diff --git a/poetry.lock b/poetry.lock index d8b5f277..55716e7f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -336,8 +336,8 @@ files = [ lazy-object-proxy = ">=1.4.0" typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} wrapt = [ - {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, {version = ">=1.11,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.14,<2", markers = "python_version >= \"3.11\""}, ] [[package]] @@ -370,6 +370,21 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "autopep8" +version = "2.3.1" +description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" +optional = false +python-versions = ">=3.8" +files = [ + {file = "autopep8-2.3.1-py2.py3-none-any.whl", hash = "sha256:a203fe0fcad7939987422140ab17a930f684763bf7335bdb6709991dd7ef6c2d"}, + {file = "autopep8-2.3.1.tar.gz", hash = "sha256:8d6c87eba648fdcfc83e29b788910b8643171c395d9c4bcf115ece035b9c9dda"}, +] + +[package.dependencies] +pycodestyle = ">=2.12.0" +tomli = {version = "*", markers = "python_version < \"3.11\""} + [[package]] name = "babel" version = "2.16.0" @@ -2132,8 +2147,8 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" proto-plus = [ - {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -5311,8 +5326,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -6024,8 +6039,8 @@ files = [ annotated-types = ">=0.6.0" pydantic-core = "2.23.4" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -6233,8 +6248,8 @@ files = [ astroid = ">=2.15.8,<=2.17.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ - {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, ] isort = ">=4.2.5,<6" mccabe = ">=0.6,<0.8" @@ -9655,4 +9670,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "aa2ebc7bfc0815f46c4de87f1930ae0b9213c8432d53c4a3185d8073e7612552" +content-hash = "718e461c373c7bb63a0d46b2389e15b20cda1d8889898b6c9a0c3729a4a9aab2" diff --git a/pyproject.toml b/pyproject.toml index bb66ea05..effb49a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ setuptools = "^69.5.1" shelved-cache = "^0.3.1" curategpt = "^0.2.2" psutil = "^6.1.0" +autopep8 = "^2.3.1" [tool.poetry.plugins."pheval.plugins"] template = "malco.runner:MalcoRunner" diff --git a/src/malco/analysis/check_lens.py b/src/malco/analysis/check_lens.py index bc8cfdc9..ca47a492 100644 --- a/src/malco/analysis/check_lens.py +++ b/src/malco/analysis/check_lens.py @@ -1,12 +1,13 @@ -import pandas as pd +import pandas as pd from typing import List import pandas as pd import yaml -#from malco.post_process.post_process_results_format import read_raw_result_yaml +# from malco.post_process.post_process_results_format import read_raw_result_yaml from pathlib import Path import sys + def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: """ Read the raw result file. @@ -18,16 +19,17 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: dict: Contents of the raw result file. """ with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list + return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list + unique_ppkts = {} -#model=str(sys.argv[1]) +# model=str(sys.argv[1]) models = ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4", "gpt-4o"] for model in models: - print("==="*10, "\nEvaluating now: ", model, "\n"+"==="*10) - + print("===" * 10, "\nEvaluating now: ", model, "\n" + "===" * 10) + yamlfile = f"out_openAI_models/raw_results/multimodel/{model}/results.yaml" - all_results=read_raw_result_yaml(yamlfile) + all_results = read_raw_result_yaml(yamlfile) counter = 0 labelvec = [] @@ -64,4 +66,4 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: if i in unique_ppkts["gpt-3.5-turbo"]: continue else: - print(f"Missing ppkt in gpt-3.5-turbo is:\t", i) \ No newline at end of file + print(f"Missing ppkt in gpt-3.5-turbo is:\t", i) diff --git a/src/malco/analysis/count_grounding_failures.py b/src/malco/analysis/count_grounding_failures.py index c3e2b06b..262d258b 100644 --- a/src/malco/analysis/count_grounding_failures.py +++ b/src/malco/analysis/count_grounding_failures.py @@ -4,8 +4,8 @@ mfile = "../outputdir_all_2024_07_04/en/results.tsv" df = pd.read_csv( - mfile, sep="\t" #, header=None, names=["description", "term", "label"] - ) + mfile, sep="\t" # , header=None, names=["description", "term", "label"] +) terms = df["term"] counter = 0 @@ -17,4 +17,4 @@ counter += 1 print(counter) -print(grounded) \ No newline at end of file +print(grounded) diff --git a/src/malco/analysis/count_translated_prompts.py b/src/malco/analysis/count_translated_prompts.py index 869870b6..f6d778ca 100644 --- a/src/malco/analysis/count_translated_prompts.py +++ b/src/malco/analysis/count_translated_prompts.py @@ -14,9 +14,9 @@ promptfiles = {} for lang in langs: promptfiles[lang] = [] - for (dirpath, dirnames, filenames) in os.walk(fp+lang): + for (dirpath, dirnames, filenames) in os.walk(fp + lang): for fn in filenames: - fn = fn[0:-14] # TODO may be problematic if there are 2 "_" before "{langcode}-" + fn = fn[0:-14] # TODO may be problematic if there are 2 "_" before "{langcode}-" # Maybe something along the lines of other script disease_avail_knowledge.py # ppkt_label = ppkt[0].replace('_en-prompt.txt','') promptfiles[lang].append(fn) @@ -34,4 +34,4 @@ intersection = enset & esset & deset & itset & nlset & zhset & trset -print("Common ppkts are: ", len(intersection)) \ No newline at end of file +print("Common ppkts are: ", len(intersection)) diff --git a/src/malco/analysis/eval_diagnose_category.py b/src/malco/analysis/eval_diagnose_category.py index 3a8c1b9e..ba7851bb 100644 --- a/src/malco/analysis/eval_diagnose_category.py +++ b/src/malco/analysis/eval_diagnose_category.py @@ -17,9 +17,11 @@ outpath = "disease_groups/" pc_cache_file = outpath + "diagnoses_hereditary_cond" -pc = PersistentCache(LRUCache, pc_cache_file, maxsize=4096) - +pc = PersistentCache(LRUCache, pc_cache_file, maxsize=4096) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + def mondo_adapter() -> OboGraphInterface: """ Get the adapter for the MONDO ontology. @@ -27,10 +29,12 @@ def mondo_adapter() -> OboGraphInterface: Returns: Adapter: The adapter. """ - return get_adapter("sqlite:obo:mondo") + return get_adapter("sqlite:obo:mondo") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -def mondo_mapping(term, adapter): + + +def mondo_mapping(term, adapter): mondos = [] for m in adapter.sssom_mappings([term], source="OMIM"): if m.predicate_id == "skos:exactMatch": @@ -38,6 +42,8 @@ def mondo_mapping(term, adapter): return mondos # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + @cached(pc, key=lambda omim_term, disease_categories, mondo: hashkey(omim_term)) def find_category(omim_term, disease_categories, mondo): if not isinstance(mondo, MappingProviderInterface): @@ -47,49 +53,51 @@ def find_category(omim_term, disease_categories, mondo): if not mondo_term: print(omim_term) return None - - ancestor_list = mondo.ancestors(mondo_term, # only IS_A->same result - predicates=[IS_A, PART_OF]) #, reflexive=True) # method=GraphTraversalMethod.ENTAILMENT - + + ancestor_list = mondo.ancestors(mondo_term, # only IS_A->same result + # , reflexive=True) # method=GraphTraversalMethod.ENTAILMENT + predicates=[IS_A, PART_OF]) + for mondo_ancestor in ancestor_list: if mondo_ancestor in disease_categories: - #TODO IMPORTANT! Like this, at the first match the function exits!! - return mondo_ancestor # This should be smt like MONDO:0045024 (cancer or benign tumor) - + # TODO IMPORTANT! Like this, at the first match the function exits!! + return mondo_ancestor # This should be smt like MONDO:0045024 (cancer or benign tumor) + print("Special issue following: ") print(omim_term) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#===================================================== +# ===================================================== # Script starts here. Name model: -model=str(sys.argv[1]) -#===================================================== +model = str(sys.argv[1]) +# ===================================================== # Find 42 diseases categories mondo = mondo_adapter() -disease_categories = mondo.relationships(objects = ["MONDO:0003847"], # hereditary diseases +disease_categories = mondo.relationships(objects=["MONDO:0003847"], # hereditary diseases predicates=[IS_A, PART_OF]) # only IS_A->same result -#disease_categories = mondo.relationships(objects = ["MONDO:0700096"], # only IS_A->same result +# disease_categories = mondo.relationships(objects = ["MONDO:0700096"], # only IS_A->same result # predicates=[IS_A, PART_OF]) # make df contingency table with header=diseases_category, correct, incorrect and initialize all to 0. -header = ["label","correct", "incorrect"] +header = ["label", "correct", "incorrect"] dc_list = [i[0] for i in list(disease_categories)] contingency_table = pd.DataFrame(0, index=dc_list, columns=header) for j in dc_list: - contingency_table.loc[j,"label"] = mondo.label(j) + contingency_table.loc[j, "label"] = mondo.label(j) breakpoint() filename = f"out_openAI_models/multimodel/{model}/full_df_results.tsv" # label term score rank correct_term is_correct reciprocal_rank # PMID_35962790_Family_B_Individual_3__II_6__en-prompt.txt MONDO:0008675 1.0 1.0 OMIM:620545 False 0.0 df = pd.read_csv( - filename, sep="\t" - ) + filename, sep="\t" +) -ppkts = df.groupby("label")[["term", "correct_term", "is_correct"]] -count_fails=0 +ppkts = df.groupby("label")[["term", "correct_term", "is_correct"]] +count_fails = 0 omim_wo_match = {} for ppkt in ppkts: @@ -97,11 +105,11 @@ def find_category(omim_term, disease_categories, mondo): category_index = find_category(ppkt[1].iloc[0]["correct_term"], dc_list, mondo) if not category_index: count_fails += 1 - #print(f"Category index for {ppkt[1].iloc[0]["correct_term"]} ") + # print(f"Category index for {ppkt[1].iloc[0]["correct_term"]} ") omim_wo_match[ppkt[0]] = ppkt[1].iloc[0]["correct_term"] continue - #cat_ind = find_cat_index(category) - # is there a true? ppkt is tuple ("filename"/"label"/what has been used for grouping, dataframe) --> ppkt[1] is a dataframe + # cat_ind = find_cat_index(category) + # is there a true? ppkt is tuple ("filename"/"label"/what has been used for grouping, dataframe) --> ppkt[1] is a dataframe if not any(ppkt[1]["is_correct"]): # no --> increase incorrect try: @@ -117,12 +125,12 @@ def find_category(omim_term, disease_categories, mondo): print("issue here") continue -print("\n\n", "==="*15,"\n") -print(f"For whatever reason find_category() returned None in {count_fails} cases, wich follow:\n") # print to file! -#print(contingency_table) +print("\n\n", "===" * 15, "\n") +print(f"For whatever reason find_category() returned None in {count_fails} cases, wich follow:\n") # print to file! +# print(contingency_table) print("\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) -#print(omim_wo_match, "\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) +# print(omim_wo_match, "\n\nOf which the following are unique OMIMs:\n", set(list(omim_wo_match.values()))) cont_table_file = f"{outpath}{model}.tsv" # Will overwrite -#contingency_table.to_csv(cont_table_file, sep='\t') \ No newline at end of file +# contingency_table.to_csv(cont_table_file, sep='\t') diff --git a/src/malco/analysis/monarchKG_classifier.py b/src/malco/analysis/monarchKG_classifier.py index c628efa6..7fb5a03e 100644 --- a/src/malco/analysis/monarchKG_classifier.py +++ b/src/malco/analysis/monarchKG_classifier.py @@ -1,8 +1,8 @@ -# Monarch KG -# Idea: for each ppkt, make contingency table NF/F and in box write +# Monarch KG +# Idea: for each ppkt, make contingency table NF/F and in box write # average number of connections. Thus 7 K of entries with num_edges, y=0,1 # Think about mouse weight and obesity as an example. -import numpy +import numpy from neo4j import GraphDatabase # Connect to the Neo4j database @@ -11,7 +11,7 @@ # From results take ppkts ground truth correct result and 0,1 # Map OMIM to MONDO -# +# # Need to decide what to project out. Maybe simply all edges connected to the MONDO terms I have. # At this point for each MONDO term I have count the edges # Define the Cypher query @@ -26,4 +26,4 @@ with driver.session() as session: results = session.run(query) for record in results: - data.append(record) \ No newline at end of file + data.append(record) diff --git a/src/malco/analysis/test_curate_script.py b/src/malco/analysis/test_curate_script.py index a83f20de..a5e98db0 100644 --- a/src/malco/analysis/test_curate_script.py +++ b/src/malco/analysis/test_curate_script.py @@ -1,9 +1,9 @@ -import yaml +import yaml from pathlib import Path from typing import List from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo from oaklib import get_adapter - + def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: """ @@ -16,7 +16,7 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: dict: Contents of the raw result file. """ with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list + return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list annotator = get_adapter("sqlite:obo:mondo") @@ -29,25 +29,25 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: j = 0 for this_result in all_results: extracted_object = this_result.get("extracted_object") - if extracted_object: # Necessary because this is how I keep track of multiple runs + if extracted_object: # Necessary because this is how I keep track of multiple runs ontogpt_text = this_result.get("input_text") # its a single string, should be parseable through curategpt cleaned_text = clean_service_answer(ontogpt_text) assert cleaned_text != "", "Cleaning failed: the cleaned text is empty." result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False) - label = extracted_object.get('label') # pubmed id + label = extracted_object.get('label') # pubmed id # terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream - terms = [i[1][0][0] for i in result] - #terms = extracted_object.get('terms') # list of strings, the mondo id or description + terms = [i[1][0][0] for i in result] + # terms = extracted_object.get('terms') # list of strings, the mondo id or description if terms: - # Note, the if allows for rerunning ppkts that failed due to connection issues - # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field + # Note, the if allows for rerunning ppkts that failed due to connection issues + # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field num_terms = len(terms) score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank - rank_list = [ i+1 for i in range(num_terms)] + rank_list = [i + 1 for i in range(num_terms)] for term, scr, rank in zip(terms, score, rank_list): data.append({'label': label, 'term': term, 'score': scr, 'rank': rank}) - if j>20: + if j > 20: break j += 1 diff --git a/src/malco/analysis/time_ic/disease_avail_knowledge.py b/src/malco/analysis/time_ic/disease_avail_knowledge.py index 0eee3a70..04ba475b 100644 --- a/src/malco/analysis/time_ic/disease_avail_knowledge.py +++ b/src/malco/analysis/time_ic/disease_avail_knowledge.py @@ -14,6 +14,10 @@ `runoak -g hpoa_file -G hpoa -i hpo_file information-content -p i --use-associations .all` """ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +from scipy.stats import mannwhitneyu +from scipy.stats import ttest_ind +from scipy.stats import kstest +from scipy.stats import chi2_contingency import sys import os import pandas as pd @@ -32,10 +36,10 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ model = str(sys.argv[1]) try: - make_plots = str(sys.argv[2])=="plot" + make_plots = str(sys.argv[2]) == "plot" except IndexError: - make_plots = False - print("\nYou can pass \"plot\" as a second CLI argument and this will generate nice plots!\n") + make_plots = False + print("\nYou can pass \"plot\" as a second CLI argument and this will generate nice plots!\n") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # PATHS: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -43,7 +47,7 @@ hpoa_file_path = data_dir / "phenotype.hpoa" ic_file = data_dir / "ic_hpoa.txt" original_ppkt_dir = data_dir / "ppkt-store-0.1.19" -outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" +outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" ranking_results_filename = f"out_openAI_models/multimodel/{model}/full_df_results.tsv" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -53,8 +57,8 @@ rank_date_dict = pickle.load(f) # import df of LLM results rank_results_df = pd.read_csv( - ranking_results_filename, sep="\t" - ) + ranking_results_filename, sep="\t" +) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -62,70 +66,69 @@ # Look for correlation in box plot of ppkts' rank vs time # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ dates = [] -dates_wo_none = [] # messy workaround +dates_wo_none = [] # messy workaround ranks = [] ranks_wo_none = [] for key, data in rank_date_dict.items(): - r = data[0] - d = dt.datetime.strptime(data[1], '%Y-%m-%d').date() - dates.append(d) - ranks.append(r) - if r is not None: - dates_wo_none.append(d) - ranks_wo_none.append(r) - - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# Correlation? + r = data[0] + d = dt.datetime.strptime(data[1], '%Y-%m-%d').date() + dates.append(d) + ranks.append(r) + if r is not None: + dates_wo_none.append(d) + ranks_wo_none.append(r) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Correlation? years_only = [] -for i in range(len(dates)): - years_only.append(dates[i].year) +for i in range(len(dates)): + years_only.append(dates[i].year) years_only_wo_none = [] -for i in range(len(dates_wo_none)): - years_only_wo_none.append(dates[i].year) +for i in range(len(dates_wo_none)): + years_only_wo_none.append(dates[i].year) if make_plots: - sns.boxplot(x=years_only_wo_none, y=ranks_wo_none) - plt.xlabel("Year of HPOA annotation") - plt.ylabel("Rank") - plt.title("LLM performance uncorrelated with date of discovery") - plt.savefig(outdir / "boxplot_discovery_date.png") + sns.boxplot(x=years_only_wo_none, y=ranks_wo_none) + plt.xlabel("Year of HPOA annotation") + plt.ylabel("Rank") + plt.title("LLM performance uncorrelated with date of discovery") + plt.savefig(outdir / "boxplot_discovery_date.png") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Statistical test, simplest idea: chi2 of contingency table with: # y<=2009 and y>2009 clmns and found vs not-found counts, one count per ppkt # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -cont_table = [[0, 0], [0, 0]] # contains counts +cont_table = [[0, 0], [0, 0]] # contains counts for i, d in enumerate(years_only): - if d < 2010: - if ranks[i] == None: - cont_table[0][1] += 1 - else: - cont_table[0][0] += 1 - else: - if ranks[i] == None: - cont_table[1][1] += 1 - else: - cont_table[1][0] += 1 - -df_contingency_table = pd.DataFrame(cont_table, - index=["found", "not_found"], + if d < 2010: + if ranks[i] == None: + cont_table[0][1] += 1 + else: + cont_table[0][0] += 1 + else: + if ranks[i] == None: + cont_table[1][1] += 1 + else: + cont_table[1][0] += 1 + +df_contingency_table = pd.DataFrame(cont_table, + index=["found", "not_found"], columns=["y<2010", "y>=2010"]) print(df_contingency_table) print("H0: no correlation between column 1 and 2:") -from scipy.stats import chi2_contingency -res = chi2_contingency(cont_table) +res = chi2_contingency(cont_table) print("Results from \u03c7\N{SUPERSCRIPT TWO} test on contingency table:\n", res) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# IC: For each phenpacket, list observed HPOs and compute average IC. Is it correlated with +# IC: For each phenpacket, list observed HPOs and compute average IC. Is it correlated with # success? I.e., start with f/nf, 1/0 on y-axis vs avg(IC) on x-axis # Import ppkts -ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] +ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] # Import IC-file as dict with open(ic_file) as f: @@ -136,62 +139,64 @@ ppkts_with_zero_hpos = [] ppkts_with_missing_hpos = [] -# Iterate over ppkts, which are json. +# Iterate over ppkts, which are json. for subdir, dirs, files in os.walk(original_ppkt_dir): - # For each ppkt - for filename in files: - if filename.endswith('.json'): - file_path = os.path.join(subdir, filename) - with open(file_path, mode="r", encoding="utf-8") as read_file: - ppkt = json.load(read_file) - ppkt_id = re.sub('[^\\w]', '_', ppkt['id']) - ic = 0 - num_hpos = 0 - # For each HPO - for i in ppkt['phenotypicFeatures']: + # For each ppkt + for filename in files: + if filename.endswith('.json'): + file_path = os.path.join(subdir, filename) + with open(file_path, mode="r", encoding="utf-8") as read_file: + ppkt = json.load(read_file) + ppkt_id = re.sub('[^\\w]', '_', ppkt['id']) + ic = 0 + num_hpos = 0 + # For each HPO + for i in ppkt['phenotypicFeatures']: + try: + if i["excluded"]: # skip excluded + continue + except KeyError: + pass + hpo = i["type"]["id"] + try: + ic += float(ic_dict[hpo]) + num_hpos += 1 + except KeyError as e: + missing_in_ic_dict.append(e.args[0]) + ppkts_with_missing_hpos.append(ppkt_id) + + # print(f"No entry for {e}.") + + # For now we are fine with average IC try: - if i["excluded"]: # skip excluded - continue - except KeyError: - pass - hpo = i["type"]["id"] - try: - ic += float(ic_dict[hpo]) - num_hpos += 1 - except KeyError as e: - missing_in_ic_dict.append(e.args[0]) - ppkts_with_missing_hpos.append(ppkt_id) - - #print(f"No entry for {e}.") - - # For now we are fine with average IC - try: - ppkt_ic[ppkt_id] = ic/num_hpos - # TODO max ic instead try - except ZeroDivisionError as e: - ppkts_with_zero_hpos.append(ppkt_id) - #print(f"No HPOs for {ppkt["id"]}.") - + ppkt_ic[ppkt_id] = ic / num_hpos + # TODO max ic instead try + except ZeroDivisionError as e: + ppkts_with_zero_hpos.append(ppkt_id) + # print(f"No HPOs for {ppkt["id"]}.") + missing_in_ic_dict_unique = set(missing_in_ic_dict) ppkts_with_missing_hpos = set(ppkts_with_missing_hpos) -print(f"\nNumber of (unique) HPOs without IC-value is {len(missing_in_ic_dict_unique)}.") # 65 -print(f"Number of ppkts with zero observed HPOs is {len(ppkts_with_zero_hpos)}. These are left out.") # 141 -#TODO check 141 -print(f"Number of ppkts where at least one HPO is missing its IC value is {len(ppkts_with_missing_hpos)}. These are left out from the average.\n") # 172 +print(f"\nNumber of (unique) HPOs without IC-value is {len(missing_in_ic_dict_unique)}.") # 65 +print(f"Number of ppkts with zero observed HPOs is {len(ppkts_with_zero_hpos)}. These are left out.") # 141 +# TODO check 141 +# 172 +print( + f"Number of ppkts where at least one HPO is missing its IC value is {len(ppkts_with_missing_hpos)}. These are left out from the average.\n") ppkt_ic_df = pd.DataFrame.from_dict(ppkt_ic, orient='index', columns=['avg(IC)']) -ppkt_ic_df['Diagnosed'] = 0 +ppkt_ic_df['Diagnosed'] = 0 for ppkt in ppkts: - if any(ppkt[1]["is_correct"]): - ppkt_label = ppkt[0].replace('_en-prompt.txt','') - if ppkt_label in ppkts_with_zero_hpos: - continue - ppkt_ic_df.loc[ppkt_label,'Diagnosed'] = 1 + if any(ppkt[1]["is_correct"]): + ppkt_label = ppkt[0].replace('_en-prompt.txt', '') + if ppkt_label in ppkts_with_zero_hpos: + continue + ppkt_ic_df.loc[ppkt_label, 'Diagnosed'] = 1 # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # See https://github.com/monarch-initiative/phenopacket-store/issues/157 -label_manual_removal = ["PMID_27764983_Family_1_individual__J", +label_manual_removal = ["PMID_27764983_Family_1_individual__J", "PMID_35991565_Family_I__3"] ppkt_ic_df = ppkt_ic_df.drop(label_manual_removal) # xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx @@ -205,39 +210,36 @@ # T-test: unlikely that the two samples are such due to sample bias. # Likely, there is a correlation between average IC and whether the case is being solved. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# One-sample Kolmogorov-Smirnov test: compares the underlying distribution F(x) of a sample +# One-sample Kolmogorov-Smirnov test: compares the underlying distribution F(x) of a sample # against a given distribution G(x), here the normal distribution -from scipy.stats import kstest -kolsmirnov_result = kstest(ppkt_ic_df['avg(IC)'],'norm') -print(kolsmirnov_result,"\n") -# When interpreting be careful, and I quote one answer from: +kolsmirnov_result = kstest(ppkt_ic_df['avg(IC)'], 'norm') +print(kolsmirnov_result, "\n") +# When interpreting be careful, and I quote one answer from: # https://stats.stackexchange.com/questions/2492/is-normality-testing-essentially-useless -# "The question normality tests answer: Is there convincing -# evidence of any deviation from the Gaussian ideal? -# With moderately large real data sets, the answer is almost always yes." +# "The question normality tests answer: Is there convincing +# evidence of any deviation from the Gaussian ideal? +# With moderately large real data sets, the answer is almost always yes." # --------------- t-test --------------- # TODO: regardless of above, our distributions are not normal. Discard t and add non-parametric U-test (Mann-Whitney) -from scipy.stats import ttest_ind -found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed']>0, 'avg(IC)']) -not_found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed']<1, 'avg(IC)']) +found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed'] > 0, 'avg(IC)']) +not_found_ic = list(ppkt_ic_df.loc[ppkt_ic_df['Diagnosed'] < 1, 'avg(IC)']) tresult = ttest_ind(found_ic, not_found_ic, equal_var=False) -print("T-test result:\n", tresult,"\n") +print("T-test result:\n", tresult, "\n") # --------------- u-test --------------- -from scipy.stats import mannwhitneyu u_value, p_of_u = mannwhitneyu(found_ic, not_found_ic) -print(f"U-test, u_value={u_value} and its associated p_val={p_of_u}","\n") +print(f"U-test, u_value={u_value} and its associated p_val={p_of_u}", "\n") # --------------- plot --------------- if make_plots: - plt.hist(found_ic, bins=25, color='c', edgecolor='k', alpha=0.5, density=True) - plt.hist(not_found_ic, bins=25, color='r', edgecolor='k', alpha=0.5, density=True) - plt.xlabel("Average Information Content") - plt.ylabel("Counts") - plt.legend(['Successful Diagnosis', 'Unsuccessful Diagnosis']) - plt.savefig(outdir / "inf_content_histograms.png") + plt.hist(found_ic, bins=25, color='c', edgecolor='k', alpha=0.5, density=True) + plt.hist(not_found_ic, bins=25, color='r', edgecolor='k', alpha=0.5, density=True) + plt.xlabel("Average Information Content") + plt.ylabel("Counts") + plt.legend(['Successful Diagnosis', 'Unsuccessful Diagnosis']) + plt.savefig(outdir / "inf_content_histograms.png") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Still important TODO! @@ -245,6 +247,3 @@ # 1) cont table 8 cells... chi2 test # 2) MRR test --> one way # 3) rank based u-test, max 50, 51 for not found or >50 - - - diff --git a/src/malco/analysis/time_ic/logit_predict_llm.py b/src/malco/analysis/time_ic/logit_predict_llm.py index d7a6f9d3..b26c648b 100644 --- a/src/malco/analysis/time_ic/logit_predict_llm.py +++ b/src/malco/analysis/time_ic/logit_predict_llm.py @@ -11,25 +11,27 @@ # https://towardsdatascience.com/building-a-logistic-regression-in-python-step-by-step-becd4d56c9c8 [1/10/24] import pandas as pd import numpy as np -import pickle +import pickle from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix, classification_report from pathlib import Path # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Path -outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" +outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" # Import ppkt_ic_df = pd.read_csv(outdir / "ppkt_ic.tsv", delimiter='\t', index_col=0) with open(outdir / "rank_date_dict.pkl", 'rb') as f: rank_date_dict = pickle.load(f) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -X_train, X_test, y_train, y_test = train_test_split(ppkt_ic_df[['avg(IC)']], ppkt_ic_df['Diagnosed'], test_size=0.2, random_state=0) +X_train, X_test, y_train, y_test = train_test_split( + ppkt_ic_df[['avg(IC)']], ppkt_ic_df['Diagnosed'], test_size=0.2, random_state=0) logreg = LogisticRegression() logreg.fit(X_train, y_train) y_pred = logreg.predict(X_test) -print('Accuracy of 1 parameter (IC) logistic regression classifier on test set: {:.2f}'.format(logreg.score(X_test, y_test))) +print('Accuracy of 1 parameter (IC) logistic regression classifier on test set: {:.2f}'.format( + logreg.score(X_test, y_test))) cm1d = confusion_matrix(y_test, y_pred) print(cm1d) class_report = classification_report(y_test, y_pred) @@ -50,11 +52,13 @@ date_df.set_index(new_index, inplace=True) ic_date_df = date_df.join(ppkt_ic_df, how='inner') -X_train, X_test, y_train, y_test = train_test_split(ic_date_df[['avg(IC)','date']], ic_date_df['Diagnosed'], test_size=0.2, random_state=0) +X_train, X_test, y_train, y_test = train_test_split( + ic_date_df[['avg(IC)', 'date']], ic_date_df['Diagnosed'], test_size=0.2, random_state=0) logreg = LogisticRegression() logreg.fit(X_train, y_train) y_pred = logreg.predict(X_test) -print('\nAccuracy of 2 PARAMETER (IC and time) logistic regression classifier on test set: {:.2f}'.format(logreg.score(X_test, y_test))) +print('\nAccuracy of 2 PARAMETER (IC and time) logistic regression classifier on test set: {:.2f}'.format( + logreg.score(X_test, y_test))) cm2d = confusion_matrix(y_test, y_pred) print(cm2d) class_report = classification_report(y_test, y_pred) diff --git a/src/malco/analysis/time_ic/rank_date_exploratory.py b/src/malco/analysis/time_ic/rank_date_exploratory.py index 405267d6..37042d4b 100644 --- a/src/malco/analysis/time_ic/rank_date_exploratory.py +++ b/src/malco/analysis/time_ic/rank_date_exploratory.py @@ -13,9 +13,9 @@ >>> python src/malco/analysis/time_ic/diseases_avail_knowledge.py gpt-4o """ -import pickle +import pickle from pathlib import Path -import pandas as pd +import pandas as pd import sys import numpy as np # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -24,17 +24,17 @@ ranking_results_filename = f"out_openAI_models/multimodel/{model}/full_df_results.tsv" data_dir = Path.home() / "data" hpoa_file_path = data_dir / "phenotype.hpoa" -outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" +outdir = Path.cwd() / "src" / "malco" / "analysis" / "time_ic" # (1) HPOA for dates # HPOA import and setup hpoa_df = pd.read_csv( - hpoa_file_path, sep="\t" , header=4, low_memory=False # Necessary to suppress Warning we don't care about - ) + hpoa_file_path, sep="\t", header=4, low_memory=False # Necessary to suppress Warning we don't care about +) labels_to_drop = ["disease_name", "qualifier", "hpo_id", "reference", "evidence", - "onset", "frequency", "sex", "modifier", "aspect"] + "onset", "frequency", "sex", "modifier", "aspect"] hpoa_df = hpoa_df.drop(columns=labels_to_drop) hpoa_df['date'] = hpoa_df["biocuration"].str.extract(r'\[(.*?)\]') @@ -47,40 +47,40 @@ # import df of LLM results rank_results_df = pd.read_csv( - ranking_results_filename, sep="\t" - ) + ranking_results_filename, sep="\t" +) # Go through results data and make set of found vs not found diseases. found_diseases = [] not_found_diseases = [] -rank_date_dict = {} -ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] -for ppkt in ppkts: #TODO 1st for ppkt in ppkts - # ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe - disease = ppkt[1].iloc[0]['correct_term'] - - if any(ppkt[1]["is_correct"]): - found_diseases.append(disease) - index_of_match = ppkt[1]["is_correct"].to_list().index(True) - try: - rank = ppkt[1].iloc[index_of_match]["rank"] # inverse rank does not work well - rank_date_dict[ppkt[0]] = [rank.item(), - hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] - # If ppkt[1].iloc[0]["correct_term"] is nan, then KeyError "e" is nan - except (ValueError, KeyError) as e: - print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") - - else: - not_found_diseases.append(disease) - try: - rank_date_dict[ppkt[0]] = [None, - hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] - except (ValueError, KeyError) as e: - #pass - #TODO collect the below somewhere - print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") - +rank_date_dict = {} +ppkts = rank_results_df.groupby("label")[["term", "correct_term", "is_correct", "rank"]] +for ppkt in ppkts: # TODO 1st for ppkt in ppkts + # ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe + disease = ppkt[1].iloc[0]['correct_term'] + + if any(ppkt[1]["is_correct"]): + found_diseases.append(disease) + index_of_match = ppkt[1]["is_correct"].to_list().index(True) + try: + rank = ppkt[1].iloc[index_of_match]["rank"] # inverse rank does not work well + rank_date_dict[ppkt[0]] = [rank.item(), + hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] + # If ppkt[1].iloc[0]["correct_term"] is nan, then KeyError "e" is nan + except (ValueError, KeyError) as e: + print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") + + else: + not_found_diseases.append(disease) + try: + rank_date_dict[ppkt[0]] = [None, + hpoa_unique.loc[ppkt[1].iloc[0]["correct_term"]]] + except (ValueError, KeyError) as e: + # pass + # TODO collect the below somewhere + print(f"Error {e} for {ppkt[0]}, disease {ppkt[1].iloc[0]['correct_term']}.") + # gpt-4o output, reasonable enough to throw out 62 cases ~1%. 3 OMIMs to check and 3 nan # TODO clean up here # len(rank_date_dict) --> 6625 @@ -101,47 +101,47 @@ overlap = [] for i in found_set: - if i in notfound_set: - overlap.append(i) + if i in notfound_set: + overlap.append(i) print(f"Number of found diseases by {model} is {len(found_set)}.") print(f"Number of not found diseases by {model} is {len(notfound_set)}.") print(f"Diseases sometimes found, sometimes not, by {model} are {len(overlap)}.\n") # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# Look at the 263-129 (gpt-4o) found diseases not present in not-found set ("always found") +# Look at the 263-129 (gpt-4o) found diseases not present in not-found set ("always found") # and the opposite namely "never found" diseases. Average date of two sets is? -always_found = found_set - notfound_set # 134 -never_found = notfound_set - found_set # 213 +always_found = found_set - notfound_set # 134 +never_found = notfound_set - found_set # 213 # meaning 347/476, 27% sometimes found sometimes not, 28% always found, 45% never found. # Compute average date of always vs never found diseases -results_dict = {} # turns out being 281 long +results_dict = {} # turns out being 281 long # TODO get rid of next line, bzw hpoa_unique does not work for loop below hpoa_df.drop_duplicates(subset='database_id', inplace=True) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ for af in always_found: - try: - results_dict[af] = [True, hpoa_df.loc[hpoa_df['database_id'] == af, 'date'].item() ] - #results_dict[af] = [True, hpoa_unique.loc[hpoa_unique['database_id'] == af, 'date'].item() ] - except ValueError: - print(f"No HPOA in always_found for {af}.") + try: + results_dict[af] = [True, hpoa_df.loc[hpoa_df['database_id'] == af, 'date'].item()] + # results_dict[af] = [True, hpoa_unique.loc[hpoa_unique['database_id'] == af, 'date'].item() ] + except ValueError: + print(f"No HPOA in always_found for {af}.") for nf in never_found: - try: - results_dict[nf] = [False, hpoa_df.loc[hpoa_df['database_id'] == nf, 'date'].item() ] - #results_dict[nf] = [False, hpoa_unique.loc[hpoa_unique['database_id'] == nf, 'date'].item() ] - except ValueError: - print(f"No HPOA in never_found for {nf}.") + try: + results_dict[nf] = [False, hpoa_df.loc[hpoa_df['database_id'] == nf, 'date'].item()] + # results_dict[nf] = [False, hpoa_unique.loc[hpoa_unique['database_id'] == nf, 'date'].item() ] + except ValueError: + print(f"No HPOA in never_found for {nf}.") -#TODO No HPOA for ... comes from for ppkt in ppkts, then +# TODO No HPOA for ... comes from for ppkt in ppkts, then # disease = ppkt[1].iloc[0]['correct_term'] res_to_clean = pd.DataFrame.from_dict(results_dict).transpose() -res_to_clean.columns=["found","date"] +res_to_clean.columns = ["found", "date"] res_to_clean['date'] = pd.to_datetime(res_to_clean.date).values.astype(np.int64) final_avg = pd.DataFrame(pd.to_datetime(res_to_clean.groupby('found').mean()['date'])) print(final_avg) -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \ No newline at end of file +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/malco/post_process/df_save_util.py b/src/malco/post_process/df_save_util.py index df6479ed..bd140b71 100644 --- a/src/malco/post_process/df_save_util.py +++ b/src/malco/post_process/df_save_util.py @@ -2,6 +2,7 @@ import os import pandas as pd + def safe_save_tsv(path, filename, df): full_path = path / filename # If full_path already exists, prepend "old_" @@ -12,4 +13,4 @@ def safe_save_tsv(path, filename, df): os.remove(old_full_path) shutil.copy(full_path, old_full_path) os.remove(full_path) - df.to_csv(full_path, sep='\t', index=False) \ No newline at end of file + df.to_csv(full_path, sep='\t', index=False) diff --git a/src/malco/post_process/extended_scoring.py b/src/malco/post_process/extended_scoring.py index d60fdaf5..2fae6841 100644 --- a/src/malco/post_process/extended_scoring.py +++ b/src/malco/post_process/extended_scoring.py @@ -10,6 +10,8 @@ dd_re = re.compile(r"^[^A-z]*Differential Diagnosis") # Function to clean and remove "Differential Diagnosis" header if present + + def clean_service_answer(answer: str) -> str: """Remove the 'Differential Diagnosis' header if present, and clean the first line.""" lines = answer.split('\n') @@ -18,6 +20,8 @@ def clean_service_answer(answer: str) -> str: return '\n'.join(cleaned_lines) # Clean the diagnosis line by removing leading numbers, periods, asterisks, and spaces + + def clean_diagnosis_line(line: str) -> str: """Remove leading numbers, asterisks, and unnecessary punctuation/spaces from the diagnosis.""" line = re.sub(r'^\**\d+\.\s*', '', line) # Remove leading numbers and periods @@ -25,6 +29,8 @@ def clean_diagnosis_line(line: str) -> str: return line.strip() # Strip any remaining spaces # Split a diagnosis into its main name and synonym if present + + def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]: """Split the diagnosis into main name and synonym (if present in parentheses).""" match = re.match(r'^(.*)\s*\((.*)\)\s*$', diagnosis) @@ -33,6 +39,7 @@ def split_diagnosis_and_synonym(diagnosis: str) -> Tuple[str, str]: return main_diagnosis.strip(), synonym.strip() return diagnosis, None # Return the original diagnosis if no synonym is found + def perform_curategpt_grounding( diagnosis: str, path: str, @@ -44,7 +51,7 @@ def perform_curategpt_grounding( ) -> List[Tuple[str, str]]: """ Use curategpt to perform grounding for a given diagnosis when initial attempts fail. - + Parameters: - diagnosis: The diagnosis text to ground. - path: The path to the database. You'll need to create an index of Mondo using curategpt in this db @@ -55,13 +62,13 @@ def perform_curategpt_grounding( - limit: The number of search results to return. - relevance_factor: The distance threshold for relevance filtering. - verbose: Whether to print verbose output for debugging. - + Returns: - List of tuples: [(Mondo ID, Label), ...] """ # Initialize the database store db = get_store(database_type, path) - + # Perform the search using the provided diagnosis results = db.search(diagnosis, collection=collection) @@ -79,7 +86,7 @@ def perform_curategpt_grounding( for obj, distance, _meta in limited_results: disease_mondo_id = obj.get("original_id") # Use the 'original_id' field for Mondo ID disease_label = obj.get("label") - + if disease_mondo_id and disease_label: pred_ids.append(disease_mondo_id) pred_labels.append(disease_label) @@ -118,40 +125,43 @@ def perform_oak_grounding( if any(ann.object_id.startswith(prefix) for prefix in include_list) } ) - + if filtered_annotations: return filtered_annotations else: match_type = "exact" if exact_match else "inexact" if verbose: print(f"No {match_type} grounded IDs found for: {diagnosis}") - pass + pass return [('N/A', 'No grounding found')] # Now, integrate curategpt into your ground_diagnosis_text_to_mondo function + + def ground_diagnosis_text_to_mondo( annotator: TextAnnotatorInterface, differential_diagnosis: str, verbose: bool = False, include_list: List[str] = ["MONDO:"], use_ontogpt_grounding: bool = True, - curategpt_path: str = "../curategpt/stagedb/", + curategpt_path: str = "../curategpt/stagedb/", curategpt_collection: str = "ont_mondo", curategpt_database_type: str = "chromadb" ) -> List[Tuple[str, List[Tuple[str, str]]]]: results = [] - + # Split the input into lines and process each one for line in differential_diagnosis.splitlines(): clean_line = clean_diagnosis_line(line) - + # Skip header lines like "**Differential diagnosis:**" if not clean_line or "Differential diagnosis" in clean_line.lower(): continue - + # Try grounding the full line first (exact match) - grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, verbose=verbose, include_list=include_list) - + grounded = perform_oak_grounding(annotator, clean_line, exact_match=True, + verbose=verbose, include_list=include_list) + # Try grounding with curategpt if no grounding is found if use_ontogpt_grounding and grounded == [('N/A', 'No grounding found')]: grounded = perform_curategpt_grounding( @@ -161,13 +171,13 @@ def ground_diagnosis_text_to_mondo( database_type=curategpt_database_type, verbose=verbose ) - + # If still no grounding is found, log the final failure if grounded == [('N/A', 'No grounding found')]: if verbose: print(f"Final grounding failed for: {clean_line}") - + # Append the grounded results (even if no grounding was found) results.append((clean_line, grounded)) - return results \ No newline at end of file + return results diff --git a/src/malco/post_process/generate_plots.py b/src/malco/post_process/generate_plots.py index abddfb88..927245ea 100644 --- a/src/malco/post_process/generate_plots.py +++ b/src/malco/post_process/generate_plots.py @@ -6,6 +6,7 @@ # Make a nice plot, use it as function or as script + def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file, comparing): plot_dir = data_dir.parents[0] / "plots" plot_dir.mkdir(exist_ok=True) @@ -13,39 +14,38 @@ def make_plots(mrr_file, data_dir, languages, num_ppkt, models, topn_aggr_file, # For plot filenam labeling use lowest number of ppkt available for all models/languages etc. num_ppkt = min(num_ppkt.values()) - if comparing=="model": + if comparing == "model": name_string = str(len(models)) else: name_string = str(len(languages)) - with mrr_file.open('r', newline = '') as f: - lines = csv.reader(f, quoting = csv.QUOTE_NONNUMERIC, delimiter = '\t', lineterminator='\n') + with mrr_file.open('r', newline='') as f: + lines = csv.reader(f, quoting=csv.QUOTE_NONNUMERIC, delimiter='\t', lineterminator='\n') results_files = next(lines) mrr_scores = next(lines) - + print(results_files) print(mrr_scores) # Plotting the mrr results - sns.barplot(x = results_files, y = mrr_scores) + sns.barplot(x=results_files, y=mrr_scores) plt.xlabel("Results File") plt.ylabel("Mean Reciprocal Rank (MRR)") plt.title("MRR of Correct Answers Across Different Results Files") - plot_path = plot_dir / (name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") + plot_path = plot_dir / (name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") plt.savefig(plot_path) plt.close() # Plotting bar-plots with top ranks df_aggr = pd.read_csv(topn_aggr_file, delimiter='\t') - - sns.barplot(x="Rank_in", y="percentage", data = df_aggr, hue = comparing) + + sns.barplot(x="Rank_in", y="percentage", data=df_aggr, hue=comparing) plt.xlabel("Number of Ranks in") plt.ylabel("Percentage of Cases") plt.ylim([0.0, 1.0]) plt.title("Rank Comparison for Differential Diagnosis") plt.legend(title=comparing) - plot_path = plot_dir / ("barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") + plot_path = plot_dir / ("barplot_" + name_string + "_" + comparing + "_" + str(num_ppkt) + "ppkt.png") plt.savefig(plot_path) plt.close() - diff --git a/src/malco/post_process/mondo_score_utils.py b/src/malco/post_process/mondo_score_utils.py index 8735b994..55a5f849 100644 --- a/src/malco/post_process/mondo_score_utils.py +++ b/src/malco/post_process/mondo_score_utils.py @@ -2,7 +2,7 @@ from oaklib.interfaces import MappingProviderInterface from pathlib import Path -from typing import List +from typing import List from cachetools.keys import hashkey @@ -10,7 +10,7 @@ PARTIAL_SCORE = 0.5 -def omim_mappings(term: str, adapter) -> List[str]: +def omim_mappings(term: str, adapter) -> List[str]: """ Get the OMIM mappings for a term. @@ -26,7 +26,7 @@ def omim_mappings(term: str, adapter) -> List[str]: Returns: str: The OMIM mappings. - """ + """ omims = [] for m in adapter.sssom_mappings([term], source="OMIM"): if m.predicate_id == "skos:exactMatch": @@ -45,13 +45,13 @@ def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) The predicted Mondo is equivalent to the ground truth OMIM (via skos:exactMatches in Mondo): - + >>> score_grounded_result("MONDO:0007566", "OMIM:132800", get_adapter("sqlite:obo:mondo")) 1.0 The predicted Mondo is a disease entity that groups multiple OMIMs, one of which is the ground truth: - + >>> score_grounded_result("MONDO:0008029", "OMIM:158810", get_adapter("sqlite:obo:mondo")) 0.5 @@ -65,12 +65,11 @@ def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) """ if not isinstance(mondo, MappingProviderInterface): raise ValueError("Adapter is not an MappingProviderInterface") - + if prediction == ground_truth: # predication is the correct OMIM return FULL_SCORE - ground_truths = get_ground_truth_from_cache_or_compute(prediction, mondo, cache) if ground_truth in ground_truths: # prediction is a MONDO that directly maps to a correct OMIM @@ -84,14 +83,15 @@ def score_grounded_result(prediction: str, ground_truth: str, mondo, cache=None) return PARTIAL_SCORE return 0.0 + def get_ground_truth_from_cache_or_compute( - term, - adapter, + term, + adapter, cache, ): if cache is None: return omim_mappings(term, adapter) - + k = hashkey(term) try: ground_truths = cache[k] @@ -102,4 +102,3 @@ def get_ground_truth_from_cache_or_compute( cache[k] = ground_truths cache.misses += 1 return ground_truths - diff --git a/src/malco/post_process/post_process.py b/src/malco/post_process/post_process.py index 84427879..c7fe9624 100644 --- a/src/malco/post_process/post_process.py +++ b/src/malco/post_process/post_process.py @@ -27,10 +27,10 @@ def post_process(self) -> None: create_standardised_results(curategpt, raw_results_dir=raw_results_lang, - output_dir=output_lang, + output_dir=output_lang, output_file_name="results.tsv", ) - + elif self.modality == "several_models": for model in models: raw_results_model = raw_results_dir / "multimodel" / model @@ -40,6 +40,6 @@ def post_process(self) -> None: create_standardised_results(curategpt, raw_results_dir=raw_results_model, - output_dir=output_model, + output_dir=output_model, output_file_name="results.tsv", ) diff --git a/src/malco/post_process/post_process_results_format.py b/src/malco/post_process/post_process_results_format.py index 54046a1b..c47ab679 100644 --- a/src/malco/post_process/post_process_results_format.py +++ b/src/malco/post_process/post_process_results_format.py @@ -11,7 +11,6 @@ from malco.post_process.df_save_util import safe_save_tsv from malco.post_process.extended_scoring import clean_service_answer, ground_diagnosis_text_to_mondo from oaklib import get_adapter - def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: @@ -25,15 +24,15 @@ def read_raw_result_yaml(raw_result_path: Path) -> List[dict]: dict: Contents of the raw result file. """ with open(raw_result_path, 'r') as raw_result: - return list(yaml.safe_load_all(raw_result.read().replace(u'\x04',''))) # Load and convert to list + return list(yaml.safe_load_all(raw_result.read().replace(u'\x04', ''))) # Load and convert to list def create_standardised_results(curategpt: bool, - raw_results_dir: Path, + raw_results_dir: Path, output_dir: Path, output_file_name: str ) -> pd.DataFrame: - + data = [] if curategpt: annotator = get_adapter("sqlite:obo:mondo") @@ -48,20 +47,20 @@ def create_standardised_results(curategpt: bool, if extracted_object: label = extracted_object.get('label') terms = extracted_object.get('terms') - if curategpt and terms: + if curategpt and terms: ontogpt_text = this_result.get("input_text") # its a single string, should be parseable through curategpt cleaned_text = clean_service_answer(ontogpt_text) assert cleaned_text != "", "Cleaning failed: the cleaned text is empty." result = ground_diagnosis_text_to_mondo(annotator, cleaned_text, verbose=False) # terms will now ONLY contain MONDO IDs OR 'N/A'. The latter should be dealt with downstream - terms = [i[1][0][0] for i in result] # MONDO_ID + terms = [i[1][0][0] for i in result] # MONDO_ID if terms: - # Note, the if allows for rerunning ppkts that failed due to connection issues - # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field + # Note, the if allows for rerunning ppkts that failed due to connection issues + # We can have multiple identical ppkts/prompts in results.yaml as long as only one has a terms field num_terms = len(terms) score = [1 / (i + 1) for i in range(num_terms)] # score is reciprocal rank - rank_list = [ i+1 for i in range(num_terms)] + rank_list = [i + 1 for i in range(num_terms)] for term, scr, rank in zip(terms, score, rank_list): data.append({'label': label, 'term': term, 'score': scr, 'rank': rank}) @@ -164,6 +163,7 @@ def extract_pheval_gene_requirements(self) -> List[PhEvalGeneResult]: ) return pheval_result + ''' def create_standardised_results(raw_results_dir: Path, output_dir: Path) -> None: """ @@ -187,4 +187,4 @@ def create_standardised_results(raw_results_dir: Path, output_dir: Path) -> None output_dir=output_dir, tool_result_path=raw_result_path, ) -''' \ No newline at end of file +''' diff --git a/src/malco/post_process/ranking_utils.py b/src/malco/post_process/ranking_utils.py index dc03c51e..29dbffe1 100644 --- a/src/malco/post_process/ranking_utils.py +++ b/src/malco/post_process/ranking_utils.py @@ -1,4 +1,4 @@ -import os +import os import csv from pathlib import Path from datetime import datetime @@ -15,16 +15,18 @@ from malco.post_process.df_save_util import safe_save_tsv from malco.post_process.mondo_score_utils import score_grounded_result from cachetools import LRUCache -from typing import List +from typing import List from cachetools.keys import hashkey from shelved_cache import PersistentCache FULL_SCORE = 1.0 PARTIAL_SCORE = 0.5 + def cache_info(self): return f"CacheInfo: hits={self.hits}, misses={self.misses}, maxsize={self.wrapped.maxsize}, currsize={self.wrapped.currsize}" + def mondo_adapter() -> OboGraphInterface: """ Get the adapter for the MONDO ontology. @@ -32,40 +34,41 @@ def mondo_adapter() -> OboGraphInterface: Returns: Adapter: The adapter. """ - return get_adapter("sqlite:obo:mondo") + return get_adapter("sqlite:obo:mondo") + def compute_mrr_and_ranks( - comparing: str, - output_dir: Path, + comparing: str, + output_dir: Path, out_subdir: str, - prompt_dir: str, + prompt_dir: str, correct_answer_file: str, - ) -> Path: +) -> Path: - # Read in results TSVs from self.output_dir that match glob results*tsv + # Read in results TSVs from self.output_dir that match glob results*tsv out_caches = Path("caches") - #out_caches = output_dir / "caches" + # out_caches = output_dir / "caches" out_caches.mkdir(exist_ok=True) output_dir = output_dir / out_subdir results_data = [] results_files = [] num_ppkt = {} pc2_cache_file = str(out_caches / "score_grounded_result_cache") - pc2 = PersistentCache(LRUCache, pc2_cache_file, maxsize=524288) + pc2 = PersistentCache(LRUCache, pc2_cache_file, maxsize=524288) pc1_cache_file = str(out_caches / "omim_mappings_cache") pc1 = PersistentCache(LRUCache, pc1_cache_file, maxsize=524288) # Treat hits and misses as run-specific arguments, write them cache_log pc1.hits = pc1.misses = 0 pc2.hits = pc2.misses = 0 PersistentCache.cache_info = cache_info - + mode_index = 0 for subdir, dirs, files in os.walk(output_dir): for filename in files: if filename.startswith("result") and filename.endswith(".tsv"): file_path = os.path.join(subdir, filename) df = pd.read_csv(file_path, sep="\t") - num_ppkt[subdir.split('/')[-1]] = df["label"].nunique() + num_ppkt[subdir.split('/')[-1]] = df["label"].nunique() results_data.append(df) # Append both the subdirectory relative to output_dir and the filename results_files.append(os.path.relpath(file_path, output_dir)) @@ -84,21 +87,22 @@ def compute_mrr_and_ranks( cache_file = out_caches / "cache_log.txt" - with cache_file.open('a', newline = '') as cf: + with cache_file.open('a', newline='') as cf: now_is = datetime.now().strftime("%Y%m%d-%H%M%S") - cf.write("Timestamp: " + now_is +"\n\n") + cf.write("Timestamp: " + now_is + "\n\n") mondo = mondo_adapter() i = 0 # Each df is a model or a language for df in results_data: # For each label in the results file, find if the correct term is ranked df["rank"] = df.groupby("label")["score"].rank(ascending=False, method="first") - label_4_non_eng = df["label"].str.replace("_[a-z][a-z]-prompt", "_en-prompt", regex=True) #TODO is bug here? + label_4_non_eng = df["label"].str.replace( + "_[a-z][a-z]-prompt", "_en-prompt", regex=True) # TODO is bug here? # df['correct_term'] is an OMIM # df['term'] is Mondo or OMIM ID, or even disease label df["correct_term"] = label_4_non_eng.map(label_to_correct_term) - + # Make sure caching is used in the following by unwrapping explicitly results = [] for idx, row in df.iterrows(): @@ -125,34 +129,34 @@ def compute_mrr_and_ranks( full_df_path = output_dir / results_files[i].split("/")[0] full_df_filename = "full_df_results.tsv" safe_save_tsv(full_df_path, full_df_filename, df) - + # Calculate MRR for this file mrr = df.groupby("label")["reciprocal_rank"].max().mean() mrr_scores.append(mrr) - + # Calculate top of each rank rank_df.loc[i, comparing] = results_files[i].split("/")[0] - - ppkts = df.groupby("label")[["rank","is_correct"]] + + ppkts = df.groupby("label")[["rank", "is_correct"]] index_matches = df.index[df['is_correct']] - + # for each group for ppkt in ppkts: - # is there a true? ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe + # is there a true? ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe if not any(ppkt[1]["is_correct"]): # no --> increase nf = "not found" - rank_df.loc[i,"nf"] += 1 + rank_df.loc[i, "nf"] += 1 else: # yes --> what's it rank? It's jind = ppkt[1].index[ppkt[1]['is_correct']] j = int(ppkt[1]['rank'].loc[jind].values[0]) - if j<11: + if j < 11: # increase n - rank_df.loc[i,"n"+str(j)] += 1 + rank_df.loc[i, "n" + str(j)] += 1 else: # increase n10p - rank_df.loc[i,"n10p"] += 1 - + rank_df.loc[i, "n10p"] += 1 + # Write cache charatcteristics to file cf.write(results_files[i]) cf.write('\nscore_grounded_result cache info:\n') @@ -164,9 +168,9 @@ def compute_mrr_and_ranks( pc1.close() pc2.close() - + for modelname in num_ppkt.keys(): - rank_df.loc[rank_df['model']==modelname,'num_cases'] = num_ppkt[modelname] + rank_df.loc[rank_df['model'] == modelname, 'num_cases'] = num_ppkt[modelname] data_dir = output_dir / "rank_data" data_dir.mkdir(exist_ok=True) topn_file_name = "topn_result.tsv" @@ -177,27 +181,28 @@ def compute_mrr_and_ranks( print(mrr_scores) mrr_file = data_dir / "mrr_result.tsv" - # write out results for plotting - with mrr_file.open('w', newline = '') as dat: - writer = csv.writer(dat, quoting = csv.QUOTE_NONNUMERIC, delimiter = '\t', lineterminator='\n') + # write out results for plotting + with mrr_file.open('w', newline='') as dat: + writer = csv.writer(dat, quoting=csv.QUOTE_NONNUMERIC, delimiter='\t', lineterminator='\n') writer.writerow(results_files) writer.writerow(mrr_scores) df = pd.read_csv(topn_file, delimiter='\t') df["top1"] = (df['n1']) / df["num_cases"] - df["top3"] = (df["n1"] + df["n2"] + df["n3"] ) / df["num_cases"] - df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"] ) / df["num_cases"] - df["top10"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"] + df["n6"] + df["n7"] + df["n8"] + df["n9"] + df["n10"] ) / df["num_cases"] + df["top3"] = (df["n1"] + df["n2"] + df["n3"]) / df["num_cases"] + df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"]) / df["num_cases"] + df["top10"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"] + df["n6"] + + df["n7"] + df["n8"] + df["n9"] + df["n10"]) / df["num_cases"] df["not_found"] = (df["nf"]) / df["num_cases"] - + df_aggr = pd.DataFrame() - df_aggr = pd.melt(df, id_vars=comparing, value_vars=["top1", "top3", "top5", "top10", "not_found"], var_name="Rank_in", value_name="percentage") + df_aggr = pd.melt(df, id_vars=comparing, value_vars=[ + "top1", "top3", "top5", "top10", "not_found"], var_name="Rank_in", value_name="percentage") # If "topn_aggr.tsv" already exists, prepend "old_" # It's the user's responsibility to know only up to 2 versions can exist, then data is lost topn_aggr_file_name = "topn_aggr.tsv" topn_aggr_file = data_dir / topn_aggr_file_name safe_save_tsv(data_dir, topn_aggr_file_name, df_aggr) - + return mrr_file, data_dir, num_ppkt, topn_aggr_file - \ No newline at end of file diff --git a/src/malco/prepare/setup_phenopackets.py b/src/malco/prepare/setup_phenopackets.py index c35921a2..41c117f4 100644 --- a/src/malco/prepare/setup_phenopackets.py +++ b/src/malco/prepare/setup_phenopackets.py @@ -1,10 +1,11 @@ import zipfile -import os +import os import requests -phenopacket_zip_url="https://github.com/monarch-initiative/phenopacket-store/releases/download/0.1.11/all_phenopackets.zip" +phenopacket_zip_url = "https://github.com/monarch-initiative/phenopacket-store/releases/download/0.1.11/all_phenopackets.zip" # TODO just point to a folder w/ ppkts -phenopacket_dir="phenopacket-store" +phenopacket_dir = "phenopacket-store" + def setup_phenopackets(self) -> str: phenopacket_store_path = os.path.join(self.input_dir, phenopacket_dir) diff --git a/src/malco/prepare/setup_run_pars.py b/src/malco/prepare/setup_run_pars.py index 6944ee90..0828b21f 100644 --- a/src/malco/prepare/setup_run_pars.py +++ b/src/malco/prepare/setup_run_pars.py @@ -2,6 +2,7 @@ import csv import sys + def import_inputdata(self) -> None: """ Example input file is located in ``self.input_dir`` and named run_parameters.csv @@ -23,7 +24,7 @@ def import_inputdata(self) -> None: Meaning run multilingual prompts with those 5 aforementioned languages, and only execute the function postprocess, not run. """ with open(self.input_dir / "run_parameters.csv", 'r') as pars: - lines = csv.reader(pars, quoting = csv.QUOTE_NONNUMERIC, delimiter = ',', lineterminator='\n') + lines = csv.reader(pars, quoting=csv.QUOTE_NONNUMERIC, delimiter=',', lineterminator='\n') in_langs = next(lines) in_models = next(lines) in_what_to_run = next(lines) @@ -33,17 +34,17 @@ def import_inputdata(self) -> None: if (l > 1 and m > 1): sys.exit("Error, either run multiple languages or models, not both, exiting...") elif l == 1 and m >= 1: - if in_langs[0]=="en": - self.modality = "several_models" # English and more than 1 model defaults to multiple models + if in_langs[0] == "en": + self.modality = "several_models" # English and more than 1 model defaults to multiple models else: if m > 1: sys.exit("Error, only English and multiple models supported, exiting...") - else: # m==1 - self.modality = "several_languages" # non English defaults to multiple languages - elif l > 1: + else: # m==1 + self.modality = "several_languages" # non English defaults to multiple languages + elif l > 1: self.modality = "several_languages" self.languages = tuple(in_langs) - self.models = tuple(in_models) + self.models = tuple(in_models) self.do_run_step = in_what_to_run[0] # only run the run part of the code self.do_postprocess_step = in_what_to_run[1] # only run the postprocess part of the code diff --git a/src/malco/run/run.py b/src/malco/run/run.py index 1ad3e416..cfc7c14f 100644 --- a/src/malco/run/run.py +++ b/src/malco/run/run.py @@ -7,13 +7,14 @@ from malco.run.search_ppkts import search_ppkts + def call_ontogpt( - lang: str, - raw_results_dir: Path, - input_dir: Path, - model: str, + lang: str, + raw_results_dir: Path, + input_dir: Path, + model: str, modality: typing.Literal['several_languages', 'several_models'], -)-> None: +) -> None: """ Wrapper used for parallel execution of ontogpt. @@ -37,9 +38,9 @@ def call_ontogpt( else: raise ValueError('Not permitted run modality!\n') - selected_indir = search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model_dir) + selected_indir = search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model_dir) yaml_file = f"{raw_results_dir}/{lang_or_model_dir}/results.yaml" - + if os.path.isfile(yaml_file): old_yaml_file = yaml_file yaml_file = f"{raw_results_dir}/{lang_or_model_dir}/new_results.yaml" @@ -97,15 +98,16 @@ def run(self, if modality == "several_languages": with multiprocessing.Pool(processes=max_workers) as pool: try: - pool.starmap(call_ontogpt, [(lang, raw_results_dir / "multilingual", input_dir, "gpt-4o", modality) for lang in langs]) + pool.starmap(call_ontogpt, [(lang, raw_results_dir / "multilingual", + input_dir, "gpt-4o", modality) for lang in langs]) except FileExistsError as e: raise ValueError('Did not clean up after last run, check tmp dir: \n' + e) - if modality == "several_models": # English only many models with multiprocessing.Pool(processes=max_workers) as pool: try: - pool.starmap(call_ontogpt, [("en", raw_results_dir / "multimodel", input_dir, model, modality) for model in models]) + pool.starmap(call_ontogpt, [("en", raw_results_dir / "multimodel", + input_dir, model, modality) for model in models]) except FileExistsError as e: raise ValueError('Did not clean up after last run, check tmp dir: \n' + e) diff --git a/src/malco/run/search_ppkts.py b/src/malco/run/search_ppkts.py index d0cfd560..2cf5467f 100644 --- a/src/malco/run/search_ppkts.py +++ b/src/malco/run/search_ppkts.py @@ -3,21 +3,21 @@ import shutil from malco.post_process.post_process_results_format import read_raw_result_yaml - + def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): """ Check what ppkts have already been computed in current output dir, for current run parameters. - + ontogpt will run every .txt that is in inputdir, we need a tmp inputdir excluding already run cases. Source of truth is the results.yaml output by ontogpt. Only extracted_object containing terms is considered successfully run. Note that rerunning """ - + # List of "labels" that are already present in results.yaml iff terms is not None files = [] - + yaml_file = f"{raw_results_dir}/{lang_or_model}/results.yaml" if os.path.isfile(yaml_file): # tmp inputdir contains prompts yet to be computed for a given model (pars set) @@ -36,11 +36,10 @@ def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): else: return prompt_dir - # prompts: ls prompt_dir promptfiles = [] for (dirpath, dirnames, filenames) in os.walk(prompt_dir): - promptfiles.extend(filenames) + promptfiles.extend(filenames) break # foreach promptfile in original_inputdir @@ -52,4 +51,4 @@ def search_ppkts(input_dir, prompt_dir, raw_results_dir, lang_or_model): else: shutil.copyfile(prompt_dir + promptfile, selected_indir + "/" + promptfile) - return selected_indir \ No newline at end of file + return selected_indir diff --git a/src/malco/runner.py b/src/malco/runner.py index 21dbb42e..0528c567 100644 --- a/src/malco/runner.py +++ b/src/malco/runner.py @@ -11,6 +11,7 @@ from malco.post_process.generate_plots import make_plots import os + class MalcoRunner(PhEvalRunner): input_dir: Path testdata_dir: Path @@ -19,7 +20,6 @@ class MalcoRunner(PhEvalRunner): config_file: Path version: str - def prepare(self): """ Pre-process any data and inputs necessary to run the tool. @@ -35,13 +35,12 @@ def run(self): pass if self.do_run_step: run(self, - ) + ) # Cleanup tmp_dir = f"{self.input_dir}/prompts/tmp/" if os.path.isdir(tmp_dir): rmtree(tmp_dir) - def post_process(self, print_plot=True, prompts_subdir_name="prompts", @@ -54,24 +53,22 @@ def post_process(self, print("post processing results to PhEval standardised TSV output.") post_process(self) - - - if self.modality=="several_languages": + + if self.modality == "several_languages": comparing = "language" - out_subdir="multilingual" - elif self.modality=="several_models": + out_subdir = "multilingual" + elif self.modality == "several_models": comparing = "model" - out_subdir="multimodel" + out_subdir = "multimodel" else: raise ValueError('Not permitted run modality!\n') mrr_file, data_dir, num_ppkt, topn_aggr_file = compute_mrr_and_ranks(comparing, - output_dir=self.output_dir, - out_subdir=out_subdir, - prompt_dir=os.path.join(self.input_dir, prompts_subdir_name), - correct_answer_file=correct_answer_file) - + output_dir=self.output_dir, + out_subdir=out_subdir, + prompt_dir=os.path.join( + self.input_dir, prompts_subdir_name), + correct_answer_file=correct_answer_file) + if print_plot: make_plots(mrr_file, data_dir, self.languages, num_ppkt, self.models, topn_aggr_file, comparing) - - \ No newline at end of file