From 0fbc9caa94a7e17b8a0071ca91c120456303443f Mon Sep 17 00:00:00 2001 From: Tarun Mamidi Date: Sun, 7 Jan 2024 00:20:09 -0600 Subject: [PATCH] predictions are working --- src/main.py | 7 ++-- src/utils/predict.py | 6 ++-- src/utils/{Home.py => query.py} | 60 +++++++++++---------------------- 3 files changed, 27 insertions(+), 46 deletions(-) rename src/utils/{Home.py => query.py} (61%) diff --git a/src/main.py b/src/main.py index 5afba5b..695b30b 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,6 @@ from typing import Union - from fastapi import FastAPI +from utils.query import get_ditto_score # run me https://fastapi.tiangolo.com/#installation app = FastAPI() @@ -25,8 +25,9 @@ def get_scores(chromosome: str, start: int, end: int): @app.get("/var/{chromosome}/{position}/{ref}/{alt}") def get_variant_score(chromosome, position, ref, alt): # TODO call the get_scores function to perform look up, if score not precomputed then call dynamic generation - scores = get_scores(chromosome=chromosome, start=position, end=position) - return {"variant": f"chr{chromosome}:g.{position}{ref}>{alt}"} + scores = get_ditto_score(chrom=chromosome, pos=position, ref=ref, alt=alt) + # return {"variant": f"chr{chromosome}:g.{position}{ref}>{alt}"} + return {"scores": scores} @app.get("/hgvs/{hgvs_cdna}") diff --git a/src/utils/predict.py b/src/utils/predict.py index b0cb294..a462c55 100644 --- a/src/utils/predict.py +++ b/src/utils/predict.py @@ -4,7 +4,7 @@ def parse_and_predict(dataframe, config_dict, clf): # Drop variant info columns so we can perform one-hot encoding - dataframe["so"] = dataframe["consequence"] + dataframe["so"] = dataframe["consequence"].copy() dataframe = dataframe.drop(config_dict["id_cols"], axis=1) dataframe = dataframe.replace([".", "-", ""], np.nan) for key in dataframe.columns: @@ -44,5 +44,5 @@ def parse_and_predict(dataframe, config_dict, clf): ) y_score = 1 - clf.predict(df2, verbose=0) - del temp_df - return df2, y_score + del temp_df,df2 + return y_score diff --git a/src/utils/Home.py b/src/utils/query.py similarity index 61% rename from src/utils/Home.py rename to src/utils/query.py index d25f1c6..ac2d662 100644 --- a/src/utils/Home.py +++ b/src/utils/query.py @@ -28,10 +28,10 @@ def get_parser(data_config): # Function to load model and weights -def load_model(): +def load_model(clf_path): # Load model and weights - clf = keras.models.load_model("./results/Neural_network") - clf.load_weights("./results/weights.h5") + clf = keras.models.load_model(clf_path + "/Neural_network") + clf.load_weights(clf_path + "/weights.h5") return clf # Function to query variant reference allele based on posiiton from UCSC API @@ -60,8 +60,8 @@ def query_variant(chrom: str, pos: int, allele_len: int) -> json: return get_fields.json() -def main(): - repo_root = Path(__file__).parent.parent +def get_ditto_score(chrom: str, pos: int, ref: str, alt: str): + repo_root = Path(__file__).parent.parent.parent # Load the col config file as dictionary config_f = repo_root / "configs" / "col_config.yaml" config_dict = get_col_configs(config_f) @@ -71,43 +71,23 @@ def main(): parser = get_parser(data_config) # Load the model and data - clf = load_model() - - # Query variant reference allele based on posiiton from UCSC API - try: - actual_ref = query_variant(str(chrom), int(pos), len(ref))["dna"].upper() - - # Handle invalid variant position - except: - print("Please enter a valid variant info.") - - if ref == actual_ref and ref != alt: - try: - # Query variant annotations via opencravat API and get data as dataframe - overall = parser.query_variant(chrom=str(chrom), pos=int(pos), ref=ref, alt=alt) - except: - overall = pd.DataFrame() - - # Check if variant annotations are found - if overall.empty: - print( - "Could not get variant annotations from OpenCravat's API. Please check the variant info and try again." - ) - - else: - # Select transcript - transcript = st.selectbox( - "**Select a transcript:**", options=list(overall["transcript"].unique()) - ) - - # Filter data based on selected transcript + clf_path = repo_root / "results" + clf = load_model(str(clf_path)) + + overall = parser.query_variant(chrom=str(chrom), pos=int(pos), ref=ref, alt=alt) + # Check if variant annotations are found + if overall.empty: + return( + "Could not get variant annotations from OpenCravat's API. Please check the variant info and try again." + ) + else: + score_dict = {} + for transcript in overall["transcript"].unique(): transcript_data = overall[overall["transcript"] == transcript].reset_index( drop=True ) - - # Parse and predict - df2, y_score = parse_and_predict(transcript_data, config_dict, clf) + y_score = parse_and_predict(transcript_data, config_dict, clf) y_score = round(y_score[0][0], 2) + score_dict[transcript] = str(y_score) + return score_dict -if __name__ == "__main__": - main()