-
Notifications
You must be signed in to change notification settings - Fork 2
/
predictor.py
99 lines (77 loc) · 2.83 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import sys
import os
import json
import numpy as np
import pandas as pd
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
print ("Importing modules...")
import modules
# print ("Done")
##################################################
print ("Reading data from disk...", end=' ')
sys.stdout.flush()
df = pd.read_parquet('./datasets/Metal_all_20180601.parquet')
seqs = np.array(df.sequence)
metals = np.array(df.ligandId)
fingerprints = np.array(df.fingerprint)
print ("Done")
##################################################
print ("Using FOFE encoder...", end=' ')
sys.stdout.flush()
metal_dict = {}
with open("./dictionaries/metal_dict", 'r') as fp:
metal_dict = json.load(fp)
num_to_metal = {e: k for k, e in metal_dict.items()}
print ("Done")
##################################################
print("Loading metal_predictor...", end=' ')
from keras.models import model_from_json
json_file = open('./models/metal_predict.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
metal_predictor = model_from_json(loaded_model_json)
metal_predictor.load_weights("./models/metal_predict.h5")
print ("Done")
##################################################
factor = 2.33
def threshold_func(y_in, factor):
y_out = np.zeros_like(y_in)
for i in range(y_in.shape[0]):
th= np.mean(y_in[i]) + factor * np.std(y_in[i])
y_out[i] = (y_in[i] > th)
return y_out
print ("Threshold factor set to " + str(factor))
print ("--------------------------------------------------")
choice = 0
if len(sys.argv) == 1:
choice = np.random.randint(58207)
print ("No input is provided. Randomly choose index...[" + str(choice) + "]")
else:
choice = int(sys.argv[1])
print ("Choose index [" + str(choice) + "]")
if choice < 0 or choice > 58206:
sys.exit("Index must be within [0, 58206]")
print ("--------------------------------------------------")
if len(seqs[choice]) > 60:
print ("The seuqnce is [" + seqs[choice][:30] + "..." + seqs[choice][len(seqs[choice])-30:] + "]\n")
else:
print ("The seuqnce is [" + seqs[choice] + "]\n")
metal_out = metal_predictor.predict(modules.FOFE(seqs[choice]))
max_index = np.argmax(metal_out)
metal = num_to_metal[max_index]
print ("This sample is binded to [" + metal + "]")
print (" Ground truth [" + metals[choice] + "]\n")
json_file = open('./models/' + metal + '.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
MBS_predictor = model_from_json(loaded_model_json)
MBS_predictor.load_weights('./models/' + metal + '.h5')
MBS_out = MBS_predictor.predict(modules.FOFE(seqs[choice]))
MBS_OneHot = threshold_func(MBS_out, factor)
MBS = [np.where(e==1)[0] for e in MBS_OneHot][0]
print ("This sample has [", end="")
print (*MBS, sep=",", end="")
print ("] binding sites")
print (" Ground truth [", end="")
print (*fingerprints[choice], sep=",", end="")
print ("]")