diff --git a/.github/unittest/install_dependencies.sh b/.github/unittest/install_dependencies.sh index fe486931..cb176267 100644 --- a/.github/unittest/install_dependencies.sh +++ b/.github/unittest/install_dependencies.sh @@ -5,7 +5,8 @@ python -m pip install --upgrade pip python -m pip install flake8 pytest pytest-cov hydra-core tqdm python -m pip install torch torchvision python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2 -pip install deepsmiles selfies smi2sdf smi2svg safe-mol # atomInSmiles +python -m pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe +python -m pip install molbloom # Verify installations python -c "import transformers; print(transformers.__version__)" diff --git a/.github/unittest/install_dependencies_nightly.sh b/.github/unittest/install_dependencies_nightly.sh index 2de98c3f..751fc97b 100644 --- a/.github/unittest/install_dependencies_nightly.sh +++ b/.github/unittest/install_dependencies_nightly.sh @@ -5,7 +5,8 @@ python -m pip install --upgrade pip python -m pip install flake8 pytest pytest-cov hydra-core tqdm python -m pip install torch torchvision python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2 -pip install deepsmiles selfies smi2sdf smi2svg safe-mol # atomInSmiles +python -m pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe +python -m pip install molbloom # Verify installations python -c "import transformers; print(transformers.__version__)" diff --git a/README.md b/README.md index c0cc09ce..39d47bc4 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ [![python](https://img.shields.io/badge/python-3.9%20|%203.10%20|%203.11-blue)](https://www.python.org/downloads/) [![arXiv](https://img.shields.io/badge/arXiv-2405.04657-red.svg)](https://arxiv.org/abs/2405.04657) [![JCIM](https://img.shields.io/badge/JCIM-10.1021%2Facs.jcim.4c00895-blue)](https://doi.org/10.1021/acs.jcim.4c00895) +[![unit-tests](https://github.com/Acellera/acegen-open/actions/workflows/unit_tests.yml/badge.svg)](https://github.com/Acellera/acegen-open/actions/workflows/unit_tests.yml) + --- diff --git a/tests/data/smiles_test_set.bloom b/tests/data/smiles_test_set.bloom new file mode 100644 index 00000000..0e9f9bde Binary files /dev/null and b/tests/data/smiles_test_set.bloom differ diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index d691ec21..0497b27f 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -1,7 +1,28 @@ +import os +import shutil +import tempfile + +import pytest import torch -from acegen.data import smiles_to_tensordict +from acegen.data import ( + load_dataset, + MolBloomDataset, + smiles_to_tensordict, + SMILESDataset, +) +from acegen.vocabulary.tokenizers import SMILESTokenizerChEMBL +from acegen.vocabulary.vocabulary import Vocabulary from tensordict import TensorDict +from torch.utils.data import DataLoader + + +try: + from molbloom import BloomFilter, CustomFilter + + _has_molbloom = True +except ImportError: + _has_molbloom = False def test_smiles_to_tensordict(): @@ -51,3 +72,33 @@ def test_smiles_to_tensordict(): # Check rewards are in the right position assert (result["next"]["reward"][next_tensordict["done"]].cpu() == reward).all() + + +@pytest.mark.parametrize("randomize_smiles", [False, True]) +def test_load_dataset(randomize_smiles): + dataset_path = os.path.dirname(__file__) + "/data/smiles_test_set" + dataset_str = load_dataset(dataset_path) + assert type(dataset_str) == list + assert len(dataset_str) == 1000 + vocab = Vocabulary.create_from_strings( + dataset_str, tokenizer=SMILESTokenizerChEMBL() + ) + temp_dir = tempfile.mkdtemp() + dataset = SMILESDataset( + cache_path=temp_dir, + dataset_path=dataset_path, + vocabulary=vocab, + randomize_smiles=randomize_smiles, + ) + if _has_molbloom: + molbloom_dataset = MolBloomDataset(dataset_path=dataset_path) + assert dataset_str[0] in molbloom_dataset + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.collate_fn, + ) + data_batch = dataloader.__iter__().__next__() + assert isinstance(data_batch, TensorDict) + shutil.rmtree(temp_dir) diff --git a/tests/test_scoring_function.py b/tests/test_scoring_function.py new file mode 100644 index 00000000..66dfab34 --- /dev/null +++ b/tests/test_scoring_function.py @@ -0,0 +1,30 @@ +import os +import shutil +import tempfile + +from acegen.scoring_functions import ( + custom_scoring_functions, + register_custom_scoring_function, + Task, +) + + +def test_scoring_functions_utils(): + temp_dir = tempfile.mkdtemp() + register_custom_scoring_function("QED", "acegen.scoring_functions.chemistry.QED") + task = Task( + name="QED2", + scoring_function=custom_scoring_functions["QED"], + budget=4, + output_dir=temp_dir, + ) + assert not task.finished + counter = 0 + for i in range(4): + score = task(["CC1=CC=CC=C1"]) + assert len(score) == 1 + counter += 1 + assert counter == 4 + assert task.finished + assert os.path.isfile(f"{temp_dir}/compounds.csv") + shutil.rmtree(temp_dir) diff --git a/tutorials/understanding_the_smiles_environment.md b/tutorials/understanding_the_token_environment.md similarity index 93% rename from tutorials/understanding_the_smiles_environment.md rename to tutorials/understanding_the_token_environment.md index 581bb635..e4347891 100644 --- a/tutorials/understanding_the_smiles_environment.md +++ b/tutorials/understanding_the_token_environment.md @@ -1,9 +1,10 @@ -# Tutorial: Understanding the SMILES Environment +# Tutorial: Understanding the token Environment --- -In this tutorial, we will demonstrate how to create an AceGen environment for smiles generation. We will also explain -how to interact with it using TensorDicts and how to understand its expected inputs and outputs. +In this tutorial, we will demonstrate how to create an AceGen environment for smiles generation. +The same approach could be used for other chemical languages. +We will also explain how to interact with this environment using TensorDicts and how to understand its expected inputs and outputs. ## Prerequisite Knowledge on TorchRL and Tensordict @@ -15,10 +16,9 @@ dictionary of tensors as input and return a dictionary of tensors as output. --- -## What is the SMILES environment? +## What is the token environment? -The SMILES environment is a Tensordict-compatible environment for molecule generation with SMILES, and a key component -of the AceGen library. In particular, it is the component that manages the segment of the RL loop responsible for +The token environment is a Tensordict-compatible environment for molecule generation with language, and a key component of the AceGen library. In particular, it is the component that manages the segment of the RL loop responsible for providing observations in response to the agent's actions. This environment class inherits from the TorchRL base environment component ``EnvBase``, providing a range of advantages that include input and output data transformations, compatibility with Gym-based APIs, efficient vectorized options (enabling the generation of multiple molecules in parallel), @@ -27,11 +27,11 @@ environment, all TorchRL components become available for creating potential RL s --- -## How to create a SMILES environment? +## How to create a token environment? ### Create a vocabulary -To create a SMILES environment, we first need to create a vocabulary. The vocabulary maps characters to indices and +To create a token environment, we first need to create a vocabulary. The vocabulary maps characters to indices and vice versa. There are 3 ways to create a vocabulary in AceGen. 1. Create a vocabulary from a list of characters @@ -86,9 +86,9 @@ env = TokenEnv( --- -## How to interact with the SMILES environment? +## How to interact with the token environment? -To start exploring how to use the SMILES environment, we can create an initial observation +To start exploring how to use the token environment, we can create an initial observation ```python initial_td = env.reset() @@ -226,9 +226,9 @@ print([vocab1.decode(seq) for seq in rollout["action"].numpy()]) --- -## What are the exact expected inputs and outputs of the SMILES environment? +## What are the exact expected inputs and outputs of the token environment? -We can better understand the expected inputs and outputs of the SMILES environment by running the following code +We can better understand the expected inputs and outputs of the token environment by running the following code snippets, which will print the full action, observation, done, and reward specs. ```python