diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 426134ed..f615a099 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -450,6 +450,23 @@ def test_beam_search_decode(): tokens[0, :step] = torch.tensor([4, 14, 4, 28]) tokens[1, :step] = torch.tensor([4, 14, 4, 1]) + # Set finished beams array to allow decoding from only one beam. + test_finished_beams = torch.tensor([True, False]) + + new_tokens, new_scores = model._get_topk_beams( + tokens, scores, test_finished_beams, batch, step + ) + + # Only the second peptide should have a new token predicted. + expected_tokens = torch.tensor( + [ + [4, 14, 4, 28, 0], + [4, 14, 4, 1, 3], + ] + ) + + assert torch.equal(new_tokens[:, : step + 1], expected_tokens) + # Test that duplicate peptide scores don't lead to a conflict in the cache. model = Spec2Pep(n_beams=5, residues="massivekb", min_peptide_len=3) batch = 2 # B