diff --git a/CHANGELOG.md b/CHANGELOG.md index 040e87e3..38b4f869 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Config option `max_iters` has been renamed to `cosine_schedule_period_iters` to better reflect that it controls the number of iterations for the cosine half period of the learning rate. +### Fixed + +- Fix beam search caching failure when multiple beams have an equal predicted peptide score by breaking ties randomly. + ## [4.1.0] - 2024-02-16 ### Changed diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 50d43047..77df6df5 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -482,7 +482,9 @@ def _cache_finished_beams( step: int, beams_to_cache: torch.Tensor, beam_fits_precursor: torch.Tensor, - pred_cache: Dict[int, List[Tuple[float, np.ndarray, torch.Tensor]]], + pred_cache: Dict[ + int, List[Tuple[float, float, np.ndarray, torch.Tensor]] + ], ): """ Cache terminated beams. @@ -503,11 +505,13 @@ def _cache_finished_beams( beam_fits_precursor: torch.Tensor of shape (n_spectra * n_beams) Boolean tensor indicating whether the beams are within the precursor m/z tolerance. - pred_cache : Dict[int, List[Tuple[float, np.ndarray, torch.Tensor]]] + pred_cache : Dict[ + int, List[Tuple[float, float, np.ndarray, torch.Tensor]] + ] Priority queue with finished beams for each spectrum, ordered by peptide score. For each finished beam, a tuple with the (negated) - peptide score, amino acid-level scores, and the predicted tokens is - stored. + peptide score, a random tie-breaking float, the amino acid-level + scores, and the predicted tokens is stored. """ for i in range(len(beams_to_cache)): if not beams_to_cache[i]: @@ -548,7 +552,12 @@ def _cache_finished_beams( heapadd = heapq.heappushpop heapadd( pred_cache[spec_idx], - (peptide_score, aa_scores, torch.clone(pred_peptide)), + ( + peptide_score, + np.random.random_sample(), + aa_scores, + torch.clone(pred_peptide), + ), ) def _get_topk_beams( @@ -646,17 +655,22 @@ def _get_topk_beams( def _get_top_peptide( self, - pred_cache: Dict[int, List[Tuple[float, np.ndarray, torch.Tensor]]], + pred_cache: Dict[ + int, List[Tuple[float, float, np.ndarray, torch.Tensor]] + ], ) -> Iterable[List[Tuple[float, np.ndarray, str]]]: """ Return the peptide with the highest confidence score for each spectrum. Parameters ---------- - pred_cache : Dict[int, List[Tuple[float, np.ndarray, torch.Tensor]]] + pred_cache : Dict[ + int, List[Tuple[float, float, np.ndarray, torch.Tensor]] + ] Priority queue with finished beams for each spectrum, ordered by peptide score. For each finished beam, a tuple with the peptide - score, amino acid-level scores, and the predicted tokens is stored. + score, a random tie-breaking float, the amino acid-level scores, + and the predicted tokens is stored. Returns ------- @@ -673,7 +687,7 @@ def _get_top_peptide( aa_scores, "".join(self.decoder.detokenize(pred_tokens)), ) - for pep_score, aa_scores, pred_tokens in heapq.nlargest( + for pep_score, _, aa_scores, pred_tokens in heapq.nlargest( self.top_match, peptides ) ] diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 61d61efa..f615a099 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -203,7 +203,7 @@ def test_beam_search_decode(): ) # Verify that the correct peptides have been cached. correct_cached = 0 - for _, _, pep in pred_cache[0]: + for _, _, _, pep in pred_cache[0]: if torch.equal(pep, torch.tensor([4, 14, 4, 13])): correct_cached += 1 elif torch.equal(pep, torch.tensor([4, 14, 4, 18])): @@ -220,13 +220,13 @@ def test_beam_search_decode(): # Return the candidate peptide with the highest score test_cache = collections.OrderedDict((i, []) for i in range(batch)) heapq.heappush( - test_cache[0], (0.93, 4 * [0.93], torch.tensor([4, 14, 4, 19])) + test_cache[0], (0.93, 0.1, 4 * [0.93], torch.tensor([4, 14, 4, 19])) ) heapq.heappush( - test_cache[0], (0.95, 4 * [0.95], torch.tensor([4, 14, 4, 13])) + test_cache[0], (0.95, 0.2, 4 * [0.95], torch.tensor([4, 14, 4, 13])) ) heapq.heappush( - test_cache[0], (0.94, 4 * [0.94], torch.tensor([4, 14, 4, 4])) + test_cache[0], (0.94, 0.3, 4 * [0.94], torch.tensor([4, 14, 4, 4])) ) assert list(model._get_top_peptide(test_cache))[0][0][-1] == "PEPK" @@ -296,7 +296,7 @@ def test_beam_search_decode(): ) # Verify predictions with matching/non-matching precursor m/z. positive_score = negative_score = 0 - for peptide_score, _, _ in pred_cache[0]: + for peptide_score, _, _, _ in pred_cache[0]: positive_score += peptide_score >= 0 negative_score += peptide_score < 0 assert positive_score == 2 @@ -435,7 +435,7 @@ def test_beam_search_decode(): vocab = model.decoder.vocab_size + 1 # V step = 4 - # Initialize dummyy scores and tokens. + # Initialize dummy scores and tokens. scores = torch.full( size=(batch, length, vocab, beam), fill_value=torch.nan ) @@ -467,6 +467,37 @@ def test_beam_search_decode(): assert torch.equal(new_tokens[:, : step + 1], expected_tokens) + # Test that duplicate peptide scores don't lead to a conflict in the cache. + model = Spec2Pep(n_beams=5, residues="massivekb", min_peptide_len=3) + batch = 2 # B + beam = model.n_beams # S + model.decoder.reverse = True + length = model.max_length + 1 # L + vocab = model.decoder.vocab_size + 1 # V + step = 4 + + # Simulate beams with identical amino acid scores but different tokens. + scores = torch.zeros(size=(batch * beam, length, vocab)) + scores[: batch * beam, : step + 1, :] = torch.rand(1) + tokens = torch.zeros(batch * beam, length, dtype=torch.int64) + tokens[: batch * beam, :step] = torch.randint( + 1, vocab, (batch * beam, step) + ) + + pred_cache = collections.OrderedDict((i, []) for i in range(batch)) + model._cache_finished_beams( + tokens, + scores, + step, + torch.ones(batch * beam, dtype=torch.bool), + torch.ones(batch * beam, dtype=torch.bool), + pred_cache, + ) + for beam_i, preds in pred_cache.items(): + assert len(preds) == beam + peptide_scores = [pep[0] for pep in preds] + assert np.allclose(peptide_scores, peptide_scores[0]) + def test_eval_metrics(): """