From 988404cd3ba142bc7a56e641ac675279d8a6d995 Mon Sep 17 00:00:00 2001 From: Leonardo macOS Date: Wed, 20 Nov 2024 16:04:14 +0100 Subject: [PATCH 1/2] trying to edit formatting with autopep8 to conform to pep8 --- poetry.lock | 27 ++- pyproject.toml | 1 + src/malco/analysis/check_lens.py | 18 +- .../analysis/count_grounding_failures.py | 6 +- .../analysis/count_translated_prompts.py | 6 +- src/malco/analysis/eval_diagnose_category.py | 68 +++--- src/malco/analysis/monarchKG_classifier.py | 10 +- src/malco/analysis/test_curate_script.py | 22 +- .../time_ic/disease_avail_knowledge.py | 225 +++++++++--------- .../analysis/time_ic/logit_predict_llm.py | 16 +- .../analysis/time_ic/rank_date_exploratory.py | 108 ++++----- src/malco/post_process/df_save_util.py | 3 +- src/malco/post_process/extended_scoring.py | 40 ++-- src/malco/post_process/generate_plots.py | 20 +- src/malco/post_process/mondo_score_utils.py | 21 +- src/malco/post_process/post_process.py | 6 +- .../post_process_results_format.py | 20 +- src/malco/post_process/ranking_utils.py | 83 ++++--- src/malco/prepare/setup_phenopackets.py | 7 +- src/malco/prepare/setup_run_pars.py | 15 +- src/malco/run/run.py | 22 +- src/malco/run/search_ppkts.py | 13 +- src/malco/runner.py | 29 +-- 23 files changed, 415 insertions(+), 371 deletions(-) diff --git a/poetry.lock b/poetry.lock index d8b5f277..55716e7f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -336,8 +336,8 @@ files = [ lazy-object-proxy = ">=1.4.0" typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} wrapt = [ - {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, {version = ">=1.11,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, ] [[package]] @@ -370,6 +370,21 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "autopep8" +version = "2.3.1" +description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" +optional = false +python-versions = ">=3.8" +files = [ + {file = "autopep8-2.3.1-py2.py3-none-any.whl", hash = "sha256:a203fe0fcad7939987422140ab17a930f684763bf7335bdb6709991dd7ef6c2d"}, + {file = "autopep8-2.3.1.tar.gz", hash = "sha256:8d6c87eba648fdcfc83e29b788910b8643171c395d9c4bcf115ece035b9c9dda"}, +] + +[package.dependencies] +pycodestyle = ">=2.12.0" +tomli = {version = "*", markers = "python_version < \"3.11\""} + [[package]] name = "babel" version = "2.16.0" @@ -2132,8 +2147,8 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" proto-plus = [ - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -5311,8 +5326,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -6024,8 +6039,8 @@ files = [ annotated-types = ">=0.6.0" pydantic-core = "2.23.4" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -6233,8 +6248,8 @@ files = [ astroid = ">=2.15.8,<=2.17.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ - {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, ] isort = ">=4.2.5,<6" mccabe = ">=0.6,<0.8" @@ -9655,4 +9670,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "aa2ebc7bfc0815f46c4de87f1930ae0b9213c8432d53c4a3185d8073e7612552" +content-hash = "718e461c373c7bb63a0d46b2389e15b20cda1d8889898b6c9a0c3729a4a9aab2" diff --git a/pyproject.toml b/pyproject.toml index bb66ea05..effb49a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ setuptools = "^69.5.1" shelved-cache = "^0.3.1" curategpt = "^0.2.2" psutil = "^6.1.0" +autopep8 = "^2.3.1" [tool.poetry.plugins."pheval.plugins"] template = "malco.runner:MalcoRunner" diff --git a/src/malco/analysis/check_lens.py b/src/malco/analysis/check_lens.py index bc8cfdc9..ca47a492 100644 --- a/src/malco/analysis/check_lens.py +++ b/src/malco/analysis/check_lens.py @@ -1,12 +1,13 @@ -import pandas as pd +import pandas as pd from typing import List import pandas as pd import yaml -#from malco.post_process.post_process_results_format import read_raw_result_yaml +# from malco.post_process.post_process_results_format import read_raw_result_yaml from pathlib import Path import sys + def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: """ Read the raw result file. @@ -18,16 +19,17 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: dict: Contents of the raw result file. """ with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list + return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list + unique_ppkts = {} -#model=str(sys.argv[1]) +# model=str(sys.argv[1]) models = ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4", "gpt-4o"] for model in models: - print("==="*10, "\nEvaluating now: ", model, "\n"+"==="*10) - + print("===" * 10, "\nEvaluating now: ", model, "\n" + "===" * 10) + yamlfile = f"out_openAI_models/raw_results/multimodel/{model}/results.yaml" - all_results=read_raw_result_yaml(yamlfile) + all_results = read_raw_result_yaml(yamlfile) counter = 0 labelvec = [] @@ -64,4 +66,4 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: if i in unique_ppkts["gpt-3.5-turbo"]: continue else: - print(f"Missing ppkt in gpt-3.5-turbo is:\t", i) \ No newline at end of file + print(f"Missing ppkt in gpt-3.5-turbo is:\t", i) diff --git a/src/malco/analysis/count_grounding_failures.py b/src/malco/analysis/count_grounding_failures.py index c3e2b06b..262d258b 100644 --- a/src/malco/analysis/count_grounding_failures.py +++ b/src/malco/analysis/count_grounding_failures.py @@ -4,8 +4,8 @@ mfile = "../outputdir_all_2024_07_04/en/results.tsv" df = pd.read_csv( - mfile, sep="\t" #, header=None, names=["description", "term", "label"] - ) + mfile, sep="\t" # , header=None, names=["description", "term", "label"] +) terms = df["term"] counter = 0 @@ -17,4 +17,4 @@ counter += 1 print(counter) -print(grounded) \ No newline at end of file +print(grounded) diff --git a/src/malco/analysis/count_translated_prompts.py b/src/malco/analysis/count_translated_prompts.py index 869870b6..f6d778ca 100644 --- a/src/malco/analysis/count_translated_prompts.py +++ b/src/malco/analysis/count_translated_prompts.py @@ -14,9 +14,9 @@ promptfiles = {} for lang in langs: promptfiles[lang] = [] - for (dirpath, dirnames, filenames) in os.walk(fp+lang): + for (dirpath, dirnames, filenames) in os.walk(fp + lang): for fn in filenames: - fn = fn[0:-14] # TODO may be problematic if there are 2 "_" before "{langcode}-" + fn = fn[0:-14] # TODO may be problematic if there are 2 "_" before "{langcode}-" # Maybe something along the lines of other script disease_avail_knowledge.py # ppkt_label = ppkt[0].replace('_en-prompt.txt','') promptfiles[lang].append(fn) @@ -34,4 +34,4 @@ intersection = enset & esset & deset & itset & nlset & zhset & trset -print("Common ppkts are: ", len(intersection)) \ No newline at end of file +print("Common ppkts are: ", len(intersection)) diff --git a/src/malco/analysis/eval_diagnose_category.py b/src/malco/analysis/eval_diagnose_category.py index 3a8c1b9e..ba7851bb 100644 --- a/src/malco/analysis/eval_diagnose_category.py +++ b/src/malco/analysis/eval_diagnose_category.py @@ -17,9 +17,11 @@ outpath = "disease_groups/" pc_cache_file = outpath + "diagnoses_hereditary_cond" -pc = PersistentCache(LRUCache, pc_cache_file, maxsize=4096) - +pc = PersistentCache(LRUCache, pc_cache_file, maxsize=4096) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def mondo_adapter() -> OboGraphInterface: """ Get the adapter for the MONDO ontology. @@ -27,10 +29,12 @@ def mondo_adapter() -> OboGraphInterface: Returns: Adapter: The adapter. """ - return get_adapter("sqlite:obo:mondo") + return get_adapter("sqlite:obo:mondo") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -def mondo_mapping(term, adapter): + + +def mondo_mapping(term, adapter): mondos = [] for m in adapter.sssom_mappings([term], source="OMIM"): if m.predicate_id == "skos:exactMatch": @@ -38,6 +42,8 @@ def mondo_mapping(term, adapter): return mondos # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + @cached(pc, key=lambda omim_term, disease_categories, mondo: hashkey(omim_term)) def find_category(omim_term, disease_categories, mondo): if not isinstance(mondo, MappingProviderInterface): @@ -47,49 +53,51 @@ def find_category(omim_term, disease_categories, mondo): if not mondo_term: print(omim_term) return None - - ancestor_list = mondo.ancestors(mondo_term, # only IS_A->same result - predicates=[IS_A, PART_OF]) #, reflexive=True) # method=GraphTraversalMethod.ENTAILMENT - + + ancestor_list = mondo.ancestors(mondo_term, # only IS_A->same result + # , reflexive=True) # method=GraphTraversalMethod.ENTAILMENT + predicates=[IS_A, PART_OF]) + for mondo_ancestor in ancestor_list: if mondo_ancestor in disease_categories: - #TODO IMPORTANT! Like this, at the first match the function exits!! - return mondo_ancestor # This should be smt like MONDO:0045024 (cancer or benign tumor) - + # TODO IMPORTANT! Like this, at the first match the function exits!! + return mondo_ancestor # This should be smt like MONDO:0045024 (cancer or benign tumor) + print("Special issue following: ") print(omim_term) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#===================================================== +# ===================================================== # Script starts here. Name model: -model=str(sys.argv[1]) -#===================================================== +model = str(sys.argv[1]) +# ===================================================== # Find 42 diseases categories mondo = mondo_adapter() -disease_categories = mondo.relationships(objects = ["MONDO:0003847"], # hereditary diseases +disease_categories = mondo.relationships(objects=["MONDO:0003847"], # hereditary diseases predicates=[IS_A, PART_OF]) # only IS_A->same result -#disease_categories = mondo.relationships(objects = ["MONDO:0700096"], # only IS_A->same result +# disease_categories = mondo.relationships(objects = ["MONDO:0700096"], # only IS_A->same result # predicates=[IS_A, PART_OF]) # make df contingency table with header=diseases_category, correct, incorrect and initialize all to 0. -header = ["label","correct", "incorrect"] +header = ["label", "correct", "incorrect"] dc_list = [i[0] for i in list(disease_categories)] contingency_table = pd.DataFrame(0, index=dc_list, columns=header) for j in dc_list: - contingency_table.loc[j,"label"] = mondo.label(j) + contingency_table.loc[j, "label"] = mondo.label(j) breakpoint() filename = f"out_openAI_models/multimodel/{model}/full_df_results.tsv" # label term score rank correct_term is_correct reciprocal_rank # PMID_35962790_Family_B_Individual_3__II_6__en-prompt.txt MONDO:0008675 1.0 1.0 OMIM:620545 False 0.0 df = pd.read_csv( - filename, sep="\t" - ) + filename, sep="\t" +) -ppkts = df.groupby("label")[["term", "correct_term", "is_correct"]] -count_fails=0 +ppkts = df.groupby("label")[["term", "correct_term", "is_correct"]] +count_fails = 0 omim_wo_match = {} for ppkt in ppkts: @@ -97,11 +105,11 @@ def find_category(omim_term, disease_categories, mondo): category_index = find_category(ppkt[1].iloc[0]["correct_term"], dc_list, mondo) if not category_index: count_fails += 1 - #print(f"Category index for {ppkt[1].iloc[0]["correct_term"]} ") + # print(f"Category index for {ppkt[1].iloc[0]["correct_term"]} ") omim_wo_match[ppkt[0]] = ppkt[1].iloc[0]["correct_term"] continue - #cat_ind = find_cat_index(category) - # is there a true? ppkt is tuple ("filename"/"label"/what has been used for grouping, dataframe) --> ppkt[1] is a dataframe + # cat_ind = find_cat_index(category) + # is there a true? ppkt is tuple ("filename"/"label"/what has been used for grouping, dataframe) --> ppkt[1] is a dataframe if not any(ppkt[1]["is_correct"]): # no --> increase incorrect try: @@ -117,12 +125,12 @@ def find_category(omim_term, disease_categories, mondo): print("issue here") continue -print("\n\n", "==="*15,"\n") -print(f"For whatever reason find_category() returned None in {count_fails} cases, wich follow:\n") # print to file! -#print(contingency_table) +print("\n\n", "===" * 15, "\n") +print(f"For whatever reason find_category() returned None in {count_fails} cases, wich follow:\n") # print to file! +# print(contingency_table) print("\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) -#print(omim_wo_match, "\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) +# print(omim_wo_match, "\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) cont_table_file = f"{outpath}{model}.tsv" # Will overwrite -#contingency_table.to_csv(cont_table_file, sep='\t') \ No newline at end of file +# contingency_table.to_csv(cont_table_file, sep='\t') diff --git a/src/malco/analysis/monarchKG_classifier.py b/src/malco/analysis/monarchKG_classifier.py index c628efa6..7fb5a03e 100644 --- a/src/malco/analysis/monarchKG_classifier.py +++ b/src/malco/analysis/monarchKG_classifier.py @@ -1,8 +1,8 @@ -# Monarch KG -# Idea: for each ppkt, make contingency table NF/F and in box write +# Monarch KG +# Idea: for each ppkt, make contingency table NF/F and in box write # average number of connections. Thus 7 K of entries with num_edges, y=0,1 # Think about mouse weight and obesity as an example. -import numpy +import numpy from neo4j import GraphDatabase # Connect to the Neo4j database @@ -11,7 +11,7 @@ # From results take ppkts ground truth correct result and 0,1 # Map OMIM to MONDO -# +# # Need to decide what to project out. Maybe simply all edges connected to the MONDO terms I have. # At this point for each MONDO term I have count the edges # Define the Cypher query @@ -26,4 +26,4 @@ with driver.session() as session: results = session.run(query) for record in results: - data.append(record) \ No newline at end of file + data.append(record) diff --git a/src/malco/analysis/test_curate_script.py b/src/malco/analysis/test_curate_script.py index a83f20de..a5e98db0 100644 --- a/src/malco/analysis/test_curate_script.py +++ b/src/malco/analysis/test_curate_script.py @@ -1,9 +1,9 @@ -import yaml +import yaml from pathlib import Path from typing import List from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo from oaklib import get_adapter - + def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: """ @@ -16,7 +16,7 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: dict: Contents of the raw result file. """ with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list + return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list annotator = get_adapter("sqlite:obo:mondo") @@ -29,25 +29,25 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: j = 0 for this_result in all_results: extracted_object = this_result.get("extracted_object") - if extracted_object: # Necessary because this is how I keep track of multiple runs + if extracted_object: # Necessary because this is how I keep track of multiple runs ontogpt_text = this_result.get("input_text") # its a single string, should be parseable through curategpt cleaned_text = clean_service_answer(ontogpt_text) assert cleaned_text != "", "Cleaning failed: the cleaned text is empty." result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False) - label = extracted_object.get('label') # pubmed id + label = extracted_object.get('label') # pubmed id # terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream - terms = [i[1][0][0] for i in result] - #terms = extracted_object.get('terms') # list of strings, the mondo id or description + terms = [i[1][0][0] for i in result] + # terms = extracted_object.get('terms') # list of strings, the mondo id or description if terms: - # Note, the if allows for rerunning ppkts that failed due to connection issues - # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field + # Note, the if allows for rerunning ppkts that failed due to connection issues + # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field num_terms = len(terms) score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank - rank_list = [ i+1 for i in range(num_terms)] + rank_list = [i + 1 for i in range(num_terms)] for term, scr, rank in zip(terms, score, rank_list): data.append({'label': label, 'term': term, 'score': scr, 'rank': rank}) - if j>20: + if j > 20: break j += 1 diff --git a/src/malco/analysis/time_ic/disease_avail_knowledge.py b/src/malco/analysis/time_ic/disease_avail_knowledge.py index 0eee3a70..04ba475b 100644 --- a/src/malco/analysis/time_ic/disease_avail_knowledge.py +++ b/src/malco/analysis/time_ic/disease_avail_knowledge.py @@ -14,6 +14,10 @@ `runoak -g hpoa_file -G hpoa -i hpo_file information-content -p i --use-associations .all` """ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +from scipy.stats import mannwhitneyu +from scipy.stats import ttest_ind +from scipy.stats import kstest +from scipy.stats import chi2_contingency import sys import os import pandas as pd @@ -32,10 +36,10 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ model = str(sys.argv[1]) try: - make_plots = str(sys.argv[2])=="plot" + make_plots = str(sys.argv[2]) == "plot" except IndexError: - make_plots = False - print("\nYou can pass \"plot\" as a second CLI argument and this will generate nice plots!\n") + make_plots = False + print("\nYou can pass \"plot\" as a second CLI argument and this will generate nice plots!\n") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # PATHS: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -43,7 +47,7 @@ hpoa_file_path = data_dir / "phenotype.hpoa" ic_file = data_dir / "ic_hpoa.txt" original_ppkt_dir = data_dir / "ppkt-store-0.1.19" -outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" +outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" ranking_results_filename = f"out_openAI_models/multimodel/{model}/full_df_results.tsv" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -53,8 +57,8 @@ rank_date_dict = pickle.load(f) # import df of LLM results rank_results_df = pd.read_csv( - ranking_results_filename, sep="\t" - ) + ranking_results_filename, sep="\t" +) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -62,70 +66,69 @@ # Look for correlation in box plot of ppkts' rank vs time # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ dates = [] -dates_wo_none = [] # messy workaround +dates_wo_none = [] # messy workaround ranks = [] ranks_wo_none = [] for key, data in rank_date_dict.items(): - r = data[0] - d = dt.datetime.strptime(data[1], '%Y-%m-%d').date() - dates.append(d) - ranks.append(r) - if r is not None: - dates_wo_none.append(d) - ranks_wo_none.append(r) - - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# Correlation? + r = data[0] + d = dt.datetime.strptime(data[1], '%Y-%m-%d').date() + dates.append(d) + ranks.append(r) + if r is not None: + dates_wo_none.append(d) + ranks_wo_none.append(r) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Correlation? years_only = [] -for i in range(len(dates)): - years_only.append(dates[i].year) +for i in range(len(dates)): + years_only.append(dates[i].year) years_only_wo_none = [] -for i in range(len(dates_wo_none)): - years_only_wo_none.append(dates[i].year) +for i in range(len(dates_wo_none)): + years_only_wo_none.append(dates[i].year) if make_plots: - sns.boxplot(x=years_only_wo_none, y=ranks_wo_none) - plt.xlabel("Year of HPOA annotation") - plt.ylabel("Rank") - plt.title("LLM performance uncorrelated with date of discovery") - plt.savefig(outdir / "boxplot_discovery_date.png") + sns.boxplot(x=years_only_wo_none, y=ranks_wo_none) + plt.xlabel("Year of HPOA annotation") + plt.ylabel("Rank") + plt.title("LLM performance uncorrelated with date of discovery") + plt.savefig(outdir / "boxplot_discovery_date.png") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Statistical test, simplest idea: chi2 of contingency table with: # y<=2009 and y>2009 clmns and found vs not-found counts, one count per ppkt # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -cont_table = [[0, 0], [0, 0]] # contains counts +cont_table = [[0, 0], [0, 0]] # contains counts for i, d in enumerate(years_only): - if d < 2010: - if ranks[i] == None: - cont_table[0][1] += 1 - else: - cont_table[0][0] += 1 - else: - if ranks[i] == None: - cont_table[1][1] += 1 - else: - cont_table[1][0] += 1 - -df_contingency_table = pd.DataFrame(cont_table, - index=["found", "not_found"], + if d < 2010: + if ranks[i] == None: + cont_table[0][1] += 1 + else: + cont_table[0][0] += 1 + else: + if ranks[i] == None: + cont_table[1][1] += 1 + else: + cont_table[1][0] += 1 + +df_contingency_table = pd.DataFrame(cont_table, + index=["found", "not_found"], columns=["y<2010", "y>=2010"]) print(df_contingency_table) print("H0: no correlation between column 1 and 2:") -from scipy.stats import chi2_contingency -res = chi2_contingency(cont_table) +res = chi2_contingency(cont_table) print("Results from \u03c7\N{SUPERSCRIPT TWO} test on contingency table:\n", res) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# IC: For each phenpacket, list observed HPOs and compute average IC. Is it correlated with +# IC: For each phenpacket, list observed HPOs and compute average IC. Is it correlated with # success? I.e., start with f/nf, 1/0 on y-axis vs avg(IC) on x-axis # Import ppkts -ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] +ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] # Import IC-file as dict with open(ic_file) as f: @@ -136,62 +139,64 @@ ppkts_with_zero_hpos = [] ppkts_with_missing_hpos = [] -# Iterate over ppkts, which are json. +# Iterate over ppkts, which are json. for subdir, dirs, files in os.walk(original_ppkt_dir): - # For each ppkt - for filename in files: - if filename.endswith('.json'): - file_path = os.path.join(subdir, filename) - with open(file_path, mode="r", encoding="utf-8") as read_file: - ppkt = json.load(read_file) - ppkt_id = re.sub('[^\\w]', '_', ppkt['id']) - ic = 0 - num_hpos = 0 - # For each HPO - for i in ppkt['phenotypicFeatures']: + # For each ppkt + for filename in files: + if filename.endswith('.json'): + file_path = os.path.join(subdir, filename) + with open(file_path, mode="r", encoding="utf-8") as read_file: + ppkt = json.load(read_file) + ppkt_id = re.sub('[^\\w]', '_', ppkt['id']) + ic = 0 + num_hpos = 0 + # For each HPO + for i in ppkt['phenotypicFeatures']: + try: + if i["excluded"]: # skip excluded + continue + except KeyError: + pass + hpo = i["type"]["id"] + try: + ic += float(ic_dict[hpo]) + num_hpos += 1 + except KeyError as e: + missing_in_ic_dict.append(e.args[0]) + ppkts_with_missing_hpos.append(ppkt_id) + + # print(f"No entry for {e}.") + + # For now we are fine with average IC try: - if i["excluded"]: # skip excluded - continue - except KeyError: - pass - hpo = i["type"]["id"] - try: - ic += float(ic_dict[hpo]) - num_hpos += 1 - except KeyError as e: - missing_in_ic_dict.append(e.args[0]) - ppkts_with_missing_hpos.append(ppkt_id) - - #print(f"No entry for {e}.") - - # For now we are fine with average IC - try: - ppkt_ic[ppkt_id] = ic/num_hpos - # TODO max ic instead try - except ZeroDivisionError as e: - ppkts_with_zero_hpos.append(ppkt_id) - #print(f"No HPOs for {ppkt["id"]}.") - + ppkt_ic[ppkt_id] = ic / num_hpos + # TODO max ic instead try + except ZeroDivisionError as e: + ppkts_with_zero_hpos.append(ppkt_id) + # print(f"No HPOs for {ppkt["id"]}.") + missing_in_ic_dict_unique = set(missing_in_ic_dict) ppkts_with_missing_hpos = set(ppkts_with_missing_hpos) -print(f"\nNumber of (unique) HPOs without IC-value is {len(missing_in_ic_dict_unique)}.") # 65 -print(f"Number of ppkts with zero observed HPOs is {len(ppkts_with_zero_hpos)}. These are left out.") # 141 -#TODO check 141 -print(f"Number of ppkts where at least one HPO is missing its IC value is {len(ppkts_with_missing_hpos)}. These are left out from the average.\n") # 172 +print(f"\nNumber of (unique) HPOs without IC-value is {len(missing_in_ic_dict_unique)}.") # 65 +print(f"Number of ppkts with zero observed HPOs is {len(ppkts_with_zero_hpos)}. These are left out.") # 141 +# TODO check 141 +# 172 +print( + f"Number of ppkts where at least one HPO is missing its IC value is {len(ppkts_with_missing_hpos)}. These are left out from the average.\n") ppkt_ic_df = pd.DataFrame.from_dict(ppkt_ic, orient='index', columns=['avg(IC)']) -ppkt_ic_df['Diagnosed'] = 0 +ppkt_ic_df['Diagnosed'] = 0 for ppkt in ppkts: - if any(ppkt[1]["is_correct"]): - ppkt_label = ppkt[0].replace('_en-prompt.txt','') - if ppkt_label in ppkts_with_zero_hpos: - continue - ppkt_ic_df.loc[ppkt_label,'Diagnosed'] = 1 + if any(ppkt[1]["is_correct"]): + ppkt_label = ppkt[0].replace('_en-prompt.txt', '') + if ppkt_label in ppkts_with_zero_hpos: + continue + ppkt_ic_df.loc[ppkt_label, 'Diagnosed'] = 1 # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # See https://github.com/monarch-initiative/phenopacket-store/issues/157 -label_manual_removal = ["PMID_27764983_Family_1_individual__J", +label_manual_removal = ["PMID_27764983_Family_1_individual__J", "PMID_35991565_Family_I__3"] ppkt_ic_df = ppkt_ic_df.drop(label_manual_removal) # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx @@ -205,39 +210,36 @@ # T-test: unlikely that the two samples are such due to sample bias. # Likely, there is a correlation between average IC and whether the case is being solved. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# One-sample Kolmogorov-Smirnov test: compares the underlying distribution F(x) of a sample +# One-sample Kolmogorov-Smirnov test: compares the underlying distribution F(x) of a sample # against a given distribution G(x), here the normal distribution -from scipy.stats import kstest -kolsmirnov_result = kstest(ppkt_ic_df['avg(IC)'],'norm') -print(kolsmirnov_result,"\n") -# When interpreting be careful, and I quote one answer from: +kolsmirnov_result = kstest(ppkt_ic_df['avg(IC)'], 'norm') +print(kolsmirnov_result, "\n") +# When interpreting be careful, and I quote one answer from: # https://stats.stackexchange.com/questions/2492/is-normality-testing-essentially-useless -# "The question normality tests answer: Is there convincing -# evidence of any deviation from the Gaussian ideal? -# With moderately large real data sets, the answer is almost always yes." +# "The question normality tests answer: Is there convincing +# evidence of any deviation from the Gaussian ideal? +# With moderately large real data sets, the answer is almost always yes." # --------------- t-test --------------- # TODO: regardless of above, our distributions are not normal. Discard t and add non-parametric U-test (Mann-Whitney) -from scipy.stats import ttest_ind -found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed']>0, 'avg(IC)']) -not_found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed']<1, 'avg(IC)']) +found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed'] > 0, 'avg(IC)']) +not_found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed'] < 1, 'avg(IC)']) tresult = ttest_ind(found_ic, not_found_ic, equal_var=False) -print("T-test result:\n", tresult,"\n") +print("T-test result:\n", tresult, "\n") # --------------- u-test --------------- -from scipy.stats import mannwhitneyu u_value, p_of_u = mannwhitneyu(found_ic, not_found_ic) -print(f"U-test, u_value={u_value} and its associated p_val={p_of_u}","\n") +print(f"U-test, u_value={u_value} and its associated p_val={p_of_u}", "\n") # --------------- plot --------------- if make_plots: - plt.hist(found_ic, bins=25, color='c', edgecolor='k', alpha=0.5, density=True) - plt.hist(not_found_ic, bins=25, color='r', edgecolor='k', alpha=0.5, density=True) - plt.xlabel("Average Information Content") - plt.ylabel("Counts") - plt.legend(['Successful Diagnosis', 'Unsuccessful Diagnosis']) - plt.savefig(outdir / "inf_content_histograms.png") + plt.hist(found_ic, bins=25, color='c', edgecolor='k', alpha=0.5, density=True) + plt.hist(not_found_ic, bins=25, color='r', edgecolor='k', alpha=0.5, density=True) + plt.xlabel("Average Information Content") + plt.ylabel("Counts") + plt.legend(['Successful Diagnosis', 'Unsuccessful Diagnosis']) + plt.savefig(outdir / "inf_content_histograms.png") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Still important TODO! @@ -245,6 +247,3 @@ # 1) cont table 8 cells... chi2 test # 2) MRR test --> one way # 3) rank based u-test, max 50, 51 for not found or >50 - - - diff --git a/src/malco/analysis/time_ic/logit_predict_llm.py b/src/malco/analysis/time_ic/logit_predict_llm.py index d7a6f9d3..b26c648b 100644 --- a/src/malco/analysis/time_ic/logit_predict_llm.py +++ b/src/malco/analysis/time_ic/logit_predict_llm.py @@ -11,25 +11,27 @@ # https://towardsdatascience.com/building-a-logistic-regression-in-python-step-by-step-becd4d56c9c8 [1/10/24] import pandas as pd import numpy as np -import pickle +import pickle from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix, classification_report from pathlib import Path # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Path -outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" +outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" # Import ppkt_ic_df = pd.read_csv(outdir / "ppkt_ic.tsv", delimiter='\t', index_col=0) with open(outdir / "rank_date_dict.pkl", 'rb') as f: rank_date_dict = pickle.load(f) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -X_train, X_test, y_train, y_test = train_test_split(ppkt_ic_df[['avg(IC)']], ppkt_ic_df['Diagnosed'], test_size=0.2, random_state=0) +X_train, X_test, y_train, y_test = train_test_split( + ppkt_ic_df[['avg(IC)']], ppkt_ic_df['Diagnosed'], test_size=0.2, random_state=0) logreg = LogisticRegression() logreg.fit(X_train, y_train) y_pred = logreg.predict(X_test) -print('Accuracy of 1 parameter (IC) logistic regression classifier on test set: {:.2f}'.format(logreg.score(X_test, y_test))) +print('Accuracy of 1 parameter (IC) logistic regression classifier on test set: {:.2f}'.format( + logreg.score(X_test, y_test))) cm1d = confusion_matrix(y_test, y_pred) print(cm1d) class_report = classification_report(y_test, y_pred) @@ -50,11 +52,13 @@ date_df.set_index(new_index, inplace=True) ic_date_df = date_df.join(ppkt_ic_df, how='inner') -X_train, X_test, y_train, y_test = train_test_split(ic_date_df[['avg(IC)','date']], ic_date_df['Diagnosed'], test_size=0.2, random_state=0) +X_train, X_test, y_train, y_test = train_test_split( + ic_date_df[['avg(IC)', 'date']], ic_date_df['Diagnosed'], test_size=0.2, random_state=0) logreg = LogisticRegression() logreg.fit(X_train, y_train) y_pred = logreg.predict(X_test) -print('\nAccuracy of 2 PARAMETER (IC and time) logistic regression classifier on test set: {:.2f}'.format(logreg.score(X_test, y_test))) +print('\nAccuracy of 2 PARAMETER (IC and time) logistic regression classifier on test set: {:.2f}'.format( + logreg.score(X_test, y_test))) cm2d = confusion_matrix(y_test, y_pred) print(cm2d) class_report = classification_report(y_test, y_pred) diff --git a/src/malco/analysis/time_ic/rank_date_exploratory.py b/src/malco/analysis/time_ic/rank_date_exploratory.py index 405267d6..37042d4b 100644 --- a/src/malco/analysis/time_ic/rank_date_exploratory.py +++ b/src/malco/analysis/time_ic/rank_date_exploratory.py @@ -13,9 +13,9 @@ >>> python src/malco/analysis/time_ic/diseases_avail_knowledge.py gpt-4o """ -import pickle +import pickle from pathlib import Path -import pandas as pd +import pandas as pd import sys import numpy as np # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -24,17 +24,17 @@ ranking_results_filename = f"out_openAI_models/multimodel/{model}/full_df_results.tsv" data_dir = Path.home() / "data" hpoa_file_path = data_dir / "phenotype.hpoa" -outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" +outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" # (1) HPOA for dates # HPOA import and setup hpoa_df = pd.read_csv( - hpoa_file_path, sep="\t" , header=4, low_memory=False # Necessary to suppress Warning we don't care about - ) + hpoa_file_path, sep="\t", header=4, low_memory=False # Necessary to suppress Warning we don't care about +) labels_to_drop = ["disease_name", "qualifier", "hpo_id", "reference", "evidence", - "onset", "frequency", "sex", "modifier", "aspect"] + "onset", "frequency", "sex", "modifier", "aspect"] hpoa_df = hpoa_df.drop(columns=labels_to_drop) hpoa_df['date'] = hpoa_df["biocuration"].str.extract(r'\[(.*?)\]') @@ -47,40 +47,40 @@ # import df of LLM results rank_results_df = pd.read_csv( - ranking_results_filename, sep="\t" - ) + ranking_results_filename, sep="\t" +) # Go through results data and make set of found vs not found diseases. found_diseases = [] not_found_diseases = [] -rank_date_dict = {} -ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] -for ppkt in ppkts: #TODO 1st for ppkt in ppkts - # ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe - disease = ppkt[1].iloc[0]['correct_term'] - - if any(ppkt[1]["is_correct"]): - found_diseases.append(disease) - index_of_match = ppkt[1]["is_correct"].to_list().index(True) - try: - rank = ppkt[1].iloc[index_of_match]["rank"] # inverse rank does not work well - rank_date_dict[ppkt[0]] = [rank.item(), - hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] - # If ppkt[1].iloc[0]["correct_term"] is nan, then KeyError "e" is nan - except (ValueError, KeyError) as e: - print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") - - else: - not_found_diseases.append(disease) - try: - rank_date_dict[ppkt[0]] = [None, - hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] - except (ValueError, KeyError) as e: - #pass - #TODO collect the below somewhere - print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") - +rank_date_dict = {} +ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] +for ppkt in ppkts: # TODO 1st for ppkt in ppkts + # ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe + disease = ppkt[1].iloc[0]['correct_term'] + + if any(ppkt[1]["is_correct"]): + found_diseases.append(disease) + index_of_match = ppkt[1]["is_correct"].to_list().index(True) + try: + rank = ppkt[1].iloc[index_of_match]["rank"] # inverse rank does not work well + rank_date_dict[ppkt[0]] = [rank.item(), + hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] + # If ppkt[1].iloc[0]["correct_term"] is nan, then KeyError "e" is nan + except (ValueError, KeyError) as e: + print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") + + else: + not_found_diseases.append(disease) + try: + rank_date_dict[ppkt[0]] = [None, + hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] + except (ValueError, KeyError) as e: + # pass + # TODO collect the below somewhere + print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") + # gpt-4o output, reasonable enough to throw out 62 cases ~1%. 3 OMIMs to check and 3 nan # TODO clean up here # len(rank_date_dict) --> 6625 @@ -101,47 +101,47 @@ overlap = [] for i in found_set: - if i in notfound_set: - overlap.append(i) + if i in notfound_set: + overlap.append(i) print(f"Number of found diseases by {model} is {len(found_set)}.") print(f"Number of not found diseases by {model} is {len(notfound_set)}.") print(f"Diseases sometimes found, sometimes not, by {model} are {len(overlap)}.\n") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# Look at the 263-129 (gpt-4o) found diseases not present in not-found set ("always found") +# Look at the 263-129 (gpt-4o) found diseases not present in not-found set ("always found") # and the opposite namely "never found" diseases. Average date of two sets is? -always_found = found_set - notfound_set # 134 -never_found = notfound_set - found_set # 213 +always_found = found_set - notfound_set # 134 +never_found = notfound_set - found_set # 213 # meaning 347/476, 27% sometimes found sometimes not, 28% always found, 45% never found. # Compute average date of always vs never found diseases -results_dict = {} # turns out being 281 long +results_dict = {} # turns out being 281 long # TODO get rid of next line, bzw hpoa_unique does not work for loop below hpoa_df.drop_duplicates(subset='database_id', inplace=True) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ for af in always_found: - try: - results_dict[af] = [True, hpoa_df.loc[hpoa_df['database_id'] == af, 'date'].item() ] - #results_dict[af] = [True, hpoa_unique.loc[hpoa_unique['database_id'] == af, 'date'].item() ] - except ValueError: - print(f"No HPOA in always_found for {af}.") + try: + results_dict[af] = [True, hpoa_df.loc[hpoa_df['database_id'] == af, 'date'].item()] + # results_dict[af] = [True, hpoa_unique.loc[hpoa_unique['database_id'] == af, 'date'].item() ] + except ValueError: + print(f"No HPOA in always_found for {af}.") for nf in never_found: - try: - results_dict[nf] = [False, hpoa_df.loc[hpoa_df['database_id'] == nf, 'date'].item() ] - #results_dict[nf] = [False, hpoa_unique.loc[hpoa_unique['database_id'] == nf, 'date'].item() ] - except ValueError: - print(f"No HPOA in never_found for {nf}.") + try: + results_dict[nf] = [False, hpoa_df.loc[hpoa_df['database_id'] == nf, 'date'].item()] + # results_dict[nf] = [False, hpoa_unique.loc[hpoa_unique['database_id'] == nf, 'date'].item() ] + except ValueError: + print(f"No HPOA in never_found for {nf}.") -#TODO No HPOA for ... comes from for ppkt in ppkts, then +# TODO No HPOA for ... comes from for ppkt in ppkts, then # disease = ppkt[1].iloc[0]['correct_term'] res_to_clean = pd.DataFrame.from_dict(results_dict).transpose() -res_to_clean.columns=["found","date"] +res_to_clean.columns = ["found", "date"] res_to_clean['date'] = pd.to_datetime(res_to_clean.date).values.astype(np.int64) final_avg = pd.DataFrame(pd.to_datetime(res_to_clean.groupby('found').mean()['date'])) print(final_avg) -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \ No newline at end of file +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/malco/post_process/df_save_util.py b/src/malco/post_process/df_save_util.py index df6479ed..bd140b71 100644 --- a/src/malco/post_process/df_save_util.py +++ b/src/malco/post_process/df_save_util.py @@ -2,6 +2,7 @@ import os import pandas as pd + def safe_save_tsv(path, filename, df): full_path = path / filename # If full_path already exists, prepend "old_" @@ -12,4 +13,4 @@ def safe_save_tsv(path, filename, df): os.remove(old_full_path) shutil.copy(full_path, old_full_path) os.remove(full_path) - df.to_csv(full_path, sep='\t', index=False) \ No newline at end of file + df.to_csv(full_path, sep='\t', index=False) diff --git a/src/malco/post_process/extended_scoring.py b/src/malco/post_process/extended_scoring.py index d60fdaf5..2fae6841 100644 --- a/src/malco/post_process/extended_scoring.py +++ b/src/malco/post_process/extended_scoring.py @@ -10,6 +10,8 @@ dd_re = re.compile(r"^[^A-z]*Differential Diagnosis") # Function to clean and remove "Differential Diagnosis" header if present + + def clean_service_answer(answer: str) -> str: """Remove the 'Differential Diagnosis' header if present, and clean the first line.""" lines = answer.split('\n') @@ -18,6 +20,8 @@ def clean_service_answer(answer: str) -> str: return '\n'.join(cleaned_lines) # Clean the diagnosis line by removing leading numbers, periods, asterisks, and spaces + + def clean_diagnosis_line(line: str) -> str: """Remove leading numbers, asterisks, and unnecessary punctuation/spaces from the diagnosis.""" line = re.sub(r'^\**\d+\.\s*', '', line) # Remove leading numbers and periods @@ -25,6 +29,8 @@ def clean_diagnosis_line(line: str) -> str: return line.strip() # Strip any remaining spaces # Split a diagnosis into its main name and synonym if present + + def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]: """Split the diagnosis into main name and synonym (if present in parentheses).""" match = re.match(r'^(.*)\s*\((.*)\)\s*$', diagnosis) @@ -33,6 +39,7 @@ def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]: return main_diagnosis.strip(), synonym.strip() return diagnosis, None # Return the original diagnosis if no synonym is found + def perform_curategpt_grounding( diagnosis: str, path: str, @@ -44,7 +51,7 @@ def perform_curategpt_grounding( ) -> List[Tuple[str, str]]: """ Use curategpt to perform grounding for a given diagnosis when initial attempts fail. - + Parameters: - diagnosis: The diagnosis text to ground. - path: The path to the database. You'll need to create an index of Mondo using curategpt in this db @@ -55,13 +62,13 @@ def perform_curategpt_grounding( - limit: The number of search results to return. - relevance_factor: The distance threshold for relevance filtering. - verbose: Whether to print verbose output for debugging. - + Returns: - List of tuples: [(Mondo ID, Label), ...] """ # Initialize the database store db = get_store(database_type, path) - + # Perform the search using the provided diagnosis results = db.search(diagnosis, collection=collection) @@ -79,7 +86,7 @@ def perform_curategpt_grounding( for obj, distance, _meta in limited_results: disease_mondo_id = obj.get("original_id") # Use the 'original_id' field for Mondo ID disease_label = obj.get("label") - + if disease_mondo_id and disease_label: pred_ids.append(disease_mondo_id) pred_labels.append(disease_label) @@ -118,40 +125,43 @@ def perform_oak_grounding( if any(ann.object_id.startswith(prefix) for prefix in include_list) } ) - + if filtered_annotations: return filtered_annotations else: match_type = "exact" if exact_match else "inexact" if verbose: print(f"No {match_type} grounded IDs found for: {diagnosis}") - pass + pass return [('N/A', 'No grounding found')] # Now, integrate curategpt into your ground_diagnosis_text_to_mondo function + + def ground_diagnosis_text_to_mondo( annotator: TextAnnotatorInterface, differential_diagnosis: str, verbose: bool = False, include_list: List[str] = ["MONDO:"], use_ontogpt_grounding: bool = True, - curategpt_path: str = "../curategpt/stagedb/", + curategpt_path: str = "../curategpt/stagedb/", curategpt_collection: str = "ont_mondo", curategpt_database_type: str = "chromadb" ) -> List[Tuple[str, List[Tuple[str, str]]]]: results = [] - + # Split the input into lines and process each one for line in differential_diagnosis.splitlines(): clean_line = clean_diagnosis_line(line) - + # Skip header lines like "**Differential diagnosis:**" if not clean_line or "Differential diagnosis" in clean_line.lower(): continue - + # Try grounding the full line first (exact match) - grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, verbose=verbose, include_list=include_list) - + grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, + verbose=verbose, include_list=include_list) + # Try grounding with curategpt if no grounding is found if use_ontogpt_grounding and grounded == [('N/A', 'No grounding found')]: grounded = perform_curategpt_grounding( @@ -161,13 +171,13 @@ def ground_diagnosis_text_to_mondo( database_type=curategpt_database_type, verbose=verbose ) - + # If still no grounding is found, log the final failure if grounded == [('N/A', 'No grounding found')]: if verbose: print(f"Final grounding failed for: {clean_line}") - + # Append the grounded results (even if no grounding was found) results.append((clean_line, grounded)) - return results \ No newline at end of file + return results diff --git a/src/malco/post_process/generate_plots.py b/src/malco/post_process/generate_plots.py index abddfb88..927245ea 100644 --- a/src/malco/post_process/generate_plots.py +++ b/src/malco/post_process/generate_plots.py @@ -6,6 +6,7 @@ # Make a nice plot, use it as function or as script + def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file, comparing): plot_dir = data_dir.parents[0] / "plots" plot_dir.mkdir(exist_ok=True) @@ -13,39 +14,38 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file, # For plot filenam labeling use lowest number of ppkt available for all models/languages etc. num_ppkt = min(num_ppkt.values()) - if comparing=="model": + if comparing == "model": name_string = str(len(models)) else: name_string = str(len(languages)) - with mrr_file.open('r', newline = '') as f: - lines = csv.reader(f, quoting = csv.QUOTE_NONNUMERIC, delimiter = '\t', lineterminator='\n') + with mrr_file.open('r', newline='') as f: + lines = csv.reader(f, quoting=csv.QUOTE_NONNUMERIC, delimiter='\t', lineterminator='\n') results_files = next(lines) mrr_scores = next(lines) - + print(results_files) print(mrr_scores) # Plotting the mrr results - sns.barplot(x = results_files, y = mrr_scores) + sns.barplot(x=results_files, y=mrr_scores) plt.xlabel("Results File") plt.ylabel("Mean Reciprocal Rank (MRR)") plt.title("MRR of Correct Answers Across Different Results Files") - plot_path = plot_dir / (name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") + plot_path = plot_dir / (name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") plt.savefig(plot_path) plt.close() # Plotting bar-plots with top ranks df_aggr = pd.read_csv(topn_aggr_file, delimiter='\t') - - sns.barplot(x="Rank_in", y="percentage", data = df_aggr, hue = comparing) + + sns.barplot(x="Rank_in", y="percentage", data=df_aggr, hue=comparing) plt.xlabel("Number of Ranks in") plt.ylabel("Percentage of Cases") plt.ylim([0.0, 1.0]) plt.title("Rank Comparison for Differential Diagnosis") plt.legend(title=comparing) - plot_path = plot_dir / ("barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") + plot_path = plot_dir / ("barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") plt.savefig(plot_path) plt.close() - diff --git a/src/malco/post_process/mondo_score_utils.py b/src/malco/post_process/mondo_score_utils.py index 8735b994..55a5f849 100644 --- a/src/malco/post_process/mondo_score_utils.py +++ b/src/malco/post_process/mondo_score_utils.py @@ -2,7 +2,7 @@ from oaklib.interfaces import MappingProviderInterface from pathlib import Path -from typing import List +from typing import List from cachetools.keys import hashkey @@ -10,7 +10,7 @@ PARTIAL_SCORE = 0.5 -def omim_mappings(term: str, adapter) -> List[str]: +def omim_mappings(term: str, adapter) -> List[str]: """ Get the OMIM mappings for a term. @@ -26,7 +26,7 @@ def omim_mappings(term: str, adapter) -> List[str]: Returns: str: The OMIM mappings. - """ + """ omims = [] for m in adapter.sssom_mappings([term], source="OMIM"): if m.predicate_id == "skos:exactMatch": @@ -45,13 +45,13 @@ def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) The predicted Mondo is equivalent to the ground truth OMIM (via skos:exactMatches in Mondo): - + >>> score_grounded_result("MONDO:0007566", "OMIM:132800", get_adapter("sqlite:obo:mondo")) 1.0 The predicted Mondo is a disease entity that groups multiple OMIMs, one of which is the ground truth: - + >>> score_grounded_result("MONDO:0008029", "OMIM:158810", get_adapter("sqlite:obo:mondo")) 0.5 @@ -65,12 +65,11 @@ def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) """ if not isinstance(mondo, MappingProviderInterface): raise ValueError("Adapter is not an MappingProviderInterface") - + if prediction == ground_truth: # predication is the correct OMIM return FULL_SCORE - ground_truths = get_ground_truth_from_cache_or_compute(prediction, mondo, cache) if ground_truth in ground_truths: # prediction is a MONDO that directly maps to a correct OMIM @@ -84,14 +83,15 @@ def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) return PARTIAL_SCORE return 0.0 + def get_ground_truth_from_cache_or_compute( - term, - adapter, + term, + adapter, cache, ): if cache is None: return omim_mappings(term, adapter) - + k = hashkey(term) try: ground_truths = cache[k] @@ -102,4 +102,3 @@ def get_ground_truth_from_cache_or_compute( cache[k] = ground_truths cache.misses += 1 return ground_truths - diff --git a/src/malco/post_process/post_process.py b/src/malco/post_process/post_process.py index 84427879..c7fe9624 100644 --- a/src/malco/post_process/post_process.py +++ b/src/malco/post_process/post_process.py @@ -27,10 +27,10 @@ def post_process(self) -> None: create_standardised_results(curategpt, raw_results_dir=raw_results_lang, - output_dir=output_lang, + output_dir=output_lang, output_file_name="results.tsv", ) - + elif self.modality == "several_models": for model in models: raw_results_model = raw_results_dir / "multimodel" / model @@ -40,6 +40,6 @@ def post_process(self) -> None: create_standardised_results(curategpt, raw_results_dir=raw_results_model, - output_dir=output_model, + output_dir=output_model, output_file_name="results.tsv", ) diff --git a/src/malco/post_process/post_process_results_format.py b/src/malco/post_process/post_process_results_format.py index 54046a1b..c47ab679 100644 --- a/src/malco/post_process/post_process_results_format.py +++ b/src/malco/post_process/post_process_results_format.py @@ -11,7 +11,6 @@ from malco.post_process.df_save_util import safe_save_tsv from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo from oaklib import get_adapter - def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: @@ -25,15 +24,15 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: dict: Contents of the raw result file. """ with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list + return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list def create_standardised_results(curategpt: bool, - raw_results_dir: Path, + raw_results_dir: Path, output_dir: Path, output_file_name: str ) -> pd.DataFrame: - + data = [] if curategpt: annotator = get_adapter("sqlite:obo:mondo") @@ -48,20 +47,20 @@ def create_standardised_results(curategpt: bool, if extracted_object: label = extracted_object.get('label') terms = extracted_object.get('terms') - if curategpt and terms: + if curategpt and terms: ontogpt_text = this_result.get("input_text") # its a single string, should be parseable through curategpt cleaned_text = clean_service_answer(ontogpt_text) assert cleaned_text != "", "Cleaning failed: the cleaned text is empty." result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False) # terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream - terms = [i[1][0][0] for i in result] # MONDO_ID + terms = [i[1][0][0] for i in result] # MONDO_ID if terms: - # Note, the if allows for rerunning ppkts that failed due to connection issues - # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field + # Note, the if allows for rerunning ppkts that failed due to connection issues + # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field num_terms = len(terms) score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank - rank_list = [ i+1 for i in range(num_terms)] + rank_list = [i + 1 for i in range(num_terms)] for term, scr, rank in zip(terms, score, rank_list): data.append({'label': label, 'term': term, 'score': scr, 'rank': rank}) @@ -164,6 +163,7 @@ def extract_pheval_gene_requirements(self) -> List[PhEvalGeneResult]: ) return pheval_result + ''' def create_standardised_results(raw_results_dir: Path, output_dir: Path) -> None: """ @@ -187,4 +187,4 @@ def create_standardised_results(raw_results_dir: Path, output_dir: Path) -> None output_dir=output_dir, tool_result_path=raw_result_path, ) -''' \ No newline at end of file +''' diff --git a/src/malco/post_process/ranking_utils.py b/src/malco/post_process/ranking_utils.py index dc03c51e..29dbffe1 100644 --- a/src/malco/post_process/ranking_utils.py +++ b/src/malco/post_process/ranking_utils.py @@ -1,4 +1,4 @@ -import os +import os import csv from pathlib import Path from datetime import datetime @@ -15,16 +15,18 @@ from malco.post_process.df_save_util import safe_save_tsv from malco.post_process.mondo_score_utils import score_grounded_result from cachetools import LRUCache -from typing import List +from typing import List from cachetools.keys import hashkey from shelved_cache import PersistentCache FULL_SCORE = 1.0 PARTIAL_SCORE = 0.5 + def cache_info(self): return f"CacheInfo: hits={self.hits}, misses={self.misses}, maxsize={self.wrapped.maxsize}, currsize={self.wrapped.currsize}" + def mondo_adapter() -> OboGraphInterface: """ Get the adapter for the MONDO ontology. @@ -32,40 +34,41 @@ def mondo_adapter() -> OboGraphInterface: Returns: Adapter: The adapter. """ - return get_adapter("sqlite:obo:mondo") + return get_adapter("sqlite:obo:mondo") + def compute_mrr_and_ranks( - comparing: str, - output_dir: Path, + comparing: str, + output_dir: Path, out_subdir: str, - prompt_dir: str, + prompt_dir: str, correct_answer_file: str, - ) -> Path: +) -> Path: - # Read in results TSVs from self.output_dir that match glob results*tsv + # Read in results TSVs from self.output_dir that match glob results*tsv out_caches = Path("caches") - #out_caches = output_dir / "caches" + # out_caches = output_dir / "caches" out_caches.mkdir(exist_ok=True) output_dir = output_dir / out_subdir results_data = [] results_files = [] num_ppkt = {} pc2_cache_file = str(out_caches / "score_grounded_result_cache") - pc2 = PersistentCache(LRUCache, pc2_cache_file, maxsize=524288) + pc2 = PersistentCache(LRUCache, pc2_cache_file, maxsize=524288) pc1_cache_file = str(out_caches / "omim_mappings_cache") pc1 = PersistentCache(LRUCache, pc1_cache_file, maxsize=524288) # Treat hits and misses as run-specific arguments, write them cache_log pc1.hits = pc1.misses = 0 pc2.hits = pc2.misses = 0 PersistentCache.cache_info = cache_info - + mode_index = 0 for subdir, dirs, files in os.walk(output_dir): for filename in files: if filename.startswith("result") and filename.endswith(".tsv"): file_path = os.path.join(subdir, filename) df = pd.read_csv(file_path, sep="\t") - num_ppkt[subdir.split('/')[-1]] = df["label"].nunique() + num_ppkt[subdir.split('/')[-1]] = df["label"].nunique() results_data.append(df) # Append both the subdirectory relative to output_dir and the filename results_files.append(os.path.relpath(file_path, output_dir)) @@ -84,21 +87,22 @@ def compute_mrr_and_ranks( cache_file = out_caches / "cache_log.txt" - with cache_file.open('a', newline = '') as cf: + with cache_file.open('a', newline='') as cf: now_is = datetime.now().strftime("%Y%m%d-%H%M%S") - cf.write("Timestamp: " + now_is +"\n\n") + cf.write("Timestamp: " + now_is + "\n\n") mondo = mondo_adapter() i = 0 # Each df is a model or a language for df in results_data: # For each label in the results file, find if the correct term is ranked df["rank"] = df.groupby("label")["score"].rank(ascending=False, method="first") - label_4_non_eng = df["label"].str.replace("_[a-z][a-z]-prompt", "_en-prompt", regex=True) #TODO is bug here? + label_4_non_eng = df["label"].str.replace( + "_[a-z][a-z]-prompt", "_en-prompt", regex=True) # TODO is bug here? # df['correct_term'] is an OMIM # df['term'] is Mondo or OMIM ID, or even disease label df["correct_term"] = label_4_non_eng.map(label_to_correct_term) - + # Make sure caching is used in the following by unwrapping explicitly results = [] for idx, row in df.iterrows(): @@ -125,34 +129,34 @@ def compute_mrr_and_ranks( full_df_path = output_dir / results_files[i].split("/")[0] full_df_filename = "full_df_results.tsv" safe_save_tsv(full_df_path, full_df_filename, df) - + # Calculate MRR for this file mrr = df.groupby("label")["reciprocal_rank"].max().mean() mrr_scores.append(mrr) - + # Calculate top of each rank rank_df.loc[i, comparing] = results_files[i].split("/")[0] - - ppkts = df.groupby("label")[["rank","is_correct"]] + + ppkts = df.groupby("label")[["rank", "is_correct"]] index_matches = df.index[df['is_correct']] - + # for each group for ppkt in ppkts: - # is there a true? ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe + # is there a true? ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe if not any(ppkt[1]["is_correct"]): # no --> increase nf = "not found" - rank_df.loc[i,"nf"] += 1 + rank_df.loc[i, "nf"] += 1 else: # yes --> what's it rank? It's jind = ppkt[1].index[ppkt[1]['is_correct']] j = int(ppkt[1]['rank'].loc[jind].values[0]) - if j<11: + if j < 11: # increase n - rank_df.loc[i,"n"+str(j)] += 1 + rank_df.loc[i, "n" + str(j)] += 1 else: # increase n10p - rank_df.loc[i,"n10p"] += 1 - + rank_df.loc[i, "n10p"] += 1 + # Write cache charatcteristics to file cf.write(results_files[i]) cf.write('\nscore_grounded_result cache info:\n') @@ -164,9 +168,9 @@ def compute_mrr_and_ranks( pc1.close() pc2.close() - + for modelname in num_ppkt.keys(): - rank_df.loc[rank_df['model']==modelname,'num_cases'] = num_ppkt[modelname] + rank_df.loc[rank_df['model'] == modelname, 'num_cases'] = num_ppkt[modelname] data_dir = output_dir / "rank_data" data_dir.mkdir(exist_ok=True) topn_file_name = "topn_result.tsv" @@ -177,27 +181,28 @@ def compute_mrr_and_ranks( print(mrr_scores) mrr_file = data_dir / "mrr_result.tsv" - # write out results for plotting - with mrr_file.open('w', newline = '') as dat: - writer = csv.writer(dat, quoting = csv.QUOTE_NONNUMERIC, delimiter = '\t', lineterminator='\n') + # write out results for plotting + with mrr_file.open('w', newline='') as dat: + writer = csv.writer(dat, quoting=csv.QUOTE_NONNUMERIC, delimiter='\t', lineterminator='\n') writer.writerow(results_files) writer.writerow(mrr_scores) df = pd.read_csv(topn_file, delimiter='\t') df["top1"] = (df['n1']) / df["num_cases"] - df["top3"] = (df["n1"] + df["n2"] + df["n3"] ) / df["num_cases"] - df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"] ) / df["num_cases"] - df["top10"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"] + df["n6"] + df["n7"] + df["n8"] + df["n9"] + df["n10"] ) / df["num_cases"] + df["top3"] = (df["n1"] + df["n2"] + df["n3"]) / df["num_cases"] + df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"]) / df["num_cases"] + df["top10"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"] + df["n6"] + + df["n7"] + df["n8"] + df["n9"] + df["n10"]) / df["num_cases"] df["not_found"] = (df["nf"]) / df["num_cases"] - + df_aggr = pd.DataFrame() - df_aggr = pd.melt(df, id_vars=comparing, value_vars=["top1", "top3", "top5", "top10", "not_found"], var_name="Rank_in", value_name="percentage") + df_aggr = pd.melt(df, id_vars=comparing, value_vars=[ + "top1", "top3", "top5", "top10", "not_found"], var_name="Rank_in", value_name="percentage") # If "topn_aggr.tsv" already exists, prepend "old_" # It's the user's responsibility to know only up to 2 versions can exist, then data is lost topn_aggr_file_name = "topn_aggr.tsv" topn_aggr_file = data_dir / topn_aggr_file_name safe_save_tsv(data_dir, topn_aggr_file_name, df_aggr) - + return mrr_file, data_dir, num_ppkt, topn_aggr_file - \ No newline at end of file diff --git a/src/malco/prepare/setup_phenopackets.py b/src/malco/prepare/setup_phenopackets.py index c35921a2..41c117f4 100644 --- a/src/malco/prepare/setup_phenopackets.py +++ b/src/malco/prepare/setup_phenopackets.py @@ -1,10 +1,11 @@ import zipfile -import os +import os import requests -phenopacket_zip_url="https://github.com/monarch-initiative/phenopacket-store/releases/download/0.1.11/all_phenopackets.zip" +phenopacket_zip_url = "https://github.com/monarch-initiative/phenopacket-store/releases/download/0.1.11/all_phenopackets.zip" # TODO just point to a folder w/ ppkts -phenopacket_dir="phenopacket-store" +phenopacket_dir = "phenopacket-store" + def setup_phenopackets(self) -> str: phenopacket_store_path = os.path.join(self.input_dir, phenopacket_dir) diff --git a/src/malco/prepare/setup_run_pars.py b/src/malco/prepare/setup_run_pars.py index 6944ee90..0828b21f 100644 --- a/src/malco/prepare/setup_run_pars.py +++ b/src/malco/prepare/setup_run_pars.py @@ -2,6 +2,7 @@ import csv import sys + def import_inputdata(self) -> None: """ Example input file is located in ``self.input_dir`` and named run_parameters.csv @@ -23,7 +24,7 @@ def import_inputdata(self) -> None: Meaning run multilingual prompts with those 5 aforementioned languages, and only execute the function postprocess, not run. """ with open(self.input_dir / "run_parameters.csv", 'r') as pars: - lines = csv.reader(pars, quoting = csv.QUOTE_NONNUMERIC, delimiter = ',', lineterminator='\n') + lines = csv.reader(pars, quoting=csv.QUOTE_NONNUMERIC, delimiter=',', lineterminator='\n') in_langs = next(lines) in_models = next(lines) in_what_to_run = next(lines) @@ -33,17 +34,17 @@ def import_inputdata(self) -> None: if (l > 1 and m > 1): sys.exit("Error, either run multiple languages or models, not both, exiting...") elif l == 1 and m >= 1: - if in_langs[0]=="en": - self.modality = "several_models" # English and more than 1 model defaults to multiple models + if in_langs[0] == "en": + self.modality = "several_models" # English and more than 1 model defaults to multiple models else: if m > 1: sys.exit("Error, only English and multiple models supported, exiting...") - else: # m==1 - self.modality = "several_languages" # non English defaults to multiple languages - elif l > 1: + else: # m==1 + self.modality = "several_languages" # non English defaults to multiple languages + elif l > 1: self.modality = "several_languages" self.languages = tuple(in_langs) - self.models = tuple(in_models) + self.models = tuple(in_models) self.do_run_step = in_what_to_run[0] # only run the run part of the code self.do_postprocess_step = in_what_to_run[1] # only run the postprocess part of the code diff --git a/src/malco/run/run.py b/src/malco/run/run.py index 1ad3e416..cfc7c14f 100644 --- a/src/malco/run/run.py +++ b/src/malco/run/run.py @@ -7,13 +7,14 @@ from malco.run.search_ppkts import search_ppkts + def call_ontogpt( - lang: str, - raw_results_dir: Path, - input_dir: Path, - model: str, + lang: str, + raw_results_dir: Path, + input_dir: Path, + model: str, modality: typing.Literal['several_languages', 'several_models'], -)-> None: +) -> None: """ Wrapper used for parallel execution of ontogpt. @@ -37,9 +38,9 @@ def call_ontogpt( else: raise ValueError('Not permitted run modality!\n') - selected_indir = search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model_dir) + selected_indir = search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model_dir) yaml_file = f"{raw_results_dir}/{lang_or_model_dir}/results.yaml" - + if os.path.isfile(yaml_file): old_yaml_file = yaml_file yaml_file = f"{raw_results_dir}/{lang_or_model_dir}/new_results.yaml" @@ -97,15 +98,16 @@ def run(self, if modality == "several_languages": with multiprocessing.Pool(processes=max_workers) as pool: try: - pool.starmap(call_ontogpt, [(lang, raw_results_dir / "multilingual", input_dir, "gpt-4o", modality) for lang in langs]) + pool.starmap(call_ontogpt, [(lang, raw_results_dir / "multilingual", + input_dir, "gpt-4o", modality) for lang in langs]) except FileExistsError as e: raise ValueError('Did not clean up after last run, check tmp dir: \n' + e) - if modality == "several_models": # English only many models with multiprocessing.Pool(processes=max_workers) as pool: try: - pool.starmap(call_ontogpt, [("en", raw_results_dir / "multimodel", input_dir, model, modality) for model in models]) + pool.starmap(call_ontogpt, [("en", raw_results_dir / "multimodel", + input_dir, model, modality) for model in models]) except FileExistsError as e: raise ValueError('Did not clean up after last run, check tmp dir: \n' + e) diff --git a/src/malco/run/search_ppkts.py b/src/malco/run/search_ppkts.py index d0cfd560..2cf5467f 100644 --- a/src/malco/run/search_ppkts.py +++ b/src/malco/run/search_ppkts.py @@ -3,21 +3,21 @@ import shutil from malco.post_process.post_process_results_format import read_raw_result_yaml - + def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): """ Check what ppkts have already been computed in current output dir, for current run parameters. - + ontogpt will run every .txt that is in inputdir, we need a tmp inputdir excluding already run cases. Source of truth is the results.yaml output by ontogpt. Only extracted_object containing terms is considered successfully run. Note that rerunning """ - + # List of "labels" that are already present in results.yaml iff terms is not None files = [] - + yaml_file = f"{raw_results_dir}/{lang_or_model}/results.yaml" if os.path.isfile(yaml_file): # tmp inputdir contains prompts yet to be computed for a given model (pars set) @@ -36,11 +36,10 @@ def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): else: return prompt_dir - # prompts: ls prompt_dir promptfiles = [] for (dirpath, dirnames, filenames) in os.walk(prompt_dir): - promptfiles.extend(filenames) + promptfiles.extend(filenames) break # foreach promptfile in original_inputdir @@ -52,4 +51,4 @@ def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): else: shutil.copyfile(prompt_dir + promptfile, selected_indir + "/" + promptfile) - return selected_indir \ No newline at end of file + return selected_indir diff --git a/src/malco/runner.py b/src/malco/runner.py index 21dbb42e..0528c567 100644 --- a/src/malco/runner.py +++ b/src/malco/runner.py @@ -11,6 +11,7 @@ from malco.post_process.generate_plots import make_plots import os + class MalcoRunner(PhEvalRunner): input_dir: Path testdata_dir: Path @@ -19,7 +20,6 @@ class MalcoRunner(PhEvalRunner): config_file: Path version: str - def prepare(self): """ Pre-process any data and inputs necessary to run the tool. @@ -35,13 +35,12 @@ def run(self): pass if self.do_run_step: run(self, - ) + ) # Cleanup tmp_dir = f"{self.input_dir}/prompts/tmp/" if os.path.isdir(tmp_dir): rmtree(tmp_dir) - def post_process(self, print_plot=True, prompts_subdir_name="prompts", @@ -54,24 +53,22 @@ def post_process(self, print("post processing results to PhEval standardised TSV output.") post_process(self) - - - if self.modality=="several_languages": + + if self.modality == "several_languages": comparing = "language" - out_subdir="multilingual" - elif self.modality=="several_models": + out_subdir = "multilingual" + elif self.modality == "several_models": comparing = "model" - out_subdir="multimodel" + out_subdir = "multimodel" else: raise ValueError('Not permitted run modality!\n') mrr_file, data_dir, num_ppkt, topn_aggr_file = compute_mrr_and_ranks(comparing, - output_dir=self.output_dir, - out_subdir=out_subdir, - prompt_dir=os.path.join(self.input_dir, prompts_subdir_name), - correct_answer_file=correct_answer_file) - + output_dir=self.output_dir, + out_subdir=out_subdir, + prompt_dir=os.path.join( + self.input_dir, prompts_subdir_name), + correct_answer_file=correct_answer_file) + if print_plot: make_plots(mrr_file, data_dir, self.languages, num_ppkt, self.models, topn_aggr_file, comparing) - - \ No newline at end of file From 5bf84fd00a2f028052648c083b90a2af11418cb1 Mon Sep 17 00:00:00 2001 From: Leonardo macOS Date: Wed, 20 Nov 2024 17:50:00 +0100 Subject: [PATCH 2/2] fixed several things, among which flake8 version, tox.ini formatting, lots in the source code format, ran it once to make sure it still works --- caches/cache_log.txt | 14 +++ caches/omim_mappings_cache.db | Bin 10391552 -> 10391552 bytes caches/score_grounded_result_cache.db | Bin 10108928 -> 10117120 bytes src/malco/analysis/check_lens.py | 35 +++--- .../analysis/count_grounding_failures.py | 5 +- .../analysis/count_translated_prompts.py | 34 +++--- src/malco/analysis/eval_diagnose_category.py | 38 ++++--- src/malco/analysis/test_curate_script.py | 20 ++-- .../time_ic/disease_avail_knowledge.py | 80 +++++++------- .../analysis/time_ic/logit_predict_llm.py | 43 +++++--- .../analysis/time_ic/rank_date_exploratory.py | 58 ++++++---- src/malco/post_process/df_save_util.py | 5 +- src/malco/post_process/extended_scoring.py | 59 +++++----- src/malco/post_process/generate_plots.py | 16 +-- src/malco/post_process/mondo_score_utils.py | 8 +- src/malco/post_process/post_process.py | 25 ++--- .../post_process_results_format.py | 42 +++---- src/malco/post_process/ranking_utils.py | 103 +++++++++++------- src/malco/prepare/setup_phenopackets.py | 35 ------ src/malco/prepare/setup_run_pars.py | 24 ++-- src/malco/run/run.py | 49 +++++---- src/malco/run/search_ppkts.py | 12 +- src/malco/runner.py | 52 +++++---- tox.ini | 38 +++---- 24 files changed, 434 insertions(+), 361 deletions(-) delete mode 100644 src/malco/prepare/setup_phenopackets.py diff --git a/caches/cache_log.txt b/caches/cache_log.txt index c58bcb90..46225e0c 100644 --- a/caches/cache_log.txt +++ b/caches/cache_log.txt @@ -164,3 +164,17 @@ CacheInfo: hits=176206, misses=6703, maxsize=524288, currsize=29413 omim_mappings cache info: CacheInfo: hits=80073, misses=2993, maxsize=524288, currsize=23613 +Timestamp: 20241120-174511 + +gpt-4o/results.tsv +score_grounded_result cache info: +CacheInfo: hits=45, misses=30, maxsize=524288 , currsize=29443 +omim_mappings cache info: +CacheInfo: hits=615, misses=5, maxsize=524288 , currsize=23618 + +gpt-4o-mini/results.tsv +score_grounded_result cache info: +CacheInfo: hits=98, misses=70, maxsize=524288 , currsize=29483 +omim_mappings cache info: +CacheInfo: hits=1153, misses=14, maxsize=524288 , currsize=23627 + diff --git a/caches/omim_mappings_cache.db b/caches/omim_mappings_cache.db index 74e30758ff0da2f59051e12580ab5810bdf586cd..1dca25ae7bcdb04ed7b77ff3c628fb82f03ebaf0 100644 GIT binary patch delta 1617 zcmZA1U2GIp6bJC$*_qj0+WnBFd|BxCr)`{cyP0g<&^t_50X^vwrINHf70ViO*PKtu>yh`|Tt!2}Yh#_-_3$cxP6_jYsc$vNkq z#F>A@1RRT!k^a~~j7-FYO94{CNl;kh`B5sCu8GBxvpy(X6Z!dQJ#$%1s7G)>pSiUs_^?}ExZzKZ+#x4rALo#k_d?{1L-PVNhJo*ubOo|4LB zXYGAqO5(X!$|CucS*Zd1{iy?~t(#H@Q``v-~lhpf^^7mj)r`@%L(PA+SlezceM(SNUO86+BZ%3oFAH)Xz8ZL zC%Y@w@vi-&IRS%Giqsj^4XRlhACHlU-CXN)nrr>vJSKAyCs68HJ?pf$-Fw>W(CR?D zm*-z`S;XQ=A~X=xw>F;ZQ%y@#45sQPRVc4{R2=13Ysw{iWK{ry@SHQUYW9u^QZP(H zb)H`NveYMCmL}GQ`XoMhWaMdP*ScNxE~+s8U^GA5uP{qj5YRI8u%+({Eh}?Lb$#{+}i2S-SHaox! zYK1K;Y+@=H3Vn{(#_Y4*S+EEeLpJ09v`Trs z(rm8m&y&*eW(18dPD*tnA1bXEtOev^34*X_ECDHb2Ub8;P5PlR)V%%o}eOKvHyKwJTr*QAKUj}DJdV8IF*nLkvDc^G%a#~Nz z|D|e}7Oxr2crwvL4Z9ZeB~vX@Ix`rygsvKjhB~&y_}E}o$bRVF4sXDlumj$Lo$xli z1MfmR?1B#Hgf7?(@4+7Eh91}p??W$q0Q+D+9Dom@4-(K1AHm1)2@JqNuwf8Bg+p){ zj=*Pd6h4Pz&O>)Hql9E}q9k?CZEw=s2u@r{HhPC@vLifM5Df(ljj1MDq|9`PQiG3| zw?*yJND{t)AxOb690v|x!U&vjN+Tz4K6-M+Dcw`znR1W0r`%5bYUY@m5B{}~6=KCJ oj6L(dDpXZxmTIXA)l^NVLl4eJ_?bO|X!n?>VASa`#}2jp3p=_vLI3~& delta 1176 zcmW;LT})GV7{_t`r{|oWmbC|0V1vOxuxwSFrO@{DMOwtFFu?I*b-XAqZrW|Qb_xtO z6;wQmvQZHqK`=E-yf7}BEbKqcLNpoodbKTZ;dZkn+mbHJxY(9#KX&oC{XTi}er>D#n6#uN+ zLJxNG;bgrg9UK9A^ zHA5(mc`!%pSAP`R41EG;+Zku!jHTq5i&Eo~%9kEBTocF`Yw+olHW4I{!3YYNzzmu2 z9G|pVj+Zkhv9P7b$b~{f!m!L^VapmbXj&EBD$fhjzq!S>4O_KJkK0q>^Ho*)+|Ej$ zmww1B4bt4_nMD+B%kt^bS__ZXTJQYBvQbkV$~JSCIcheWxht&pn`pkXYl%v~%kk2a znsP~>tydunHt^Z{=Z`G1k{M=;C_CiK@&W0txY%SnEm8b>sY1tY)*EScobHpYNAc;) zvJJ8!2Xf&BcoFg-A70|iazP-*HWP~%>{AM*A?XV-B=Uptf(IhSXOez$1&@}|xsE(d zzc)|-h43;IK{0HE5-5dDV23hzh2I-^wRyrMWc;Sqs9v>(o5EK#mCjT*`t%Xi0h?h9 zl*3kd4Yq+3D!>JmPz7%AfEP4aZ}=eqLD&w}5P~q&zz(Q|oqR-%obP5qY#A9ge3jv$ z#3RQ8N_Fq`>+wVccEN7g19eajuR{YgLKEzTH=r3>pcUF+A4Fk49Dp~W9XfbC(K%&h zn~05e))^M$UwCsmc1xyIyw%k5BlL4ybp0mzz1*~ zPQU;R!Vr82!!QCL!N+hCPQhsyg)uk-XW<--!+Fr*0$hX%xCEDB5RIG^B*moVf9q;awV*l=dBmE^h^iR>PVE#nN6f`ZQ%VALu zM#`3V_uHxX`d!{D$Or_AwXengRRLhM0euwFtaIdgjV1&Ordz_ul zEGO&S%WvfO9`IW?@vy#C%Obv-PGL!t;2>y5(s?YVrd>@1gM)8>o1fv5Bv}$8 z5jCR7Aw8_>Q2Tst6?ENE!$bOqV}i-9Oe>lVjYeb8?4&Ex(YBq+1ghVmGq>3T_IH?o zos92yyu(03uC9Ws4{geJ<9J$F6eT5uFXTuh3`6N9A&3{zImuJU9puz;Ht%;-E>N9Z zEnUZ)r5~WrlF6Oi789-?A>nV$>nKIYKO7FA*W8ywtvm+W4Yq|o5dr~D>_^`vC-0`iG?D$q_xt_`gYLRO)(@wkSgEkjUu3F?e zXoDNadv(a73lQoPbik=9VnmT}g7_sWxDgdqu=hkO{`RMQ08TG#!jBC@1F_cAI9@x2RO@c7}f zbok+QS3mUr#UYT`UtIkkQKNAy-DUJq$90Z!$8|*McUL&zPj6qX!wtJ!F5u;Kzef>6 ziY!LLnvA{@kyKTJ!}r=GICs>28M5xka>K#3nSL0WCtFD575B?tW6^`|$;_|mLi#OR zAuf8*{gw^R>2Zpkht&i?4m&Y{C*b*TP^AE#yy@XuzvQ&12C(UD}(-P z=`6VY_ZdOhxXRNAeHSx17+B-E0nWeI!oz}_XJx=!Ydurp&R6I0@Y)vqyRLPfLy-M$ zfQL)`!XP=f-s2UAeXhyabm|1Vu9Mr03CE$_JFoo5#GyE@da8k#d)JA`4H6 z+2!_$9sk7GhyUR4Oy5{nSkpyKlO#1Hp=*m0eDLPDDp<527Z)(SAxP?j-e>Q&n@xWc zHVCVP4a7beTP<7(hVc&Tm0;)zi3)n4BFea_5ebJ=9_{nCQ9+>km_W`6j&QTN3EUa> zBzqG(pS|e#*s;=ai^Jo%Y=6kU$o>$i&h{;~!z1TL>2S1TN)(h8tr|3)7hEuu?Yqu+ z#e+REHwZ&HzNt`lf1LoWV}0*H{PG+NbbT7L!`$2JGhra#HyJwa%XGoC^qGDrE%c?q z%vmXAu%vd91{p>8>`gBf3Z%8z*F?d|MXa!?rZpWJOYs+3mO3Y-m-%)K9}qe%V-g?X z-{cN+i@0gr7wp^YZS1wInLY1V=eXLjjwqsUt^=%-M(gnY*-QakvabT(zJ^y}pxwg4 za>aKVI=0q0VY=!QL7iiZK-1-P4xU~@Wx!?4Hwx&bi~?0r->uN`D@P$TR{4%V=L!si z_wr(pWY_p^q^NQ{{2uSLE&NJ8kN5IVaoe~Wu9(}!Ze(RPo9%P_#u0TCI$pBxx6iQ0 z?C&!tm?g|?=4<*49*k>gGkxB+&h{(Ych*m=ORRIO|FGO|`3G4X^KG=ihcTZ99c5X3 z&+%13vhG3ehZKCg)%z91-i3jE)+sg*TF*9Sz}F4FH9a5gE`v|HtoiWVB;Rdtf|^nU1CtTpy6Pw8xToL6kyo-)SF7#4$)Dcr1 zbz9tIsx-aWoruDK%UlVU`aKQ6Xw4c3kCfVA#)LqoD}v_{GqxVVn5iof8Kgg#R6+a@ zrl;e7bOnj;i2sY{jkR{BC)h@6$aKAF$W&w^(B11Tf&P9gm$SW}IV1S3iTi7b(Eiw>9ewok|)?GrNBml~67%S^EO z)KOE^bkr0!9ft0%Ty^95tV**k!Uu`+0=ET9@UM|&n( z7U^iuPP7Vr#a^LQia*oK**v4fg-8 z4i_4dLP$^)jLZsd6B1Quemv(Y((z?cVhnx=^|Bpl6E!#)9@o*bKzYWJ6&v**d0-w2mp?9yF#MD_>6A%sZ(u)J|i{V=3qt@tj?v zOI`5%+L#~dT{x^MkuYw9E=HjLgVbsm`WQGEygSnZC+}<+O>Q4wzVdmhW+PRxoc@}+ zpYl-m8)JDXV_x;J=8KY`r%X~cq;*ml-AajwsvH7oXR1UrbH%Y~Ms|o-%%!(jb?cuk zI$1mrecN(nBIo~!5$CgY#cnu#dp+we!@RC(iWD+jQ`V69Cqi?pk{{2iKoiNwvn0O* zV-B@Qx`|$sQk1gKTtwOj#_mUmx@r-ka~Ui=_^l)8M$&=jQAWQ~LKsa|Jc7riQ1axp z5}NTdj3$$l*D9fR4JV65R**ZzViyHJUslCKG@2F^R8dD43+Y%b;DL^S)SerKn=hma zN!R!YagA4pZZg6}pL!2nk0E!O>6B5OO~2x1eI?Tc`1$x8f3oIDmDD_`9XU80J@f?q zxoN%0YFZD^4Y|ibESi3Wnz+MZQB(8~zL%HJWY?~WpuHSgtf)Z5JY5waA8eovQaVwNt;yL^@I@`25Xnn zx0pR<^eA-mjl=F#3GWZoaZvI~YQUvx#`AIr-9U|?L890ig>m1O@yQcsYSG4{O&}-E z)G5QMs<%EtJE$j2V@*#$w^(1esU0zlj|*AB_fc(@2W#MPSVlJwBfJt)(+D(|auT=> z)MY`_zM3Fey}kZ6r!lj>`laNTZPnjUo8DA-XnDr(ca_QEh#FR~utMinb*xAm^U&cM zD&6G#Q}ymwDGp;qLn;G&l+QpTnVk(8Jlyr-1Rdfx$E@(xStZMZ^>IWKC9Eq{6*Gjc zhv34?4vCz)uOTm&^58o+9Yq;Uo(m4AuNA6>eW;%M`GH6*crs^cGdc z{oQ`jRRt$4rB3+#iHVqHpKFNqI4!0xgtfvlAx*HO5odg&U4RqWLM5CR9qg4XtA!K= zIWP`_+pXymNZt@8ft?v_a-el08EBm(E}}+ZI+;{N$Jh@YPdPp$QvalV4*132^O19v zrdYxI%JeJ&c}qw&yjN5scz&Vn#MRYs{*U9h@PWF@{x{!UvvBmKU zk+P>e=(tjN#eYk&*s-83;JK?LTM)70#}(?DgbfR0c|Yf(oBp`j1vg&f@sq73Q?9M^ zn}=HzcVV02|GGs1*rIqHN&+?(ic7Q%5B52*_LX%N10|kLNs6rOoU*cW>dMZ>m7UWb z?VKK`oS1B;`)P$g#XZfPA~^%opW!T>){RX)?e5}#sbjW(XBp!SAtyjj5;5b5VN4jJ zro!UVt0l7hz3Iu&a6#0R$LwWiuz6@$5H;nou;bEZ4G!I!dnNO_WCb&fE@QciSjkEl zJ~&lX1zG!O0a}{wPbXu8P4B*74Aozf$GpidXQ!f(^1-I5EF2kf>ky;NR`~TMsIeU_d_bEGLgho!r_EF1eP zqm&Ho!HPT&wrj9pRxwoKKuB6t4Vv>(rR1f@;%LoiSCLDP%`m@h^a(m=R?}ZG8<<9B z1F4P8n8v_RTX|&)J_+5Wwmd;ZmKMf59}0<>ltS?2_YNed?`Jv5wAC~BMT{Z(afnx# zHl~7{nR{m+b_)b`DI`^rc?Z3%jqD!aoNN0HrN(Wr(o5K0oX=OMZ6znH}*2X@UO z1G{F6!vl5B(dqT}a{CudIjLPV=NtolU7mXQbV6#DAnVwOiA0ct>X@_@*nWC?Rr0;n zbI`6vn@iqXJ#S2daYP2@6_Mhh=Ic_LB>Jw`Xc&=nb3CDP$X(jDIO1V1GsdOhxyLgv zAsyDRmnFdkUwRa}9=1DS!S2!kEZxT+fq&o6{{pg~kIw&iKUIjlZvHgd zOB35-8zIfT&DYp|oC3!g~5O`W8BmzJ)Z$=h)FbGlo5)C!RU%5ScKz{3;78X$q1wq!>KBH2l7_m6)&|%C0d(QEWW0~V84vXU}`v!Z5 zeYD+bzn59Yq%jueVR}7XfUrDed&pL8E4C5qTq_}q2WMwlp|>q13X`60m0-~x?>Z2I z$nbz(5!zm zg@fn|QUG=&J!^qhyb&_E%5w()u(8Yq_ub_4!=kmGY4GmJnnE(=UQdP@n$~-&DRCo} zo{`eewL(r`$Qz;^f4BGG^fT?OeK!M4A(P77K`){0^f$Imw!3UOHm7Zm^#QA9&9lB} zX|cRGTt>85HYLO;Tsj(He|P7UqZG6Sn>e^U4fEUi!EhR^$niBo{BAo34}PQv$aT5C zb2u@N?Cg2`jsUqP&o_YrVUFQougwX1T^w-a_B zD+t2rrTJ5PhI|4n?1-m>t=KmO242s%z~B>gc6heLw*)&YDTQ#j6yNveAM8aWv)p%v z8g7ynWaU(IMcgrVKO0~f_Ay6~Bjm_*Ja6A^ue2A~AIAp7<3{NC83Px)ibQzr$?+om z>_9vK2g2C(h%_+69RmlPoZcOPma}RA9+z>#b0WTdF#fjquOR-HNkK@}d~4xoUZf1# zLcWLK=1FBdIj3Wz0Q#c7TY#(feFl9SJ&4jx(IA`~=kqq_59hAi1;0QGk79#8;G~_8 z@*8-G&*Y!yc5#z9nR|mh%r0UZ+4mgB9slN-g}XI~yLF@e2j(+oIdeT@WiHTN^c{47 zK5h$;=B}$=wn1N?trBk8>b({EcZvc$|I@qxShjhm7 zX?_USA#g@(ku0qD6;Y6OdwK*0Uqb%B`FcGAPEW+e&3RPQdV25R;KUkhIxLuMZ1HPY zZQPp?2}1o8{O{$oaU5|?_1%C&|0##V9OTTE%SnSpG6go3m$qHHSwsZ zg6%CORpgCzU1gFX9vN)#(;@yXZa?=H$?5MpzzvIQp|gy!z{U1L{4m_O2_xxh2n8dG zhH5l$&#cN-BitWIPImWB4*vMHphEmSO~dS8r)QyH zQbsWA=q9%25(>-^14VzS$AT*A!2IlPGy$xF8)i-g(9Z5?l1pGB};8! zT%Ldu3%nAdp;E&TvWSqgtW-|!UD1uUf$Uw;(`^3nCebn>MHq6fcmB%x6+g&#@IT=@ zh`+PvdLI7plud)J7wVl5o$n5y&;h50;y1(ybssXmfPwK&DLH0c58C}`8_Ae;o8G&C z>L%lg(lQ~4_y$qp)?e={xYgSBn4%O(B2cRg4XNPoiLwH3E-Uu z@EY4_$fRx)zH@%!Oge9Mj&_>S$f(`|v(vaYR0%PJU|2INpu`CU(+DEI2y1vF;5lqNE0l=!?@IP9HiXvhH*_!5a$|V^e;*Rt}@+_BI6wxl${Q1+T3M!yzSGN zf})1Uoc^SxVl&zU#8R>4dlz12I45)qNuitA_S#k&6iSlpC~ delta 4453 zcmai&YgkiP)_~8+$@Sb4ARq}z2q0jPkbvCL7!XiEA*kRTqE@X|Td<{STQt$u1N8!F zfQ7ZAsI5-5YP@x~YB5@!skLgXopw4(hpIDGt0VR{by{`yq2E07J!`|9i<}Oa*g_9&;$&xrc`LuPZHJx@;FWF|65JA?&RnF%k@g z_EtzI9S&0xch&3FE7VH$54@lE@UbW?(Cy>FU7*Q^w|}OM=to}0kX=2JD9X=eAe~0m z!+~=PSfIaB&%oV>>4k8?&F4bh^)epxHxdj&yGe$TV-y%AMupl_!j-c$x^k9De3vu>>0t}0aOG4~(%_3w z!s?bpyU^0YqLvnJIU;!=oxD-Gj}5U!?AHuRa`IU$STD?If(57Ja$%L+!9$>wzY3M= zG=66#zg^T~fX`D5S1nY1u53bu1N>nnXcE{|SUJI|hVjRIdXPNiW8tRRX@^-mXqB-2 z0gtvnP}j{ThXU!=lqoY6s}t?89R<JHA#}T3({b6NPxuP3QFt=AIU?OF) zqmRA`%k3tMv4h6zklxiP@DSA&z7JORIV^BT?bgC`ud=ytuf)N@BW{)f-E9kxuMh5E`;9V#(I z99y%BC+?`%sh6pv)DrbJegkjeqxd7-cFx8bxMS=l_842i-ex{$nwZ(lPxLj~OFvE1 z^iQfzl~-j}ajJKe-O3zgg7TnZtHPm3QJj^(E3c6k%KK%X$mYu?%I-`1q%TPqN|n+d zsWnt1rJ){5-jK9Ni$uj8PSoJcz(yK{>(jCQbfUB)40CBaX9e!+&-vH~LV0n1zMi!2qH;XG^wJ*L2 zDo-SEa8^Yck@wBwd=gysL!|JFFTMc2iBV+3@wbW(!4F4yE8P1!mw~5Wi~kuy9-SJl z>f|~&zq5EeT)Za5Qg24bK$|~43cfg3k`0anI74pr;yftaRXhYbc92%M40$^E)*rtU zx^9iJ!t$IUv5+^HHNweVSQ*3TqdA!N*=QpSbj05PUaQjh1CbVZB!oMl<9$UM+?2UD zf_Z!r51y9ruRv|!S$KQKI3sM8yVpV4`IIcUw}|HV@xcX(M5> z%!UK{vpB5cX!rJ=F0G4n0@>!XMPI=o52frj5+S21Pb1wn((pDhY*FxFVAzprPGXFrSQMq4f{=wi~BM zn&IS->~f)QQY0n{GXxWj>L%&7FD6Ig%A#9AWs!$ab1*!cgdaLNtmcU<9D|=~3`l!Q zcjp69QIFosd`Z-k8W%xW2;}poKu;(qAK0H}asC%SF~F9tktMJt&X5a_Kh5TV?$#KD zpjVHHK|!y<*(>VnG&U+;muRRpl6bVjX-ts7A4{VuAoQ6;4=rUzHFP{siv24S=g>dG zV}!;rM$|aQ)Rro$&|va1hjG@Ku?i!IlRj@|XOF1>PFT`dIR30k4}QTJ4F?CL`4D(k z!@-MlWO`w9qX`p>HaD8*)eT;Y`4vSP@g89ycKZ)#wD3-dQ=#1JU?c%1e_5P z7IwEzH=?grr^f$I6lJzHQZ154{}JUG5(n+~yHa4$Yr_meeS#I$C!}pNh@t09TS09m zby(*&NYS`_aMCIs&acG9{)M8nwI7fpJ4xFL>Qz#S_pBq~jyuVYawWEpCXlH(`fbxF zmuw%J-Cr9k^R;I{us=0lxLISv*wM|J^k`bldUyH~DwW!T8bd>^6qtoccHzR@bW8?j z1SS(*n46VxRrHWEyO|nGo<&|)@;DMgJz5+~nI=|H7iA-LQO^A(Vh9CtUZJXqD`Id= zgrM100D<9^YIwRRPA51HF!$MAv8qy2f_ldXGU)V&klC{{E9KK_mzC!R? zA*Kj31_iGbI|p;Ur^HKj6DIJ6#=HUj9@+|;BgL^$b;w~rt=meL@uI9iNfTw2?836r zuq;<`3A%HzM0=j3ad4DIDt+7@&dyzAk;zkPwaeZ$^>A`V_XwXz*Z+xeB&VlDThUCG$ zPe-%DrgP)argLQrCW}YR3@nIep?{4@ z3QOjMN2A{jFP~d1TI-TT+IqO7ToQMNJ;+wG4z`c^m|4iwGC$H^(;j*WEujZgZK`I~ z_sV|dD&=A&110@(YfPEsgygiuL5-JuEcuMAmCPiskynU$q>K2<=S_x?Mv?-qWCb1! zW5akzeMhMe%gD;e%Ccvs=Vp({v*%=HXZW{ON5bYyn*X&a;{R+)&xC6BwW9w=PQicI zDJ%kfOG_MJ|BDq*U?-JQh+S$J0>N(!rI2;BCLCp7mm6fTS36vfhJGtgBH{Oa#o17` z%Ed$52Q@bM@j<;6+;63sp^Y@gz`;B63TUhC=Px)F;#ahq!wVnw(J(Vt!#RaNTr%2((aYxw6C@xfalm+uMV~y~0 zYOD)pr%p{rEBY#hjKR6GDpA?ZoMOtETjE?!E{KW zfx4;=Y=EKa>WAXV;P6D{3Fa14&pg5>`>QG#Sn-s(0X|Jl!qxerg!iXj)}it{Rqq}m z<8f}fN~o6=7f`ipib?_P_MtDqy&j!ic(lufsm4r1k9JLx{8^m-J(EXKr8Z)1%2h3> z&^_gz5_%fAnGk3XFM!~`VsZGtU4uez*V+endZy%2^He+jXD3!{*7bWkdakHV#JC^PadCm^p#^RbDCj@F8CNk3G=rh|mf9(M+@-&8k+4-VV+xp+NEV zf4$xITi4vbx`ct(Kb~L^iawf!sY68{)%WHIq|=R=jhTa)i+KVw4>KRL0P`efA!ZTg zDa>Nb63o+>XD|(zM$A&ovzX^FO_=8~%P`9^D=;r$Uc|hFX~wL?tipIOUd(FD8qCX> zS1@Za0;UDiig^{Y4znJ!0kaX2&bG%Fk;dsJojWz{4Po-ZC;M_Wul^Umi5K`yC~Uyi z!o%bR>AC3A#&%nJSAVGqnF3vSGJVs;}&R6UPJqu{SzgF z@4ofw$E5aH)v53ke(jm<~)QO4!)d zcOO@8@%O`O^?IZls9vM?$4BIWd%m87844-}KQ19cwD-3+-Y@)3Y;6-Sps=3g=XrnI t7kMbicRRznLLE^kYGCvk9_-RGDGacaS!mwH>jbH-8?y;XZLj~T`8OpqP?!J! diff --git a/src/malco/analysis/check_lens.py b/src/malco/analysis/check_lens.py index ca47a492..5eaa0cf8 100644 --- a/src/malco/analysis/check_lens.py +++ b/src/malco/analysis/check_lens.py @@ -1,11 +1,11 @@ -import pandas as pd +import sys + +# from malco.post_process.post_process_results_format import read_raw_result_yaml +from pathlib import Path from typing import List import pandas as pd import yaml -# from malco.post_process.post_process_results_format import read_raw_result_yaml -from pathlib import Path -import sys def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: @@ -18,8 +18,10 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: Returns: dict: Contents of the raw result file. """ - with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list + with open(raw_result_path, "r") as raw_result: + return list( + yaml.safe_load_all(raw_result.read().replace("\x04", "")) + ) # Load and convert to list unique_ppkts = {} @@ -38,20 +40,27 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: for this_result in all_results: extracted_object = this_result.get("extracted_object") if extracted_object: - label = extracted_object.get('label') + label = extracted_object.get("label") labelvec.append(label) - terms = extracted_object.get('terms') + terms = extracted_object.get("terms") if terms: counter += 1 full_df_file = f"out_openAI_models/multimodel/{model}/results.tsv" - df = pd.read_csv(full_df_file, sep='\t') - num_ppkts = df['label'].nunique() - unique_ppkts[model] = df['label'].unique() + df = pd.read_csv(full_df_file, sep="\t") + num_ppkts = df["label"].nunique() + unique_ppkts[model] = df["label"].unique() # The first should be equivalent to grepping "raw_" in some results.yaml print("The number of prompts that have something in results.yaml are: ", len(labelvec)) - print("The number of prompts that have a non-empty differential (i.e. term is not None) is:", counter) - print("The number of unique prompts/ppkts with a non-empty differential in results.tsv are:", num_ppkts, "\n") + print( + "The number of prompts that have a non-empty differential (i.e. term is not None) is:", + counter, + ) + print( + "The number of unique prompts/ppkts with a non-empty differential in results.tsv are:", + num_ppkts, + "\n", + ) # This we know a posteriori, gpt-4o and gpt-4-turbo both have 5213 phenopackets # Thus, let's print out what is missing in the others diff --git a/src/malco/analysis/count_grounding_failures.py b/src/malco/analysis/count_grounding_failures.py index 262d258b..2527e307 100644 --- a/src/malco/analysis/count_grounding_failures.py +++ b/src/malco/analysis/count_grounding_failures.py @@ -1,11 +1,10 @@ # Quick check how often the grounding failed # Need to be in short_letter branch import pandas as pd + mfile = "../outputdir_all_2024_07_04/en/results.tsv" -df = pd.read_csv( - mfile, sep="\t" # , header=None, names=["description", "term", "label"] -) +df = pd.read_csv(mfile, sep="\t") # , header=None, names=["description", "term", "label"] terms = df["term"] counter = 0 diff --git a/src/malco/analysis/count_translated_prompts.py b/src/malco/analysis/count_translated_prompts.py index f6d778ca..c8a8dc33 100644 --- a/src/malco/analysis/count_translated_prompts.py +++ b/src/malco/analysis/count_translated_prompts.py @@ -1,20 +1,22 @@ import os import re + fp = "/Users/leonardo/IdeaProjects/phenopacket2prompt/prompts/" -langs = ["en", - "es", - "de", - "it", - "nl", - "tr", - "zh", - ] +langs = [ + "en", + "es", + "de", + "it", + "nl", + "tr", + "zh", +] promptfiles = {} for lang in langs: promptfiles[lang] = [] - for (dirpath, dirnames, filenames) in os.walk(fp + lang): + for dirpath, dirnames, filenames in os.walk(fp + lang): for fn in filenames: fn = fn[0:-14] # TODO may be problematic if there are 2 "_" before "{langcode}-" # Maybe something along the lines of other script disease_avail_knowledge.py @@ -24,13 +26,13 @@ intersection = set() -enset = set(promptfiles['en']) -esset = set(promptfiles['es']) -deset = set(promptfiles['de']) -itset = set(promptfiles['it']) -nlset = set(promptfiles['nl']) -zhset = set(promptfiles['zh']) -trset = set(promptfiles['tr']) +enset = set(promptfiles["en"]) +esset = set(promptfiles["es"]) +deset = set(promptfiles["de"]) +itset = set(promptfiles["it"]) +nlset = set(promptfiles["nl"]) +zhset = set(promptfiles["zh"]) +trset = set(promptfiles["tr"]) intersection = enset & esset & deset & itset & nlset & zhset & trset diff --git a/src/malco/analysis/eval_diagnose_category.py b/src/malco/analysis/eval_diagnose_category.py index ba7851bb..4b60a30d 100644 --- a/src/malco/analysis/eval_diagnose_category.py +++ b/src/malco/analysis/eval_diagnose_category.py @@ -1,16 +1,15 @@ -import pandas as pd -import numpy as np import sys +import numpy as np +import pandas as pd +from cachetools import LRUCache, cached +from cachetools.keys import hashkey +from oaklib import get_adapter from oaklib.datamodels.vocabulary import IS_A, PART_OF -from oaklib.interfaces import MappingProviderInterface -from oaklib.interfaces import OboGraphInterface +from oaklib.interfaces import MappingProviderInterface, OboGraphInterface from oaklib.interfaces.obograph_interface import GraphTraversalMethod -from oaklib import get_adapter - -from cachetools import cached, LRUCache -from cachetools.keys import hashkey from shelved_cache import PersistentCache + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -31,6 +30,7 @@ def mondo_adapter() -> OboGraphInterface: """ return get_adapter("sqlite:obo:mondo") + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -41,6 +41,7 @@ def mondo_mapping(term, adapter): mondos.append(m.subject_id) return mondos + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -54,9 +55,11 @@ def find_category(omim_term, disease_categories, mondo): print(omim_term) return None - ancestor_list = mondo.ancestors(mondo_term, # only IS_A->same result - # , reflexive=True) # method=GraphTraversalMethod.ENTAILMENT - predicates=[IS_A, PART_OF]) + ancestor_list = mondo.ancestors( + mondo_term, # only IS_A->same result + # , reflexive=True) # method=GraphTraversalMethod.ENTAILMENT + predicates=[IS_A, PART_OF], + ) for mondo_ancestor in ancestor_list: if mondo_ancestor in disease_categories: @@ -76,8 +79,9 @@ def find_category(omim_term, disease_categories, mondo): mondo = mondo_adapter() -disease_categories = mondo.relationships(objects=["MONDO:0003847"], # hereditary diseases - predicates=[IS_A, PART_OF]) # only IS_A->same result +disease_categories = mondo.relationships( + objects=["MONDO:0003847"], predicates=[IS_A, PART_OF] # hereditary diseases +) # only IS_A->same result # disease_categories = mondo.relationships(objects = ["MONDO:0700096"], # only IS_A->same result # predicates=[IS_A, PART_OF]) @@ -92,9 +96,7 @@ def find_category(omim_term, disease_categories, mondo): # label term score rank correct_term is_correct reciprocal_rank # PMID_35962790_Family_B_Individual_3__II_6__en-prompt.txt MONDO:0008675 1.0 1.0 OMIM:620545 False 0.0 -df = pd.read_csv( - filename, sep="\t" -) +df = pd.read_csv(filename, sep="\t") ppkts = df.groupby("label")[["term", "correct_term", "is_correct"]] count_fails = 0 @@ -126,7 +128,9 @@ def find_category(omim_term, disease_categories, mondo): continue print("\n\n", "===" * 15, "\n") -print(f"For whatever reason find_category() returned None in {count_fails} cases, wich follow:\n") # print to file! +print( + f"For whatever reason find_category() returned None in {count_fails} cases, wich follow:\n" +) # print to file! # print(contingency_table) print("\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) # print(omim_wo_match, "\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) diff --git a/src/malco/analysis/test_curate_script.py b/src/malco/analysis/test_curate_script.py index a5e98db0..9c00e3bb 100644 --- a/src/malco/analysis/test_curate_script.py +++ b/src/malco/analysis/test_curate_script.py @@ -1,9 +1,11 @@ -import yaml from pathlib import Path from typing import List -from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo + +import yaml from oaklib import get_adapter +from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo + def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: """ @@ -15,12 +17,16 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: Returns: dict: Contents of the raw result file. """ - with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list + with open(raw_result_path, "r") as raw_result: + return list( + yaml.safe_load_all(raw_result.read().replace("\x04", "")) + ) # Load and convert to list annotator = get_adapter("sqlite:obo:mondo") -some_yaml_res = Path("/Users/leonardo/git/malco/out_openAI_models/raw_results/multimodel/gpt-4/results.yaml") +some_yaml_res = Path( + "/Users/leonardo/git/malco/out_openAI_models/raw_results/multimodel/gpt-4/results.yaml" +) data = [] @@ -36,7 +42,7 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: assert cleaned_text != "", "Cleaning failed: the cleaned text is empty." result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False) - label = extracted_object.get('label') # pubmed id + label = extracted_object.get("label") # pubmed id # terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream terms = [i[1][0][0] for i in result] # terms = extracted_object.get('terms') # list of strings, the mondo id or description @@ -47,7 +53,7 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank rank_list = [i + 1 for i in range(num_terms)] for term, scr, rank in zip(terms, score, rank_list): - data.append({'label': label, 'term': term, 'score': scr, 'rank': rank}) + data.append({"label": label, "term": term, "score": scr, "rank": rank}) if j > 20: break j += 1 diff --git a/src/malco/analysis/time_ic/disease_avail_knowledge.py b/src/malco/analysis/time_ic/disease_avail_knowledge.py index 04ba475b..8379477e 100644 --- a/src/malco/analysis/time_ic/disease_avail_knowledge.py +++ b/src/malco/analysis/time_ic/disease_avail_knowledge.py @@ -13,23 +13,23 @@ `runoak -g hpoa_file -G hpoa -i hpo_file information-content -p i --use-associations .all` """ -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -from scipy.stats import mannwhitneyu -from scipy.stats import ttest_ind -from scipy.stats import kstest -from scipy.stats import chi2_contingency -import sys -import os -import pandas as pd -import numpy as np + import datetime as dt +import json +import os +import pickle +import re +import sys from pathlib import Path -import matplotlib.pyplot as plt + import matplotlib.dates as mdates +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd import seaborn as sns -import json -import re -import pickle + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +from scipy.stats import chi2_contingency, kstest, mannwhitneyu, ttest_ind # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Parse input: @@ -39,7 +39,7 @@ make_plots = str(sys.argv[2]) == "plot" except IndexError: make_plots = False - print("\nYou can pass \"plot\" as a second CLI argument and this will generate nice plots!\n") + print('\nYou can pass "plot" as a second CLI argument and this will generate nice plots!\n') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # PATHS: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -53,12 +53,10 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # IMPORT # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -with open(outdir / "rank_date_dict.pkl", 'rb') as f: +with open(outdir / "rank_date_dict.pkl", "rb") as f: rank_date_dict = pickle.load(f) # import df of LLM results -rank_results_df = pd.read_csv( - ranking_results_filename, sep="\t" -) +rank_results_df = pd.read_csv(ranking_results_filename, sep="\t") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -71,7 +69,7 @@ ranks_wo_none = [] for key, data in rank_date_dict.items(): r = data[0] - d = dt.datetime.strptime(data[1], '%Y-%m-%d').date() + d = dt.datetime.strptime(data[1], "%Y-%m-%d").date() dates.append(d) ranks.append(r) if r is not None: @@ -115,9 +113,9 @@ else: cont_table[1][0] += 1 -df_contingency_table = pd.DataFrame(cont_table, - index=["found", "not_found"], - columns=["y<2010", "y>=2010"]) +df_contingency_table = pd.DataFrame( + cont_table, index=["found", "not_found"], columns=["y<2010", "y>=2010"] +) print(df_contingency_table) print("H0: no correlation between column 1 and 2:") res = chi2_contingency(cont_table) @@ -143,15 +141,15 @@ for subdir, dirs, files in os.walk(original_ppkt_dir): # For each ppkt for filename in files: - if filename.endswith('.json'): + if filename.endswith(".json"): file_path = os.path.join(subdir, filename) with open(file_path, mode="r", encoding="utf-8") as read_file: ppkt = json.load(read_file) - ppkt_id = re.sub('[^\\w]', '_', ppkt['id']) + ppkt_id = re.sub("[^\\w]", "_", ppkt["id"]) ic = 0 num_hpos = 0 # For each HPO - for i in ppkt['phenotypicFeatures']: + for i in ppkt["phenotypicFeatures"]: try: if i["excluded"]: # skip excluded continue @@ -178,33 +176,35 @@ missing_in_ic_dict_unique = set(missing_in_ic_dict) ppkts_with_missing_hpos = set(ppkts_with_missing_hpos) print(f"\nNumber of (unique) HPOs without IC-value is {len(missing_in_ic_dict_unique)}.") # 65 -print(f"Number of ppkts with zero observed HPOs is {len(ppkts_with_zero_hpos)}. These are left out.") # 141 +print( + f"Number of ppkts with zero observed HPOs is {len(ppkts_with_zero_hpos)}. These are left out." +) # 141 # TODO check 141 # 172 print( - f"Number of ppkts where at least one HPO is missing its IC value is {len(ppkts_with_missing_hpos)}. These are left out from the average.\n") + f"Number of ppkts where at least one HPO is missing its IC value is {len(ppkts_with_missing_hpos)}. These are left out from the average.\n" +) -ppkt_ic_df = pd.DataFrame.from_dict(ppkt_ic, orient='index', columns=['avg(IC)']) -ppkt_ic_df['Diagnosed'] = 0 +ppkt_ic_df = pd.DataFrame.from_dict(ppkt_ic, orient="index", columns=["avg(IC)"]) +ppkt_ic_df["Diagnosed"] = 0 for ppkt in ppkts: if any(ppkt[1]["is_correct"]): - ppkt_label = ppkt[0].replace('_en-prompt.txt', '') + ppkt_label = ppkt[0].replace("_en-prompt.txt", "") if ppkt_label in ppkts_with_zero_hpos: continue - ppkt_ic_df.loc[ppkt_label, 'Diagnosed'] = 1 + ppkt_ic_df.loc[ppkt_label, "Diagnosed"] = 1 # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # See https://github.com/monarch-initiative/phenopacket-store/issues/157 -label_manual_removal = ["PMID_27764983_Family_1_individual__J", - "PMID_35991565_Family_I__3"] +label_manual_removal = ["PMID_27764983_Family_1_individual__J", "PMID_35991565_Family_I__3"] ppkt_ic_df = ppkt_ic_df.drop(label_manual_removal) # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # ppkt_ic_df['Diagnosed'].value_counts() # Diagnosed # 0.0 4182 64% # 1.0 2347 36% -ppkt_ic_df.to_csv(outdir / "ppkt_ic.tsv", sep='\t', index=True) +ppkt_ic_df.to_csv(outdir / "ppkt_ic.tsv", sep="\t", index=True) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # T-test: unlikely that the two samples are such due to sample bias. @@ -212,7 +212,7 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # One-sample Kolmogorov-Smirnov test: compares the underlying distribution F(x) of a sample # against a given distribution G(x), here the normal distribution -kolsmirnov_result = kstest(ppkt_ic_df['avg(IC)'], 'norm') +kolsmirnov_result = kstest(ppkt_ic_df["avg(IC)"], "norm") print(kolsmirnov_result, "\n") # When interpreting be careful, and I quote one answer from: # https://stats.stackexchange.com/questions/2492/is-normality-testing-essentially-useless @@ -222,8 +222,8 @@ # --------------- t-test --------------- # TODO: regardless of above, our distributions are not normal. Discard t and add non-parametric U-test (Mann-Whitney) -found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed'] > 0, 'avg(IC)']) -not_found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed'] < 1, 'avg(IC)']) +found_ic = list(ppkt_ic_df.loc[ppkt_ic_df["Diagnosed"] > 0, "avg(IC)"]) +not_found_ic = list(ppkt_ic_df.loc[ppkt_ic_df["Diagnosed"] < 1, "avg(IC)"]) tresult = ttest_ind(found_ic, not_found_ic, equal_var=False) print("T-test result:\n", tresult, "\n") @@ -234,11 +234,11 @@ # --------------- plot --------------- if make_plots: - plt.hist(found_ic, bins=25, color='c', edgecolor='k', alpha=0.5, density=True) - plt.hist(not_found_ic, bins=25, color='r', edgecolor='k', alpha=0.5, density=True) + plt.hist(found_ic, bins=25, color="c", edgecolor="k", alpha=0.5, density=True) + plt.hist(not_found_ic, bins=25, color="r", edgecolor="k", alpha=0.5, density=True) plt.xlabel("Average Information Content") plt.ylabel("Counts") - plt.legend(['Successful Diagnosis', 'Unsuccessful Diagnosis']) + plt.legend(["Successful Diagnosis", "Unsuccessful Diagnosis"]) plt.savefig(outdir / "inf_content_histograms.png") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/malco/analysis/time_ic/logit_predict_llm.py b/src/malco/analysis/time_ic/logit_predict_llm.py index b26c648b..9615a0f9 100644 --- a/src/malco/analysis/time_ic/logit_predict_llm.py +++ b/src/malco/analysis/time_ic/logit_predict_llm.py @@ -3,6 +3,11 @@ will be successfull or not? Machine Learning Task: can we build a logit to classify into two?""" +import pickle +from pathlib import Path + +import numpy as np + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ML(1): let's try with IC # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -10,28 +15,30 @@ # Maybe some ideas: # https://towardsdatascience.com/building-a-logistic-regression-in-python-step-by-step-becd4d56c9c8 [1/10/24] import pandas as pd -import numpy as np -import pickle from sklearn.linear_model import LogisticRegression +from sklearn.metrics import classification_report, confusion_matrix from sklearn.model_selection import train_test_split -from sklearn.metrics import confusion_matrix, classification_report -from pathlib import Path + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Path outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" # Import -ppkt_ic_df = pd.read_csv(outdir / "ppkt_ic.tsv", delimiter='\t', index_col=0) -with open(outdir / "rank_date_dict.pkl", 'rb') as f: +ppkt_ic_df = pd.read_csv(outdir / "ppkt_ic.tsv", delimiter="\t", index_col=0) +with open(outdir / "rank_date_dict.pkl", "rb") as f: rank_date_dict = pickle.load(f) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ X_train, X_test, y_train, y_test = train_test_split( - ppkt_ic_df[['avg(IC)']], ppkt_ic_df['Diagnosed'], test_size=0.2, random_state=0) + ppkt_ic_df[["avg(IC)"]], ppkt_ic_df["Diagnosed"], test_size=0.2, random_state=0 +) logreg = LogisticRegression() logreg.fit(X_train, y_train) y_pred = logreg.predict(X_test) -print('Accuracy of 1 parameter (IC) logistic regression classifier on test set: {:.2f}'.format( - logreg.score(X_test, y_test))) +print( + "Accuracy of 1 parameter (IC) logistic regression classifier on test set: {:.2f}".format( + logreg.score(X_test, y_test) + ) +) cm1d = confusion_matrix(y_test, y_pred) print(cm1d) class_report = classification_report(y_test, y_pred) @@ -45,20 +52,24 @@ # rank_date_dict { 'pubmedid' + '_en-prompt.txt': [rank, '2012-05-11'],...} # ppkt_ic_df has row names with 'pubmedid' # for entry in ppkt_ic_df if rowname in rank_date_dict.keys() get that .values() index[1] and parse first 4 entries (year is 4 digits) -date_df = pd.DataFrame.from_dict(rank_date_dict, orient='index', columns=['rank', 'date']) -date_df.drop(columns='rank', inplace=True) -date_df['date'] = date_df['date'].str[0:4] +date_df = pd.DataFrame.from_dict(rank_date_dict, orient="index", columns=["rank", "date"]) +date_df.drop(columns="rank", inplace=True) +date_df["date"] = date_df["date"].str[0:4] new_index = np.array([i[0:-14].rstrip("_") for i in date_df.index.to_list()]) date_df.set_index(new_index, inplace=True) -ic_date_df = date_df.join(ppkt_ic_df, how='inner') +ic_date_df = date_df.join(ppkt_ic_df, how="inner") X_train, X_test, y_train, y_test = train_test_split( - ic_date_df[['avg(IC)', 'date']], ic_date_df['Diagnosed'], test_size=0.2, random_state=0) + ic_date_df[["avg(IC)", "date"]], ic_date_df["Diagnosed"], test_size=0.2, random_state=0 +) logreg = LogisticRegression() logreg.fit(X_train, y_train) y_pred = logreg.predict(X_test) -print('\nAccuracy of 2 PARAMETER (IC and time) logistic regression classifier on test set: {:.2f}'.format( - logreg.score(X_test, y_test))) +print( + "\nAccuracy of 2 PARAMETER (IC and time) logistic regression classifier on test set: {:.2f}".format( + logreg.score(X_test, y_test) + ) +) cm2d = confusion_matrix(y_test, y_pred) print(cm2d) class_report = classification_report(y_test, y_pred) diff --git a/src/malco/analysis/time_ic/rank_date_exploratory.py b/src/malco/analysis/time_ic/rank_date_exploratory.py index 37042d4b..716903b2 100644 --- a/src/malco/analysis/time_ic/rank_date_exploratory.py +++ b/src/malco/analysis/time_ic/rank_date_exploratory.py @@ -14,10 +14,12 @@ """ import pickle -from pathlib import Path -import pandas as pd import sys +from pathlib import Path + import numpy as np +import pandas as pd + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Parse user input and set paths: model = str(sys.argv[1]) @@ -30,25 +32,36 @@ # (1) HPOA for dates # HPOA import and setup hpoa_df = pd.read_csv( - hpoa_file_path, sep="\t", header=4, low_memory=False # Necessary to suppress Warning we don't care about + hpoa_file_path, + sep="\t", + header=4, + low_memory=False, # Necessary to suppress Warning we don't care about ) -labels_to_drop = ["disease_name", "qualifier", "hpo_id", "reference", "evidence", - "onset", "frequency", "sex", "modifier", "aspect"] +labels_to_drop = [ + "disease_name", + "qualifier", + "hpo_id", + "reference", + "evidence", + "onset", + "frequency", + "sex", + "modifier", + "aspect", +] hpoa_df = hpoa_df.drop(columns=labels_to_drop) -hpoa_df['date'] = hpoa_df["biocuration"].str.extract(r'\[(.*?)\]') -hpoa_df = hpoa_df.drop(columns='biocuration') -hpoa_df = hpoa_df[hpoa_df['database_id'].str.startswith("OMIM")] +hpoa_df["date"] = hpoa_df["biocuration"].str.extract(r"\[(.*?)\]") +hpoa_df = hpoa_df.drop(columns="biocuration") +hpoa_df = hpoa_df[hpoa_df["database_id"].str.startswith("OMIM")] hpoa_unique = hpoa_df.groupby("database_id").date.min() # Now length 8251, and e.g. hpoa_unique.loc["OMIM:620662"] -> '2024-04-15' # import df of LLM results -rank_results_df = pd.read_csv( - ranking_results_filename, sep="\t" -) +rank_results_df = pd.read_csv(ranking_results_filename, sep="\t") # Go through results data and make set of found vs not found diseases. @@ -58,15 +71,17 @@ ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] for ppkt in ppkts: # TODO 1st for ppkt in ppkts # ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe - disease = ppkt[1].iloc[0]['correct_term'] + disease = ppkt[1].iloc[0]["correct_term"] if any(ppkt[1]["is_correct"]): found_diseases.append(disease) index_of_match = ppkt[1]["is_correct"].to_list().index(True) try: rank = ppkt[1].iloc[index_of_match]["rank"] # inverse rank does not work well - rank_date_dict[ppkt[0]] = [rank.item(), - hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] + rank_date_dict[ppkt[0]] = [ + rank.item(), + hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]], + ] # If ppkt[1].iloc[0]["correct_term"] is nan, then KeyError "e" is nan except (ValueError, KeyError) as e: print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") @@ -74,8 +89,7 @@ else: not_found_diseases.append(disease) try: - rank_date_dict[ppkt[0]] = [None, - hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] + rank_date_dict[ppkt[0]] = [None, hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] except (ValueError, KeyError) as e: # pass # TODO collect the below somewhere @@ -86,7 +100,7 @@ # len(rank_date_dict) --> 6625 # len(ppkts) --> 6687 -with open(outdir / "rank_date_dict.pkl", 'wb') as f: +with open(outdir / "rank_date_dict.pkl", "wb") as f: pickle.dump(rank_date_dict, f) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -120,19 +134,19 @@ results_dict = {} # turns out being 281 long # TODO get rid of next line, bzw hpoa_unique does not work for loop below -hpoa_df.drop_duplicates(subset='database_id', inplace=True) +hpoa_df.drop_duplicates(subset="database_id", inplace=True) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ for af in always_found: try: - results_dict[af] = [True, hpoa_df.loc[hpoa_df['database_id'] == af, 'date'].item()] + results_dict[af] = [True, hpoa_df.loc[hpoa_df["database_id"] == af, "date"].item()] # results_dict[af] = [True, hpoa_unique.loc[hpoa_unique['database_id'] == af, 'date'].item() ] except ValueError: print(f"No HPOA in always_found for {af}.") for nf in never_found: try: - results_dict[nf] = [False, hpoa_df.loc[hpoa_df['database_id'] == nf, 'date'].item()] + results_dict[nf] = [False, hpoa_df.loc[hpoa_df["database_id"] == nf, "date"].item()] # results_dict[nf] = [False, hpoa_unique.loc[hpoa_unique['database_id'] == nf, 'date'].item() ] except ValueError: print(f"No HPOA in never_found for {nf}.") @@ -141,7 +155,7 @@ # disease = ppkt[1].iloc[0]['correct_term'] res_to_clean = pd.DataFrame.from_dict(results_dict).transpose() res_to_clean.columns = ["found", "date"] -res_to_clean['date'] = pd.to_datetime(res_to_clean.date).values.astype(np.int64) -final_avg = pd.DataFrame(pd.to_datetime(res_to_clean.groupby('found').mean()['date'])) +res_to_clean["date"] = pd.to_datetime(res_to_clean.date).values.astype(np.int64) +final_avg = pd.DataFrame(pd.to_datetime(res_to_clean.groupby("found").mean()["date"])) print(final_avg) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/malco/post_process/df_save_util.py b/src/malco/post_process/df_save_util.py index bd140b71..4217bf69 100644 --- a/src/malco/post_process/df_save_util.py +++ b/src/malco/post_process/df_save_util.py @@ -1,6 +1,5 @@ -import shutil import os -import pandas as pd +import shutil def safe_save_tsv(path, filename, df): @@ -13,4 +12,4 @@ def safe_save_tsv(path, filename, df): os.remove(old_full_path) shutil.copy(full_path, old_full_path) os.remove(full_path) - df.to_csv(full_path, sep='\t', index=False) + df.to_csv(full_path, sep="\t", index=False) diff --git a/src/malco/post_process/extended_scoring.py b/src/malco/post_process/extended_scoring.py index 2fae6841..3e173b2b 100644 --- a/src/malco/post_process/extended_scoring.py +++ b/src/malco/post_process/extended_scoring.py @@ -1,10 +1,11 @@ import re -import os -from oaklib.interfaces.text_annotator_interface import TextAnnotationConfiguration -from oaklib.interfaces.text_annotator_interface import TextAnnotatorInterface -from curategpt.store import get_store from typing import List, Tuple +from curategpt.store import get_store +from oaklib.interfaces.text_annotator_interface import ( + TextAnnotationConfiguration, + TextAnnotatorInterface, +) # Compile a regex pattern to detect lines starting with "Differential Diagnosis:" dd_re = re.compile(r"^[^A-z]*Differential Diagnosis") @@ -14,26 +15,28 @@ def clean_service_answer(answer: str) -> str: """Remove the 'Differential Diagnosis' header if present, and clean the first line.""" - lines = answer.split('\n') + lines = answer.split("\n") # Filter out any line that starts with "Differential Diagnosis:" cleaned_lines = [line for line in lines if not dd_re.match(line)] - return '\n'.join(cleaned_lines) + return "\n".join(cleaned_lines) + # Clean the diagnosis line by removing leading numbers, periods, asterisks, and spaces def clean_diagnosis_line(line: str) -> str: """Remove leading numbers, asterisks, and unnecessary punctuation/spaces from the diagnosis.""" - line = re.sub(r'^\**\d+\.\s*', '', line) # Remove leading numbers and periods - line = line.strip('*') # Remove asterisks around the text + line = re.sub(r"^\**\d+\.\s*", "", line) # Remove leading numbers and periods + line = line.strip("*") # Remove asterisks around the text return line.strip() # Strip any remaining spaces + # Split a diagnosis into its main name and synonym if present def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]: """Split the diagnosis into main name and synonym (if present in parentheses).""" - match = re.match(r'^(.*)\s*\((.*)\)\s*$', diagnosis) + match = re.match(r"^(.*)\s*\((.*)\)\s*$", diagnosis) if match: main_diagnosis, synonym = match.groups() return main_diagnosis.strip(), synonym.strip() @@ -47,7 +50,7 @@ def perform_curategpt_grounding( database_type: str = "chromadb", limit: int = 1, relevance_factor: float = 0.23, - verbose: bool = False + verbose: bool = False, ) -> List[Tuple[str, str]]: """ Use curategpt to perform grounding for a given diagnosis when initial attempts fail. @@ -74,7 +77,11 @@ def perform_curategpt_grounding( # Filter results based on relevance factor (distance) if relevance_factor is not None: - results = [(obj, distance, _meta) for obj, distance, _meta in results if distance <= relevance_factor] + results = [ + (obj, distance, _meta) + for obj, distance, _meta in results + if distance <= relevance_factor + ] # Limit the results to the specified number (limit) limited_results = results[:limit] @@ -83,7 +90,7 @@ def perform_curategpt_grounding( pred_ids = [] pred_labels = [] - for obj, distance, _meta in limited_results: + for obj, _distance, _meta in limited_results: disease_mondo_id = obj.get("original_id") # Use the 'original_id' field for Mondo ID disease_label = obj.get("label") @@ -95,7 +102,7 @@ def perform_curategpt_grounding( if len(pred_ids) == 0: if verbose: print(f"No grounded IDs found for {diagnosis}") - return [('N/A', 'No grounding found')] + return [("N/A", "No grounding found")] return list(zip(pred_ids, pred_labels)) @@ -104,9 +111,9 @@ def perform_curategpt_grounding( def perform_oak_grounding( annotator: TextAnnotatorInterface, diagnosis: str, - exact_match: bool = True, - verbose: bool = False, - include_list: List[str] = ["MONDO:"], + exact_match: bool, + verbose: bool, + include_list: List[str], ) -> List[Tuple[str, str]]: """ Perform grounding for a diagnosis. The 'exact_match' flag controls whether exact or inexact @@ -133,7 +140,8 @@ def perform_oak_grounding( if verbose: print(f"No {match_type} grounded IDs found for: {diagnosis}") pass - return [('N/A', 'No grounding found')] + return [("N/A", "No grounding found")] + # Now, integrate curategpt into your ground_diagnosis_text_to_mondo function @@ -141,12 +149,12 @@ def perform_oak_grounding( def ground_diagnosis_text_to_mondo( annotator: TextAnnotatorInterface, differential_diagnosis: str, - verbose: bool = False, - include_list: List[str] = ["MONDO:"], + verbose: bool, + include_list: List[str], use_ontogpt_grounding: bool = True, curategpt_path: str = "../curategpt/stagedb/", curategpt_collection: str = "ont_mondo", - curategpt_database_type: str = "chromadb" + curategpt_database_type: str = "chromadb", ) -> List[Tuple[str, List[Tuple[str, str]]]]: results = [] @@ -159,21 +167,22 @@ def ground_diagnosis_text_to_mondo( continue # Try grounding the full line first (exact match) - grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, - verbose=verbose, include_list=include_list) + grounded = perform_oak_grounding( + annotator, clean_line, exact_match=True, verbose=verbose, include_list=include_list + ) # Try grounding with curategpt if no grounding is found - if use_ontogpt_grounding and grounded == [('N/A', 'No grounding found')]: + if use_ontogpt_grounding and grounded == [("N/A", "No grounding found")]: grounded = perform_curategpt_grounding( diagnosis=clean_line, path=curategpt_path, collection=curategpt_collection, database_type=curategpt_database_type, - verbose=verbose + verbose=verbose, ) # If still no grounding is found, log the final failure - if grounded == [('N/A', 'No grounding found')]: + if grounded == [("N/A", "No grounding found")]: if verbose: print(f"Final grounding failed for: {clean_line}") diff --git a/src/malco/post_process/generate_plots.py b/src/malco/post_process/generate_plots.py index 927245ea..1f355049 100644 --- a/src/malco/post_process/generate_plots.py +++ b/src/malco/post_process/generate_plots.py @@ -1,8 +1,8 @@ -import seaborn as sns +import csv + import matplotlib.pyplot as plt import pandas as pd -import os -import csv +import seaborn as sns # Make a nice plot, use it as function or as script @@ -19,8 +19,8 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file, else: name_string = str(len(languages)) - with mrr_file.open('r', newline='') as f: - lines = csv.reader(f, quoting=csv.QUOTE_NONNUMERIC, delimiter='\t', lineterminator='\n') + with mrr_file.open("r", newline="") as f: + lines = csv.reader(f, quoting=csv.QUOTE_NONNUMERIC, delimiter="\t", lineterminator="\n") results_files = next(lines) mrr_scores = next(lines) @@ -37,7 +37,7 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file, plt.close() # Plotting bar-plots with top ranks - df_aggr = pd.read_csv(topn_aggr_file, delimiter='\t') + df_aggr = pd.read_csv(topn_aggr_file, delimiter="\t") sns.barplot(x="Rank_in", y="percentage", data=df_aggr, hue=comparing) @@ -46,6 +46,8 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file, plt.ylim([0.0, 1.0]) plt.title("Rank Comparison for Differential Diagnosis") plt.legend(title=comparing) - plot_path = plot_dir / ("barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") + plot_path = plot_dir / ( + "barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png" + ) plt.savefig(plot_path) plt.close() diff --git a/src/malco/post_process/mondo_score_utils.py b/src/malco/post_process/mondo_score_utils.py index 55a5f849..81993ab1 100644 --- a/src/malco/post_process/mondo_score_utils.py +++ b/src/malco/post_process/mondo_score_utils.py @@ -1,10 +1,8 @@ -from oaklib.datamodels.vocabulary import IS_A -from oaklib.interfaces import MappingProviderInterface -from pathlib import Path - from typing import List -from cachetools.keys import hashkey +from cachetools.keys import hashkey +from oaklib.datamodels.vocabulary import IS_A +from oaklib.interfaces import MappingProviderInterface FULL_SCORE = 1.0 PARTIAL_SCORE = 0.5 diff --git a/src/malco/post_process/post_process.py b/src/malco/post_process/post_process.py index c7fe9624..f80611f3 100644 --- a/src/malco/post_process/post_process.py +++ b/src/malco/post_process/post_process.py @@ -1,7 +1,4 @@ -from pathlib import Path - from malco.post_process.post_process_results_format import create_standardised_results -import os def post_process(self) -> None: @@ -25,11 +22,12 @@ def post_process(self) -> None: raw_results_lang.mkdir(exist_ok=True, parents=True) output_lang.mkdir(exist_ok=True, parents=True) - create_standardised_results(curategpt, - raw_results_dir=raw_results_lang, - output_dir=output_lang, - output_file_name="results.tsv", - ) + create_standardised_results( + curategpt, + raw_results_dir=raw_results_lang, + output_dir=output_lang, + output_file_name="results.tsv", + ) elif self.modality == "several_models": for model in models: @@ -38,8 +36,9 @@ def post_process(self) -> None: raw_results_model.mkdir(exist_ok=True, parents=True) output_model.mkdir(exist_ok=True, parents=True) - create_standardised_results(curategpt, - raw_results_dir=raw_results_model, - output_dir=output_model, - output_file_name="results.tsv", - ) + create_standardised_results( + curategpt, + raw_results_dir=raw_results_model, + output_dir=output_model, + output_file_name="results.tsv", + ) diff --git a/src/malco/post_process/post_process_results_format.py b/src/malco/post_process/post_process_results_format.py index c47ab679..3ad54271 100644 --- a/src/malco/post_process/post_process_results_format.py +++ b/src/malco/post_process/post_process_results_format.py @@ -1,16 +1,15 @@ import json -import os from pathlib import Path from typing import List -import shutil + import pandas as pd import yaml -from pheval.post_processing.post_processing import PhEvalGeneResult, generate_pheval_result -from pheval.utils.file_utils import all_files -from pheval.utils.phenopacket_utils import GeneIdentifierUpdater, create_hgnc_dict +from oaklib import get_adapter +from pheval.post_processing.post_processing import PhEvalGeneResult +from pheval.utils.phenopacket_utils import GeneIdentifierUpdater + from malco.post_process.df_save_util import safe_save_tsv from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo -from oaklib import get_adapter def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: @@ -23,15 +22,15 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: Returns: dict: Contents of the raw result file. """ - with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list + with open(raw_result_path, "r") as raw_result: + return list( + yaml.safe_load_all(raw_result.read().replace("\x04", "")) + ) # Load and convert to list -def create_standardised_results(curategpt: bool, - raw_results_dir: Path, - output_dir: Path, - output_file_name: str - ) -> pd.DataFrame: +def create_standardised_results( + curategpt: bool, raw_results_dir: Path, output_dir: Path, output_file_name: str +) -> pd.DataFrame: data = [] if curategpt: @@ -45,24 +44,28 @@ def create_standardised_results(curategpt: bool, for this_result in all_results: extracted_object = this_result.get("extracted_object") if extracted_object: - label = extracted_object.get('label') - terms = extracted_object.get('terms') + label = extracted_object.get("label") + terms = extracted_object.get("terms") if curategpt and terms: ontogpt_text = this_result.get("input_text") # its a single string, should be parseable through curategpt cleaned_text = clean_service_answer(ontogpt_text) assert cleaned_text != "", "Cleaning failed: the cleaned text is empty." - result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False) - # terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream + result = ground_diagnosis_text_to_mondo( + annotator, cleaned_text, verbose=False, include_list=["MONDO:"] + ) + # terms will now ONLY contain MONDO IDs OR 'N/A'. + # The latter should be dealt with downstream terms = [i[1][0][0] for i in result] # MONDO_ID if terms: # Note, the if allows for rerunning ppkts that failed due to connection issues - # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field + # We can have multiple identical ppkts/prompts in results.yaml + # as long as only one has a terms field num_terms = len(terms) score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank rank_list = [i + 1 for i in range(num_terms)] for term, scr, rank in zip(terms, score, rank_list): - data.append({'label': label, 'term': term, 'score': scr, 'rank': rank}) + data.append({"label": label, "term": term, "score": scr, "rank": rank}) # Create DataFrame df = pd.DataFrame(data) @@ -76,6 +79,7 @@ def create_standardised_results(curategpt: bool, # these are from the template and not currently used outside of tests + def read_raw_result(raw_result_path: Path) -> List[dict]: """ Read the raw result file. diff --git a/src/malco/post_process/ranking_utils.py b/src/malco/post_process/ranking_utils.py index 29dbffe1..e64323a3 100644 --- a/src/malco/post_process/ranking_utils.py +++ b/src/malco/post_process/ranking_utils.py @@ -1,30 +1,26 @@ -import os import csv -from pathlib import Path +import os from datetime import datetime -import pandas as pd -import numpy as np -import pickle as pkl -import shutil +from pathlib import Path -from oaklib.interfaces import OboGraphInterface -from oaklib.datamodels.vocabulary import IS_A -from oaklib.interfaces import MappingProviderInterface +import numpy as np +import pandas as pd +from cachetools import LRUCache +from cachetools.keys import hashkey from oaklib import get_adapter +from oaklib.interfaces import OboGraphInterface +from shelved_cache import PersistentCache from malco.post_process.df_save_util import safe_save_tsv from malco.post_process.mondo_score_utils import score_grounded_result -from cachetools import LRUCache -from typing import List -from cachetools.keys import hashkey -from shelved_cache import PersistentCache FULL_SCORE = 1.0 PARTIAL_SCORE = 0.5 def cache_info(self): - return f"CacheInfo: hits={self.hits}, misses={self.misses}, maxsize={self.wrapped.maxsize}, currsize={self.wrapped.currsize}" + return f"CacheInfo: hits={self.hits}, misses={self.misses}, maxsize={self.wrapped.maxsize}\ + , currsize={self.wrapped.currsize}" def mondo_adapter() -> OboGraphInterface: @@ -62,13 +58,12 @@ def compute_mrr_and_ranks( pc2.hits = pc2.misses = 0 PersistentCache.cache_info = cache_info - mode_index = 0 - for subdir, dirs, files in os.walk(output_dir): + for subdir, _dirs, files in os.walk(output_dir): for filename in files: if filename.startswith("result") and filename.endswith(".tsv"): file_path = os.path.join(subdir, filename) df = pd.read_csv(file_path, sep="\t") - num_ppkt[subdir.split('/')[-1]] = df["label"].nunique() + num_ppkt[subdir.split("/")[-1]] = df["label"].nunique() results_data.append(df) # Append both the subdirectory relative to output_dir and the filename results_files.append(os.path.relpath(file_path, output_dir)) @@ -82,12 +77,27 @@ def compute_mrr_and_ranks( label_to_correct_term = answers.set_index("label")["term"].to_dict() # Calculate the Mean Reciprocal Rank (MRR) for each file mrr_scores = [] - header = [comparing, "n1", "n2", "n3", "n4", "n5", "n6", "n7", "n8", "n9", "n10", "n10p", "nf", 'num_cases'] + header = [ + comparing, + "n1", + "n2", + "n3", + "n4", + "n5", + "n6", + "n7", + "n8", + "n9", + "n10", + "n10p", + "nf", + "num_cases", + ] rank_df = pd.DataFrame(0, index=np.arange(len(results_files)), columns=header) cache_file = out_caches / "cache_log.txt" - with cache_file.open('a', newline='') as cf: + with cache_file.open("a", newline="") as cf: now_is = datetime.now().strftime("%Y%m%d-%H%M%S") cf.write("Timestamp: " + now_is + "\n\n") mondo = mondo_adapter() @@ -97,7 +107,8 @@ def compute_mrr_and_ranks( # For each label in the results file, find if the correct term is ranked df["rank"] = df.groupby("label")["score"].rank(ascending=False, method="first") label_4_non_eng = df["label"].str.replace( - "_[a-z][a-z]-prompt", "_en-prompt", regex=True) # TODO is bug here? + "_[a-z][a-z]-prompt", "_en-prompt", regex=True + ) # TODO is bug here? # df['correct_term'] is an OMIM # df['term'] is Mondo or OMIM ID, or even disease label @@ -105,22 +116,22 @@ def compute_mrr_and_ranks( # Make sure caching is used in the following by unwrapping explicitly results = [] - for idx, row in df.iterrows(): + for _idx, row in df.iterrows(): # call OAK and get OMIM IDs for df['term'] and see if df['correct_term'] is one of them # in the case of phenotypic series, if Mondo corresponds to grouping term, accept it - k = hashkey(row['term'], row['correct_term']) + k = hashkey(row["term"], row["correct_term"]) try: val = pc2[k] pc2.hits += 1 except KeyError: # cache miss - val = score_grounded_result(row['term'], row['correct_term'], mondo, pc1) + val = score_grounded_result(row["term"], row["correct_term"], mondo, pc1) pc2[k] = val pc2.misses += 1 is_correct = val > 0 results.append(is_correct) - df['is_correct'] = results + df["is_correct"] = results df["reciprocal_rank"] = df.apply( lambda row: 1 / row["rank"] if row["is_correct"] else 0, axis=1 ) @@ -138,7 +149,6 @@ def compute_mrr_and_ranks( rank_df.loc[i, comparing] = results_files[i].split("/")[0] ppkts = df.groupby("label")[["rank", "is_correct"]] - index_matches = df.index[df['is_correct']] # for each group for ppkt in ppkts: @@ -148,8 +158,8 @@ def compute_mrr_and_ranks( rank_df.loc[i, "nf"] += 1 else: # yes --> what's it rank? It's - jind = ppkt[1].index[ppkt[1]['is_correct']] - j = int(ppkt[1]['rank'].loc[jind].values[0]) + jind = ppkt[1].index[ppkt[1]["is_correct"]] + j = int(ppkt[1]["rank"].loc[jind].values[0]) if j < 11: # increase n rank_df.loc[i, "n" + str(j)] += 1 @@ -159,18 +169,18 @@ def compute_mrr_and_ranks( # Write cache charatcteristics to file cf.write(results_files[i]) - cf.write('\nscore_grounded_result cache info:\n') + cf.write("\nscore_grounded_result cache info:\n") cf.write(str(pc2.cache_info())) - cf.write('\nomim_mappings cache info:\n') + cf.write("\nomim_mappings cache info:\n") cf.write(str(pc1.cache_info())) - cf.write('\n\n') + cf.write("\n\n") i = i + 1 pc1.close() pc2.close() for modelname in num_ppkt.keys(): - rank_df.loc[rank_df['model'] == modelname, 'num_cases'] = num_ppkt[modelname] + rank_df.loc[rank_df["model"] == modelname, "num_cases"] = num_ppkt[modelname] data_dir = output_dir / "rank_data" data_dir.mkdir(exist_ok=True) topn_file_name = "topn_result.tsv" @@ -182,22 +192,37 @@ def compute_mrr_and_ranks( mrr_file = data_dir / "mrr_result.tsv" # write out results for plotting - with mrr_file.open('w', newline='') as dat: - writer = csv.writer(dat, quoting=csv.QUOTE_NONNUMERIC, delimiter='\t', lineterminator='\n') + with mrr_file.open("w", newline="") as dat: + writer = csv.writer(dat, quoting=csv.QUOTE_NONNUMERIC, delimiter="\t", lineterminator="\n") writer.writerow(results_files) writer.writerow(mrr_scores) - df = pd.read_csv(topn_file, delimiter='\t') - df["top1"] = (df['n1']) / df["num_cases"] + df = pd.read_csv(topn_file, delimiter="\t") + df["top1"] = (df["n1"]) / df["num_cases"] df["top3"] = (df["n1"] + df["n2"] + df["n3"]) / df["num_cases"] df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"]) / df["num_cases"] - df["top10"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"] + df["n6"] + - df["n7"] + df["n8"] + df["n9"] + df["n10"]) / df["num_cases"] + df["top10"] = ( + df["n1"] + + df["n2"] + + df["n3"] + + df["n4"] + + df["n5"] + + df["n6"] + + df["n7"] + + df["n8"] + + df["n9"] + + df["n10"] + ) / df["num_cases"] df["not_found"] = (df["nf"]) / df["num_cases"] df_aggr = pd.DataFrame() - df_aggr = pd.melt(df, id_vars=comparing, value_vars=[ - "top1", "top3", "top5", "top10", "not_found"], var_name="Rank_in", value_name="percentage") + df_aggr = pd.melt( + df, + id_vars=comparing, + value_vars=["top1", "top3", "top5", "top10", "not_found"], + var_name="Rank_in", + value_name="percentage", + ) # If "topn_aggr.tsv" already exists, prepend "old_" # It's the user's responsibility to know only up to 2 versions can exist, then data is lost diff --git a/src/malco/prepare/setup_phenopackets.py b/src/malco/prepare/setup_phenopackets.py deleted file mode 100644 index 41c117f4..00000000 --- a/src/malco/prepare/setup_phenopackets.py +++ /dev/null @@ -1,35 +0,0 @@ -import zipfile -import os -import requests - -phenopacket_zip_url = "https://github.com/monarch-initiative/phenopacket-store/releases/download/0.1.11/all_phenopackets.zip" -# TODO just point to a folder w/ ppkts -phenopacket_dir = "phenopacket-store" - - -def setup_phenopackets(self) -> str: - phenopacket_store_path = os.path.join(self.input_dir, phenopacket_dir) - if os.path.exists(phenopacket_store_path): - print(f"{phenopacket_store_path} exists, skipping download.") - else: - print(f"{phenopacket_store_path} doesn't exist, downloading phenopackets...") - download_phenopackets(self, phenopacket_zip_url, phenopacket_dir) - return phenopacket_store_path - - -def download_phenopackets(self, phenopacket_zip_url, phenopacket_dir): - # Ensure the directory for storing the phenopackets exists - phenopacket_store_path = os.path.join(self.input_dir, phenopacket_dir) - os.makedirs(phenopacket_store_path, exist_ok=True) - - # Download the phenopacket release zip file - response = requests.get(phenopacket_zip_url) - zip_path = os.path.join(self.input_dir, "all_phenopackets.zip") - with open(zip_path, "wb") as f: - f.write(response.content) - print("Download completed.") - - # Unzip the phenopacket release zip file - with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(phenopacket_store_path) - print("Unzip completed.") diff --git a/src/malco/prepare/setup_run_pars.py b/src/malco/prepare/setup_run_pars.py index 0828b21f..2465dca4 100644 --- a/src/malco/prepare/setup_run_pars.py +++ b/src/malco/prepare/setup_run_pars.py @@ -12,7 +12,8 @@ def import_inputdata(self) -> None: "gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-4o" 1, 0 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - Meaning run english prompts with those 4 aforementioned models, and only execute the run function, not postprocess. + Meaning run english prompts with those 4 aforementioned models, and only execute the run function, + not postprocess. Or something like: ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ @@ -21,30 +22,33 @@ def import_inputdata(self) -> None: 0, 1 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - Meaning run multilingual prompts with those 5 aforementioned languages, and only execute the function postprocess, not run. + Meaning run multilingual prompts with those 5 aforementioned languages, + and only execute the function postprocess, not run. """ - with open(self.input_dir / "run_parameters.csv", 'r') as pars: - lines = csv.reader(pars, quoting=csv.QUOTE_NONNUMERIC, delimiter=',', lineterminator='\n') + with open(self.input_dir / "run_parameters.csv", "r") as pars: + lines = csv.reader(pars, quoting=csv.QUOTE_NONNUMERIC, delimiter=",", lineterminator="\n") in_langs = next(lines) in_models = next(lines) in_what_to_run = next(lines) - l = len(in_langs) + ll = len(in_langs) m = len(in_models) - if (l > 1 and m > 1): + if ll > 1 and m > 1: sys.exit("Error, either run multiple languages or models, not both, exiting...") - elif l == 1 and m >= 1: + elif ll == 1 and m >= 1: if in_langs[0] == "en": - self.modality = "several_models" # English and more than 1 model defaults to multiple models + self.modality = ( + "several_models" # English and more than 1 model defaults to multiple models + ) else: if m > 1: sys.exit("Error, only English and multiple models supported, exiting...") else: # m==1 self.modality = "several_languages" # non English defaults to multiple languages - elif l > 1: + elif ll > 1: self.modality = "several_languages" self.languages = tuple(in_langs) self.models = tuple(in_models) - self.do_run_step = in_what_to_run[0] # only run the run part of the code + self.do_run_step = in_what_to_run[0] # only run the run part of the code self.do_postprocess_step = in_what_to_run[1] # only run the postprocess part of the code diff --git a/src/malco/run/run.py b/src/malco/run/run.py index cfc7c14f..9765ea87 100644 --- a/src/malco/run/run.py +++ b/src/malco/run/run.py @@ -1,9 +1,9 @@ -from pathlib import Path import multiprocessing -import subprocess -import shutil import os +import shutil +import subprocess import typing +from pathlib import Path from malco.run.search_ppkts import search_ppkts @@ -13,7 +13,7 @@ def call_ontogpt( raw_results_dir: Path, input_dir: Path, model: str, - modality: typing.Literal['several_languages', 'several_models'], + modality: typing.Literal["several_languages", "several_models"], ) -> None: """ Wrapper used for parallel execution of ontogpt. @@ -28,15 +28,15 @@ def call_ontogpt( Returns: None """ - prompt_dir = f'{input_dir}/prompts/' - if modality == 'several_languages': + prompt_dir = f"{input_dir}/prompts/" + if modality == "several_languages": lang_or_model_dir = lang prompt_dir += f"{lang_or_model_dir}/" - elif modality == 'several_models': + elif modality == "several_models": lang_or_model_dir = model prompt_dir += "en/" else: - raise ValueError('Not permitted run modality!\n') + raise ValueError("Not permitted run modality!\n") selected_indir = search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model_dir) yaml_file = f"{raw_results_dir}/{lang_or_model_dir}/results.yaml" @@ -60,11 +60,12 @@ def call_ontogpt( process = subprocess.Popen(command, shell=True) process.communicate() - # Note: if file.txt.result is empty, what ends up in the yaml is still OK thanks to L39 in post_process_results_format.py + # Note: if file.txt.result is empty, what ends up in the yaml is still OK thanks + # to post_process_results_format.py print(f"Finished command for language {lang} and model {model}") try: - with open(yaml_file, 'r') as file2concat: - with open(old_yaml_file, 'a') as original_file: + with open(yaml_file, "r") as file2concat: + with open(old_yaml_file, "a") as original_file: shutil.copyfileobj(file2concat, original_file) os.remove(yaml_file) except NameError: @@ -73,8 +74,7 @@ def call_ontogpt( pass -def run(self, - max_workers: int = None) -> None: +def run(self, max_workers: int = None) -> None: """ Run the tool to obtain the raw results. @@ -85,7 +85,6 @@ def run(self, langs: Tuple of languages. max_workers: Maximum number of worker processes to use. """ - testdata_dir = self.testdata_dir raw_results_dir = self.raw_results_dir input_dir = self.input_dir langs = self.languages @@ -98,16 +97,26 @@ def run(self, if modality == "several_languages": with multiprocessing.Pool(processes=max_workers) as pool: try: - pool.starmap(call_ontogpt, [(lang, raw_results_dir / "multilingual", - input_dir, "gpt-4o", modality) for lang in langs]) + pool.starmap( + call_ontogpt, + [ + (lang, raw_results_dir / "multilingual", input_dir, "gpt-4o", modality) + for lang in langs + ], + ) except FileExistsError as e: - raise ValueError('Did not clean up after last run, check tmp dir: \n' + e) + raise ValueError("Did not clean up after last run, check tmp dir: \n" + e) if modality == "several_models": # English only many models with multiprocessing.Pool(processes=max_workers) as pool: try: - pool.starmap(call_ontogpt, [("en", raw_results_dir / "multimodel", - input_dir, model, modality) for model in models]) + pool.starmap( + call_ontogpt, + [ + ("en", raw_results_dir / "multimodel", input_dir, model, modality) + for model in models + ], + ) except FileExistsError as e: - raise ValueError('Did not clean up after last run, check tmp dir: \n' + e) + raise ValueError("Did not clean up after last run, check tmp dir: \n" + e) diff --git a/src/malco/run/search_ppkts.py b/src/malco/run/search_ppkts.py index 2cf5467f..de491325 100644 --- a/src/malco/run/search_ppkts.py +++ b/src/malco/run/search_ppkts.py @@ -1,6 +1,6 @@ import os -import yaml import shutil + from malco.post_process.post_process_results_format import read_raw_result_yaml @@ -8,11 +8,11 @@ def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): """ Check what ppkts have already been computed in current output dir, for current run parameters. - ontogpt will run every .txt that is in inputdir, we need a tmp inputdir + ontogpt will run every .txt that is in inputdir, we need a tmp inputdir excluding already run cases. Source of truth is the results.yaml output by ontogpt. Only extracted_object containing terms is considered successfully run. - Note that rerunning + Note that rerunning """ # List of "labels" that are already present in results.yaml iff terms is not None @@ -28,8 +28,8 @@ def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): for this_result in all_results: extracted_object = this_result.get("extracted_object") if extracted_object: - label = extracted_object.get('label') - terms = extracted_object.get('terms') + label = extracted_object.get("label") + terms = extracted_object.get("terms") if terms: # ONLY if terms is non-empty, it was successful files.append(label) @@ -38,7 +38,7 @@ def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): # prompts: ls prompt_dir promptfiles = [] - for (dirpath, dirnames, filenames) in os.walk(prompt_dir): + for _dirpath, _dirnames, filenames in os.walk(prompt_dir): promptfiles.extend(filenames) break diff --git a/src/malco/runner.py b/src/malco/runner.py index 0528c567..9c12db5b 100644 --- a/src/malco/runner.py +++ b/src/malco/runner.py @@ -1,15 +1,14 @@ -from dataclasses import dataclass +import os from pathlib import Path from shutil import rmtree + from pheval.runners.runner import PhEvalRunner -from malco.post_process.ranking_utils import compute_mrr_and_ranks +from malco.post_process.generate_plots import make_plots from malco.post_process.post_process import post_process -from malco.run.run import run -from malco.prepare.setup_phenopackets import setup_phenopackets +from malco.post_process.ranking_utils import compute_mrr_and_ranks from malco.prepare.setup_run_pars import import_inputdata -from malco.post_process.generate_plots import make_plots -import os +from malco.run.run import run class MalcoRunner(PhEvalRunner): @@ -34,18 +33,20 @@ def run(self): print("running with predictor") pass if self.do_run_step: - run(self, - ) + run( + self, + ) # Cleanup tmp_dir = f"{self.input_dir}/prompts/tmp/" if os.path.isdir(tmp_dir): rmtree(tmp_dir) - def post_process(self, - print_plot=True, - prompts_subdir_name="prompts", - correct_answer_file="correct_results.tsv" - ): + def post_process( + self, + print_plot=True, + prompts_subdir_name="prompts", + correct_answer_file="correct_results.tsv", + ): """ Post-process the raw output into PhEval standardised TSV output. """ @@ -61,14 +62,23 @@ def post_process(self, comparing = "model" out_subdir = "multimodel" else: - raise ValueError('Not permitted run modality!\n') + raise ValueError("Not permitted run modality!\n") - mrr_file, data_dir, num_ppkt, topn_aggr_file = compute_mrr_and_ranks(comparing, - output_dir=self.output_dir, - out_subdir=out_subdir, - prompt_dir=os.path.join( - self.input_dir, prompts_subdir_name), - correct_answer_file=correct_answer_file) + mrr_file, data_dir, num_ppkt, topn_aggr_file = compute_mrr_and_ranks( + comparing, + output_dir=self.output_dir, + out_subdir=out_subdir, + prompt_dir=os.path.join(self.input_dir, prompts_subdir_name), + correct_answer_file=correct_answer_file, + ) if print_plot: - make_plots(mrr_file, data_dir, self.languages, num_ppkt, self.models, topn_aggr_file, comparing) + make_plots( + mrr_file, + data_dir, + self.languages, + num_ppkt, + self.models, + topn_aggr_file, + comparing, + ) diff --git a/tox.ini b/tox.ini index 249de6ca..4d402c4d 100644 --- a/tox.ini +++ b/tox.ini @@ -33,15 +33,16 @@ description = Run linters. [testenv:flake8] skip_install = true deps = - flake8<5.0.0 + flake8>5.0.0 flake8-bandit flake8-black flake8-bugbear flake8-colors flake8-isort pep8-naming +# as in doctest env, do not try to enforce anything in src/malco/analysis, stuff there can be messy commands = - flake8 src/ tests/ + flake8 src/malco/runner.py src/malco/run/ src/malco/prepare/ src/malco/post_process/ tests/ description = Run the flake8 tool with several plugins (bandit, docstrings, import order, pep8 naming). [testenv:doctest] @@ -61,28 +62,17 @@ commands = ######################### [flake8] ignore = - E203 - W503 - S311 - #C901 # needs code change so ignoring for now. - #E731 # needs code change so ignoring for now. - #S101 # asserts are fine - #S106 # flags false positives with test_table_filler - #N801 # mixed case is bad but there's a lot of auto-generated code - #N815 # same ^ - S404 # Consider possible security implications associated with the subprocess module. - S603 # Check for execution of untrusted input - subprocess call - #S108 # Probable insecure usage of temp file/directory. - #S307 # Use of possibly insecure function - consider using safer ast.literal_eval. - #S603 # subprocess call - check for execution of untrusted input. - S607 # Starting a process with a partial executable path ["open" in both cases] - #S608 # Possible SQL injection vector through string-based query construction. - #B024 # StreamingWriter is an abstract base class, but it has no abstract methods. - # Remember to use @abstractmethod, @abstractclassmethod and/or @abstractproperty decorators. - #B027 # empty method in an abstract base class, but has no abstract decorator. Consider adding @abstractmethod - #N803 # math-oriented classes can ignore this (e.g. hypergeometric.py) - #N806 # math-oriented classes can ignore this (e.g. hypergeometric.py) - #B019 + E203, + W503, + S311, + S404, + S603, + S607, + E501, + S101, + S403, + S602 + max-line-length = 120 max-complexity = 13