Skip to content

Commit

Permalink
predictions are working
Browse files Browse the repository at this point in the history
  • Loading branch information
tkmamidi committed Jan 7, 2024
1 parent c1cff1c commit 0fbc9ca
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 46 deletions.
7 changes: 4 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union

from fastapi import FastAPI
from utils.query import get_ditto_score

# run me https://fastapi.tiangolo.com/#installation
app = FastAPI()
Expand All @@ -25,8 +25,9 @@ def get_scores(chromosome: str, start: int, end: int):
@app.get("/var/{chromosome}/{position}/{ref}/{alt}")
def get_variant_score(chromosome, position, ref, alt):
# TODO call the get_scores function to perform look up, if score not precomputed then call dynamic generation
scores = get_scores(chromosome=chromosome, start=position, end=position)
return {"variant": f"chr{chromosome}:g.{position}{ref}>{alt}"}
scores = get_ditto_score(chrom=chromosome, pos=position, ref=ref, alt=alt)
# return {"variant": f"chr{chromosome}:g.{position}{ref}>{alt}"}
return {"scores": scores}


@app.get("/hgvs/{hgvs_cdna}")
Expand Down
6 changes: 3 additions & 3 deletions src/utils/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

def parse_and_predict(dataframe, config_dict, clf):
# Drop variant info columns so we can perform one-hot encoding
dataframe["so"] = dataframe["consequence"]
dataframe["so"] = dataframe["consequence"].copy()
dataframe = dataframe.drop(config_dict["id_cols"], axis=1)
dataframe = dataframe.replace([".", "-", ""], np.nan)
for key in dataframe.columns:
Expand Down Expand Up @@ -44,5 +44,5 @@ def parse_and_predict(dataframe, config_dict, clf):
)

y_score = 1 - clf.predict(df2, verbose=0)
del temp_df
return df2, y_score
del temp_df,df2
return y_score
60 changes: 20 additions & 40 deletions src/utils/Home.py → src/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def get_parser(data_config):


# Function to load model and weights
def load_model():
def load_model(clf_path):
# Load model and weights
clf = keras.models.load_model("./results/Neural_network")
clf.load_weights("./results/weights.h5")
clf = keras.models.load_model(clf_path + "/Neural_network")
clf.load_weights(clf_path + "/weights.h5")
return clf

# Function to query variant reference allele based on posiiton from UCSC API
Expand Down Expand Up @@ -60,8 +60,8 @@ def query_variant(chrom: str, pos: int, allele_len: int) -> json:

return get_fields.json()

def main():
repo_root = Path(__file__).parent.parent
def get_ditto_score(chrom: str, pos: int, ref: str, alt: str):
repo_root = Path(__file__).parent.parent.parent
# Load the col config file as dictionary
config_f = repo_root / "configs" / "col_config.yaml"
config_dict = get_col_configs(config_f)
Expand All @@ -71,43 +71,23 @@ def main():
parser = get_parser(data_config)

# Load the model and data
clf = load_model()

# Query variant reference allele based on posiiton from UCSC API
try:
actual_ref = query_variant(str(chrom), int(pos), len(ref))["dna"].upper()

# Handle invalid variant position
except:
print("Please enter a valid variant info.")

if ref == actual_ref and ref != alt:
try:
# Query variant annotations via opencravat API and get data as dataframe
overall = parser.query_variant(chrom=str(chrom), pos=int(pos), ref=ref, alt=alt)
except:
overall = pd.DataFrame()

# Check if variant annotations are found
if overall.empty:
print(
"Could not get variant annotations from OpenCravat's API. Please check the variant info and try again."
)

else:
# Select transcript
transcript = st.selectbox(
"**Select a transcript:**", options=list(overall["transcript"].unique())
)

# Filter data based on selected transcript
clf_path = repo_root / "results"
clf = load_model(str(clf_path))

overall = parser.query_variant(chrom=str(chrom), pos=int(pos), ref=ref, alt=alt)
# Check if variant annotations are found
if overall.empty:
return(
"Could not get variant annotations from OpenCravat's API. Please check the variant info and try again."
)
else:
score_dict = {}
for transcript in overall["transcript"].unique():
transcript_data = overall[overall["transcript"] == transcript].reset_index(
drop=True
)

# Parse and predict
df2, y_score = parse_and_predict(transcript_data, config_dict, clf)
y_score = parse_and_predict(transcript_data, config_dict, clf)
y_score = round(y_score[0][0], 2)
score_dict[transcript] = str(y_score)
return score_dict

if __name__ == "__main__":
main()

0 comments on commit 0fbc9ca

Please sign in to comment.