diff --git a/utils/difflinker_sample_and_analyze.py b/utils/difflinker_sample_and_analyze.py index ebcbbee8..557e0562 100644 --- a/utils/difflinker_sample_and_analyze.py +++ b/utils/difflinker_sample_and_analyze.py @@ -7,11 +7,11 @@ import torch from rdkit import Chem -from src import const -from src.datasets import collate_with_fragment_edges, get_dataloader, parse_molecule -from src.lightning import DDPM -from src.linker_size_lightning import SizeClassifier -from src.visualizer import save_xyz_file, visualize_chain +from .src import const +from .src.datasets import collate_with_fragment_edges, get_dataloader, parse_molecule +from .src.lightning import DDPM +from .src.linker_size_lightning import SizeClassifier +from .src.visualizer import save_xyz_file, visualize_chain from tqdm import tqdm @@ -177,6 +177,10 @@ def sample_fn(_data): save_xyz_file(output_dir, h, x, node_mask, names=names, is_geom=ddpm.is_geom, suffix='') +def run_dflk_sample_analyze(input_path, model, linker_size, output_dir="./", n_samples=5, n_steps=None, anchors=None): + main(input_path=input_path, model=model, output_dir=output_dir, n_samples=n_samples, n_steps=n_steps, linker_size=str(linker_size), anchors=anchors) + + if __name__ == '__main__': start_time = time.time() args = parser.parse_args()