Skip to content

Commit

Permalink
improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Aug 8, 2024
1 parent f3581f1 commit dcae46d
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 27 deletions.
3 changes: 2 additions & 1 deletion .github/unittest/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ python -m pip install --upgrade pip
python -m pip install flake8 pytest pytest-cov hydra-core tqdm
python -m pip install torch torchvision
python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2
pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe
python -m pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe
python -m pip install molbloom

# Verify installations
python -c "import transformers; print(transformers.__version__)"
Expand Down
3 changes: 2 additions & 1 deletion .github/unittest/install_dependencies_nightly.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ python -m pip install --upgrade pip
python -m pip install flake8 pytest pytest-cov hydra-core tqdm
python -m pip install torch torchvision
python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2
pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe
python -m pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe
python -m pip install molbloom

# Verify installations
python -c "import transformers; print(transformers.__version__)"
Expand Down
Binary file added tests/data/smiles_test_set.bloom
Binary file not shown.
53 changes: 52 additions & 1 deletion tests/test_data_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
import os
import shutil
import tempfile

import pytest
import torch
from acegen.data import smiles_to_tensordict

from acegen.data import (
load_dataset,
MolBloomDataset,
smiles_to_tensordict,
SMILESDataset,
)
from acegen.vocabulary.tokenizers import SMILESTokenizerChEMBL
from acegen.vocabulary.vocabulary import Vocabulary
from tensordict import TensorDict
from torch.utils.data import DataLoader


try:
from molbloom import BloomFilter, CustomFilter

_has_molbloom = True
except ImportError:
_has_molbloom = False


def test_smiles_to_tensordict():
Expand Down Expand Up @@ -51,3 +72,33 @@ def test_smiles_to_tensordict():

# Check rewards are in the right position
assert (result["next"]["reward"][next_tensordict["done"]].cpu() == reward).all()


@pytest.mark.parametrize("randomize_smiles", [False, True])
def test_load_dataset(randomize_smiles):
dataset_path = os.path.dirname(__file__) + "/data/smiles_test_set"
dataset_str = load_dataset(dataset_path)
assert type(dataset_str) == list
assert len(dataset_str) == 1000
vocab = Vocabulary.create_from_strings(
dataset_str, tokenizer=SMILESTokenizerChEMBL()
)
temp_dir = tempfile.mkdtemp()
dataset = SMILESDataset(
cache_path=temp_dir,
dataset_path=dataset_path,
vocabulary=vocab,
randomize_smiles=randomize_smiles,
)
if _has_molbloom:
molbloom_dataset = MolBloomDataset(dataset_path=dataset_path)
assert dataset_str[0] in molbloom_dataset
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=dataset.collate_fn,
)
data_batch = dataloader.__iter__().__next__()
assert isinstance(data_batch, TensorDict)
shutil.rmtree(temp_dir)
33 changes: 21 additions & 12 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from acegen.vocabulary.tokenizers import (
AISTokenizer,
SAFETokenizer,
DeepSMILESTokenizer,
SAFETokenizer,
SELFIESTokenizer,
SMILESTokenizerChEMBL,
SMILESTokenizerEnamine,
Expand Down Expand Up @@ -54,18 +54,28 @@
"CC1=CC=CC=C1", # Toluene (C7H8)
]

@pytest.mark.parametrize("tokenizer, available, error", [
(DeepSMILESTokenizer, _has_deepsmiles, DEEPSMILES_ERR if not _has_deepsmiles else None),
(SELFIESTokenizer, _has_selfies, SELFIES_ERR if not _has_selfies else None),
(SMILESTokenizerChEMBL, True, None),
(SMILESTokenizerEnamine, True, None),
(SMILESTokenizerGuacaMol, True, None),
# (AISTokenizer, _has_AIS, AIS_ERR if not _has_AIS else None),
# (SAFETokenizer, _has_SAFE, SAFE_ERR if not _has_SAFE else None),
])

