-
Notifications
You must be signed in to change notification settings - Fork 36
/
fingerprints.py
282 lines (240 loc) · 8.51 KB
/
fingerprints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import argparse
import csv
from functools import partial
import gzip
from itertools import chain, islice
import os
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Set, Tuple
import h5py
import numpy as np
import ray
from rdkit import Chem, DataStructs
from rdkit.Chem import rdMolDescriptors as rdmd
from tqdm import tqdm
try:
if "redis_password" in os.environ:
ray.init(
address=os.environ["ip_head"],
_node_ip_address=os.environ["ip_head"].split(":")[0],
_redis_password=os.environ["redis_password"],
)
else:
ray.init(address="auto")
except ConnectionError:
ray.init(num_cpus=len(os.sched_getaffinity(0)))
def get_smis(
libaries: Iterable[str], title_line: bool = True, delimiter: str = ",", smiles_col: int = 0
) -> Iterator[str]:
for library in libaries:
if Path(library).suffix == ".gz":
open_ = partial(gzip.open, mode="rt")
else:
open_ = open
with open_(library) as fid:
reader = csv.reader(fid, delimiter=delimiter)
if title_line:
next(reader)
for row in reader:
yield row[smiles_col]
def batches(it: Iterable, chunk_size: int) -> Iterator[List]:
"""Consume an iterable in batches of size chunk_size"""
it = iter(it)
return iter(lambda: list(islice(it, chunk_size)), [])
@ray.remote
def _smis_to_mols(smis: Iterable) -> List[Optional[Chem.Mol]]:
return [Chem.MolFromSmiles(smi) for smi in smis]
def smis_to_mols(smis: Iterable[str]) -> List[Optional[Chem.Mol]]:
chunksize = int(ray.cluster_resources()["CPU"]) * 2
refs = [_smis_to_mols.remote(smis_chunk) for smis_chunk in batches(smis, chunksize)]
mols_chunks = [ray.get(r) for r in refs]
return list(chain(*mols_chunks))
@ray.remote
def _mols_to_fps(
mols: Iterable[Chem.Mol], fingerprint: str = "pair", radius: int = 2, length: int = 2048
) -> np.ndarray:
"""fingerprint functions must be wrapped in a static function
so that they may be pickled for parallel processing
Parameters
----------
mols : Iterable[Chem.Mol]
the molecules to encode
fingerprint : str
the the type of fingerprint to generate
radius : int
the radius of the fingerprint
length : int
the length of the fingerprint
Returns
-------
T_comp
the compressed feature representation of the molecule
"""
if fingerprint == "morgan":
fps = [
rdmd.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=length, useChirality=True)
for mol in mols
]
elif fingerprint == "pair":
fps = [
rdmd.GetHashedAtomPairFingerprintAsBitVect(
mol, minLength=1, maxLength=1 + radius, nBits=length
)
for mol in mols
]
elif fingerprint == "rdkit":
fps = [
rdmd.RDKFingerprint(mol, minPath=1, maxPath=1 + radius, fpSize=length) for mol in mols
]
elif fingerprint == "maccs":
fps = [rdmd.GetMACCSKeysFingerprint(mol) for mol in mols]
else:
raise NotImplementedError(f'Unrecognized fingerprint: "{fingerprint}"')
X = np.empty((len(mols), length))
[DataStructs.ConvertToNumpyArray(fp, x) for fp, x in zip(fps, X)]
return X
def mols_to_fps(
mols: Iterable[Chem.Mol], fingerprint: str = "pair", radius: int = 2, length: int = 2048
) -> np.ndarray:
"""Calculate the Morgan fingerprint of each molecule
Parameters
----------
mols : Iterable[Chem.Mol]
the molecules
radius : int, default=2
the radius of the fingerprint
length : int, default=2048
the number of bits in the fingerprint
Returns
-------
List
a list of the corresponding morgan fingerprints in bit vector form
"""
chunksize = int(ray.cluster_resources()["CPU"] * 16)
refs = [
_mols_to_fps.remote(mols_chunk, fingerprint, radius, length)
for mols_chunk in batches(mols, chunksize)
]
fps_chunks = [
ray.get(r) for r in tqdm(refs, desc="Calculating fingerprints", unit="chunk", leave=False)
]
return np.vstack(fps_chunks)
def fps_hdf5(
smis: Iterable[str],
size: int,
fingerprint: str = "pair",
radius: int = 2,
length: int = 2048,
filepath: str = "fps.h5",
) -> Tuple[str, Set[int]]:
"""Prepare an HDF5 file containing the feature matrix of the input SMILES
strings
Parameters
----------
smis : Iterable[str]
the SMILES strings from which to build the feature matrix
size : int
the total number of smiles strings
fingerprint : str, default='pair'
the type of fingerprint to calculate
radius : int, default=2
the "radius" of the fingerprint to calculate. For path-based
fingerprints, this corresponds to the path length
length : int, default=2048
the length/number of bits in the fingerprint
filepath : str, default='fps.h5'
the filepath of the output HDF5 file
Returns
-------
str
the filepath of the output HDF5 file
invalid_idxs : Set[int]
the set of invalid indices in the input SMILES strings
"""
with h5py.File(filepath, "w") as h5f:
CHUNKSIZE = 1024
fps_dset = h5f.create_dataset(
"fps", (size, length), chunks=(CHUNKSIZE, length), dtype="int8"
)
batch_size = 4 * CHUNKSIZE * int(ray.cluster_resources()["CPU"])
n_batches = size // batch_size + 1
invalid_idxs = set()
i = 0
for smis_batch in tqdm(
batches(smis, batch_size),
total=n_batches,
desc="Precalculating fps",
unit="batch",
unit_scale=batch_size,
):
mols = smis_to_mols(smis_batch)
invalid_idxs.update({i + j for j, mol in enumerate(mols) if mol is None})
fps = mols_to_fps([mol for mol in mols if mol is not None], fingerprint, radius, length)
fps_dset[i : i + len(fps)] = fps
i += len(mols)
valid_size = size - len(invalid_idxs)
if valid_size != size:
fps_dset.resize(valid_size, axis=0)
return filepath, invalid_idxs
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", "--output", type=Path, help="the filepath under of the output fingerprints HDF5 file. If no suffix is provided, will add '.h5'. If no name is provided, output file will be named <library>.h5"
)
parser.add_argument(
"--fingerprint",
default="pair",
choices={"morgan", "rdkit", "pair", "maccs"},
help="the type of encoder to use",
)
parser.add_argument(
"--radius", type=int, default=2, help="the radius or path length to use for fingerprints"
)
parser.add_argument("--length", type=int, default=2048, help="the length of the fingerprint")
parser.add_argument(
"-l",
"--library",
required=True,
nargs="+",
help="the files containing members of the MoleculePool",
)
parser.add_argument(
"--no-title-line",
action="store_true",
help="whether there is no title line in the library file",
)
parser.add_argument(
"--total-size",
type=int,
help="(if known) the total number of molecules in the library file"
)
parser.add_argument(
"-d", "--delimiter", default=",", help="the column separator in the library file"
)
parser.add_argument(
"--smiles-col",
default=0,
type=int,
help="the column containing the SMILES string in the library file",
)
args = parser.parse_args()
args.title_line = not args.no_title_line
path = (args.output or Path(args.library[0])).with_suffix(".h5")
if args.total_size is None:
args.total_size = sum(
1 for _ in get_smis(args.library, args.title_line, args.delimiter, args.smiles_col)
)
print("Precalculating feature matrix ...", end=" ")
smis = get_smis(args.library, args.title_line, args.delimiter, args.smiles_col)
fps, invalid_lines = fps_hdf5(
smis, args.total_size, args.fingerprint, args.radius, args.length, path
)
print("Done!")
print(f'Feature matrix was saved to "{fps}"', flush=True)
print(
"When using this fingerprints file, you should add "
f'"--fps {path} --invalid-lines {" ".join(invalid_lines)}" to the command line '
"or the config file to speed up pool construction"
)
if __name__ == "__main__":
main()