diff --git a/rxnmapper/core.py b/rxnmapper/core.py index 202fbac..260066d 100644 --- a/rxnmapper/core.py +++ b/rxnmapper/core.py @@ -128,6 +128,14 @@ def convert_batch_to_attns( return_tensors="pt", ) parsed_input = {k: v.to(self.device) for k, v in encoded_ids.items()} + + max_input_length = parsed_input["input_ids"].shape[1] + max_supported_by_model = self.model.config.max_position_embeddings + if max_input_length > max_supported_by_model: + raise ValueError( + f"Reaction SMILES has {max_input_length} tokens, should be at most {max_supported_by_model}." + ) + with torch.no_grad(): output = self.model(**parsed_input) attentions = output[2] diff --git a/tests/test_mapper.py b/tests/test_mapper.py index 86789af..e93d3eb 100644 --- a/tests/test_mapper.py +++ b/tests/test_mapper.py @@ -137,3 +137,15 @@ def test_reaction_with_asterisks(rxn_mapper: RXNMapper): results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False) assert_correct_maps(results, expected) + + +def test_too_long_reaction_smiles_produce_exception_with_understandable_error_message( + rxn_mapper: RXNMapper, +): + # dummy reaction with 1 + 3 + 500 * 2 + 3 + 1 = 1008 tokens + rxn = "C=C" + "[C+][C-]" * 500 + ">>CC" + + with pytest.raises(ValueError) as excinfo: + _ = rxn_mapper.get_attention_guided_atom_maps([rxn], canonicalize_rxns=False) + + assert "Reaction SMILES has 1008 tokens, should be" in str(excinfo.value)