Skip to content

Commit

Permalink
Update filter_and_merge.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunp2 authored Sep 26, 2023
1 parent 64b341d commit 26a815e
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions utils/filter_and_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from src.datasets import read_sdf


def run(input_dir, output_dir, template, n):
def run(input_dir, output_dir, template):

os.makedirs(output_dir, exist_ok=True)
out_table_path = os.path.join(output_dir, f'{template}_table.csv')
Expand All @@ -23,43 +23,45 @@ def run(input_dir, output_dir, template, n):
full_fragments = []
full_linkers = []

for idx in range(n):
mol_path = os.path.join(input_dir, f'{template}_mol_{idx}.sdf')
frag_path = os.path.join(input_dir, f'{template}_frag_{idx}.sdf')
link_path = os.path.join(input_dir, f'{template}_link_{idx}.sdf')
table_path = os.path.join(input_dir, f'{template}_table_{idx}.csv')

# for idx in range(n):
mol_path = os.path.join(input_dir, f'{template}_mol.sdf')
frag_path = os.path.join(input_dir, f'{template}_frag.sdf')
link_path = os.path.join(input_dir, f'{template}_link.sdf')
table_path = os.path.join(input_dir, f'{template}_table.csv')

table = pd.read_csv(table_path)
table['idx'] = table.index
grouped_table = (
table
.groupby(['molecule', 'fragments', 'linker', 'anchor_1', 'anchor_2'])
.min()
.reset_index()
.sort_values(by='idx')
)
idx_to_keep = set(grouped_table['idx'].unique())
table['keep'] = table['idx'].isin(idx_to_keep)
table = pd.read_csv(table_path)
table['idx'] = table.index
grouped_table = (
table
.groupby(['molecule', 'fragments', 'linker', 'anchor_1', 'anchor_2'])
.min()
.reset_index()
.sort_values(by='idx')
)
idx_to_keep = set(grouped_table['idx'].unique())
table['keep'] = table['idx'].isin(idx_to_keep)

generator = tqdm(
zip(table.iterrows(), read_sdf(mol_path), read_sdf(frag_path), read_sdf(link_path)),
total=len(table),
desc=str(idx),
)
try:
for (_, row), molecule, fragments, linker in generator:
if row['keep']:
if molecule.GetProp('_Name') != row['molecule']:
print('Molecule _Name:', molecule.GetProp('_Name'), row['molecule'])
continue
generator = tqdm(
zip(table.iterrows(), read_sdf(mol_path), read_sdf(frag_path), read_sdf(link_path)),
total=len(table),
desc='Full data',
)
try:
for (_, row), molecule, fragments, linker in generator:
if row['keep']:
if molecule.GetProp('_Name') != row['molecule']:
print('Molecule _Name:', molecule.GetProp('_Name'), row['molecule'])
continue

full_table.append(row)
full_molecules.append(molecule)
full_fragments.append(fragments)
full_linkers.append(linker)
except:
pass
full_table.append(row)
full_molecules.append(molecule)
full_fragments.append(fragments)
full_linkers.append(linker)
except:
pass


full_table = pd.DataFrame(full_table)
full_table.to_csv(out_table_path, index=False)
with Chem.SDWriter(open(out_mol_path, 'w')) as writer:
Expand Down

0 comments on commit 26a815e

Please sign in to comment.