From 666132d2622337d6c38f4f6cff2966d698a06b6a Mon Sep 17 00:00:00 2001 From: Wout Bittremieux Date: Tue, 27 Feb 2024 07:58:22 +0100 Subject: [PATCH] Don't remove unit test --- tests/unit_tests/test_unit.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 426134ed..f615a099 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -450,6 +450,23 @@ def test_beam_search_decode(): tokens[0, :step] = torch.tensor([4, 14, 4, 28]) tokens[1, :step] = torch.tensor([4, 14, 4, 1]) + # Set finished beams array to allow decoding from only one beam. + test_finished_beams = torch.tensor([True, False]) + + new_tokens, new_scores = model._get_topk_beams( + tokens, scores, test_finished_beams, batch, step + ) + + # Only the second peptide should have a new token predicted. + expected_tokens = torch.tensor( + [ + [4, 14, 4, 28, 0], + [4, 14, 4, 1, 3], + ] + ) + + assert torch.equal(new_tokens[:, : step + 1], expected_tokens) + # Test that duplicate peptide scores don't lead to a conflict in the cache. model = Spec2Pep(n_beams=5, residues="massivekb", min_peptide_len=3) batch = 2 # B