Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Aug 7, 2024
1 parent a03eb5c commit f3581f1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
3 changes: 1 addition & 2 deletions .github/unittest/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ python -m pip install --upgrade pip
# Install dependencies
python -m pip install flake8 pytest pytest-cov hydra-core tqdm
python -m pip install torch torchvision

# Ensure dependencies are installed in the right order
python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2
pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe

# Verify installations
python -c "import transformers; print(transformers.__version__)"
Expand Down
5 changes: 3 additions & 2 deletions .github/unittest/install_dependencies_nightly.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ python -m pip install --upgrade pip
# Install dependencies
python -m pip install flake8 pytest pytest-cov hydra-core tqdm
python -m pip install torch torchvision

# Ensure dependencies are installed in the right order
python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2
pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe

# Verify installations
python -c "import transformers; print(transformers.__version__)"
python -c "import promptsmiles"
# python -c "import mamba_ssm; print('mamba-ssm:', mamba_ssm.__version__)" # Assuming mamba-ssm imports as mamba

# Install local package
cd ../acegen-open
pip install -e .
pip uninstall --yes torchrl
pip uninstall --yes tensordict

# Install torchrl and tensordict nightly
cd ..
python -m pip install git+https://github.com/pytorch-labs/tensordict.git
git clone https://github.com/pytorch/rl.git
Expand Down
77 changes: 77 additions & 0 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest

from acegen.vocabulary.tokenizers import (
AISTokenizer,
SAFETokenizer,
DeepSMILESTokenizer,
SELFIESTokenizer,
SMILESTokenizerChEMBL,
SMILESTokenizerEnamine,
SMILESTokenizerGuacaMol,
)

try:
import deepsmiles

_has_deepsmiles = True
except ImportError as err:
_has_deepsmiles = False
DEEPSMILES_ERR = err
try:
import selfies

_has_selfies = True
except ImportError as err:
_has_selfies = False
SELFIES_ERR = err
try:
import smizip

_has_smizip = True
except ImportError as err:
_has_smizip = False
SMIZIP_ERR = err
try:
import atomInSmiles as AIS

_has_AIS = True
except ImportError as err:
_has_AIS = False
AIS_ERR = err
try:
import safe

_has_SAFE = True
except ImportError as err:
_has_SAFE = False
SAFE_ERR = err

multiple_smiles = [
"CCO", # Ethanol (C2H5OH)
"CCN(CC)CC", # Triethylamine (C6H15N)
"CC(=O)OC(C)C", # Diethyl carbonate (C7H14O3)
"CC(C)C", # Isobutane (C4H10)
"CC1=CC=CC=C1", # Toluene (C7H8)
]

@pytest.mark.parametrize("tokenizer, available, error", [
(DeepSMILESTokenizer, _has_deepsmiles, DEEPSMILES_ERR if not _has_deepsmiles else None),
(SELFIESTokenizer, _has_selfies, SELFIES_ERR if not _has_selfies else None),
(SMILESTokenizerChEMBL, True, None),
(SMILESTokenizerEnamine, True, None),
(SMILESTokenizerGuacaMol, True, None),
# (AISTokenizer, _has_AIS, AIS_ERR if not _has_AIS else None),
# (SAFETokenizer, _has_SAFE, SAFE_ERR if not _has_SAFE else None),
])
def test_smiles_based_tokenizers(tokenizer, available, error):
if not available:
pytest.skip(f"Skipping {tokenizer.__name__} test because the required module is not available: {error}")
for smiles in multiple_smiles:
t = tokenizer()
tokens = t.tokenize(smiles)
assert len(tokens) > 0
assert isinstance(tokens, list)
assert isinstance(tokens[0], str)
decoded_smiles = t.untokenize(tokens)
assert decoded_smiles == smiles

0 comments on commit f3581f1

Please sign in to comment.