Skip to content

Commit

Permalink
tokenizer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Aug 8, 2024
1 parent f3581f1 commit f11dc58
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/unittest/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ python -m pip install --upgrade pip
python -m pip install flake8 pytest pytest-cov hydra-core tqdm
python -m pip install torch torchvision
python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2
pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe
pip install deepsmiles selfies smi2sdf smi2svg safe-mol # atomInSmiles

# Verify installations
python -c "import transformers; print(transformers.__version__)"
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/install_dependencies_nightly.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ python -m pip install --upgrade pip
python -m pip install flake8 pytest pytest-cov hydra-core tqdm
python -m pip install torch torchvision
python -m pip install transformers promptsmiles torchrl rdkit==2023.3.3 MolScore # causal-conv1d>=1.4.0 mamba-ssm==1.2.2
pip install deepsmiles selfies smi2sdf smi2svg # atomInSmiles safe
pip install deepsmiles selfies smi2sdf smi2svg safe-mol # atomInSmiles

# Verify installations
python -c "import transformers; print(transformers.__version__)"
Expand Down
35 changes: 23 additions & 12 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from acegen.vocabulary.tokenizers import (
AISTokenizer,
SAFETokenizer,
DeepSMILESTokenizer,
SAFETokenizer,
SELFIESTokenizer,
SMILESTokenizerChEMBL,
SMILESTokenizerEnamine,
SMILESTokenizerGuacaMol,
SmiZipTokenizer,
)

try:
Expand Down Expand Up @@ -54,18 +55,29 @@
"CC1=CC=CC=C1", # Toluene (C7H8)
]

@pytest.mark.parametrize("tokenizer, available, error", [
(DeepSMILESTokenizer, _has_deepsmiles, DEEPSMILES_ERR if not _has_deepsmiles else None),
(SELFIESTokenizer, _has_selfies, SELFIES_ERR if not _has_selfies else None),
(SMILESTokenizerChEMBL, True, None),
(SMILESTokenizerEnamine, True, None),
(SMILESTokenizerGuacaMol, True, None),
# (AISTokenizer, _has_AIS, AIS_ERR if not _has_AIS else None),
# (SAFETokenizer, _has_SAFE, SAFE_ERR if not _has_SAFE else None),
])

@pytest.mark.parametrize(
"tokenizer, available, error",
[
(
DeepSMILESTokenizer,
_has_deepsmiles,
DEEPSMILES_ERR if not _has_deepsmiles else None,
),
(SELFIESTokenizer, _has_selfies, SELFIES_ERR if not _has_selfies else None),
(SMILESTokenizerChEMBL, True, None),
(SMILESTokenizerEnamine, True, None),
(SMILESTokenizerGuacaMol, True, None),
(AISTokenizer, _has_AIS, AIS_ERR if not _has_AIS else None),
(SAFETokenizer, _has_SAFE, SAFE_ERR if not _has_SAFE else None),
(SmiZipTokenizer, _has_smizip, SMIZIP_ERR if not _has_smizip else None),
],
)
def test_smiles_based_tokenizers(tokenizer, available, error):
if not available:
pytest.skip(f"Skipping {tokenizer.__name__} test because the required module is not available: {error}")
pytest.skip(
f"Skipping {tokenizer.__name__} test because the required module is not available: {error}"
)
for smiles in multiple_smiles:
t = tokenizer()
tokens = t.tokenize(smiles)
Expand All @@ -74,4 +86,3 @@ def test_smiles_based_tokenizers(tokenizer, available, error):
assert isinstance(tokens[0], str)
decoded_smiles = t.untokenize(tokens)
assert decoded_smiles == smiles

0 comments on commit f11dc58

Please sign in to comment.