@pytest.mark.parametrize(
"tokenizer, available, error",
[
(
DeepSMILESTokenizer,
_has_deepsmiles,
DEEPSMILES_ERR if not _has_deepsmiles else None,
),
(SELFIESTokenizer, _has_selfies, SELFIES_ERR if not _has_selfies else None),
(SMILESTokenizerChEMBL, True, None),
(SMILESTokenizerEnamine, True, None),
(SMILESTokenizerGuacaMol, True, None),
# (AISTokenizer, _has_AIS, AIS_ERR if not _has_AIS else None),
# (SAFETokenizer, _has_SAFE, SAFE_ERR if not _has_SAFE else None),
],
)
def test_smiles_based_tokenizers(tokenizer, available, error):
if not available:
pytest.skip(f"Skipping {tokenizer.__name__} test because the required module is not available: {error}")
pytest.skip(
f"Skipping {tokenizer.__name__} test because the required module is not available: {error}"
)
for smiles in multiple_smiles:
t = tokenizer()
tokens = t.tokenize(smiles)
Expand All @@ -74,4 +84,3 @@ def test_smiles_based_tokenizers(tokenizer, available, error):
assert isinstance(tokens[0], str)
decoded_smiles = t.untokenize(tokens)
assert decoded_smiles == smiles

Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Tutorial: Understanding the SMILES Environment
# Tutorial: Understanding the token Environment

---

In this tutorial, we will demonstrate how to create an AceGen environment for smiles generation. We will also explain
how to interact with it using TensorDicts and how to understand its expected inputs and outputs.
In this tutorial, we will demonstrate how to create an AceGen environment for smiles generation.
The same approach could be used for other chemical languages.
We will also explain how to interact with this environment using TensorDicts and how to understand its expected inputs and outputs.

## Prerequisite Knowledge on TorchRL and Tensordict

Expand All @@ -15,10 +16,9 @@ dictionary of tensors as input and return a dictionary of tensors as output.

---

## What is the SMILES environment?
## What is the token environment?

The SMILES environment is a Tensordict-compatible environment for molecule generation with SMILES, and a key component
of the AceGen library. In particular, it is the component that manages the segment of the RL loop responsible for
The token environment is a Tensordict-compatible environment for molecule generation with language, and a key component of the AceGen library. In particular, it is the component that manages the segment of the RL loop responsible for
providing observations in response to the agent's actions. This environment class inherits from the TorchRL base
environment component ``EnvBase``, providing a range of advantages that include input and output data transformations,
compatibility with Gym-based APIs, efficient vectorized options (enabling the generation of multiple molecules in parallel),
Expand All @@ -27,11 +27,11 @@ environment, all TorchRL components become available for creating potential RL s

---

## How to create a SMILES environment?
## How to create a token environment?

### Create a vocabulary

To create a SMILES environment, we first need to create a vocabulary. The vocabulary maps characters to indices and
To create a token environment, we first need to create a vocabulary. The vocabulary maps characters to indices and
vice versa. There are 3 ways to create a vocabulary in AceGen.

1. Create a vocabulary from a list of characters
Expand Down Expand Up @@ -86,9 +86,9 @@ env = TokenEnv(

---

## How to interact with the SMILES environment?
## How to interact with the token environment?

To start exploring how to use the SMILES environment, we can create an initial observation
To start exploring how to use the token environment, we can create an initial observation

```python
initial_td = env.reset()
Expand Down Expand Up @@ -226,9 +226,9 @@ print([vocab1.decode(seq) for seq in rollout["action"].numpy()])

---

## What are the exact expected inputs and outputs of the SMILES environment?
## What are the exact expected inputs and outputs of the token environment?

We can better understand the expected inputs and outputs of the SMILES environment by running the following code
We can better understand the expected inputs and outputs of the token environment by running the following code
snippets, which will print the full action, observation, done, and reward specs.

```python
Expand Down

0 comments on commit dcae46d

Please sign in to comment.