diff --git a/src/rxn/chemutils/smiles_augmenter.py b/src/rxn/chemutils/smiles_augmenter.py index acd47f0..bf344c1 100644 --- a/src/rxn/chemutils/smiles_augmenter.py +++ b/src/rxn/chemutils/smiles_augmenter.py @@ -1,8 +1,12 @@ +import logging import random from typing import Callable, List from .miscellaneous import apply_to_any_smiles, apply_to_smiles_groups +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + class SmilesAugmenter: """ @@ -15,6 +19,7 @@ def __init__( augmentation_fn: Callable[[str], str], augmentation_probability: float = 1.0, shuffle: bool = True, + ignore_exceptions: bool = True, ): """ Args: @@ -23,10 +28,14 @@ def __init__( augmentation_probability: Probability with which to augment individual SMILES strings. shuffle: Whether to shuffle the order of the compounds. + ignore_exceptions: Whether to ignore the error (and return the + original string) when an augmentation fails. If False, exceptions + will be propagated. """ self.augmentation_fn = augmentation_fn self.augmentation_probability = augmentation_probability self.shuffle = shuffle + self.ignore_exceptions = ignore_exceptions def augment(self, smiles: str, number_augmentations: int) -> List[str]: """ @@ -61,7 +70,14 @@ def _augment_with_probability(self, smiles: str) -> str: self.augmentation_probability == 1.0 or random.uniform(0, 1) <= self.augmentation_probability ): - return self.augmentation_fn(smiles) + try: + return self.augmentation_fn(smiles) + except Exception as e: + if self.ignore_exceptions: + logger.warning(f"Augmentation failed for {smiles}: {e}") + return smiles + else: + raise # no augmentation return smiles diff --git a/tests/test_smiles_augmenter.py b/tests/test_smiles_augmenter.py index 71729ae..ed0ef0d 100644 --- a/tests/test_smiles_augmenter.py +++ b/tests/test_smiles_augmenter.py @@ -1,8 +1,10 @@ import random +import pytest from rxn.utilities.basic import identity from rxn.utilities.containers import all_identical +from rxn.chemutils.exceptions import InvalidSmiles from rxn.chemutils.smiles_augmenter import SmilesAugmenter from rxn.chemutils.smiles_randomization import randomize_smiles_rotated @@ -136,3 +138,19 @@ def test_reproducibility() -> None: # sampling one more time without resetting the seed -> results change results.append(augmenter.augment(rxn_smiles_3, 5)) assert not all_identical(results) + + +def test_augmentation_errors() -> None: + augmenter = SmilesAugmenter( + augmentation_fn=randomize_smiles_rotated, ignore_exceptions=True + ) + + invalid_smiles = "thisisinvalid" + + # When errors are ignored: returns the original input + assert augmenter.augment(invalid_smiles, 1) == [invalid_smiles] + + # When errors are not ignored: raises an exception + augmenter.ignore_exceptions = False + with pytest.raises(InvalidSmiles): + _ = augmenter.augment(invalid_smiles, 1)