Skip to content

Commit

Permalink
Reversed peptide aa scores hotfix (#417)
Browse files Browse the repository at this point in the history
* reverse aa scores hotfix

* reverse aa scores hotfix
  • Loading branch information
Lilferrit authored Dec 13, 2024
1 parent 9e3f3d1 commit 23c02f6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def _get_top_peptide(
yield [
(
pep_score,
aa_scores,
aa_scores[::-1] if self.decoder.reverse else aa_scores,
"".join(self.decoder.detokenize(pred_tokens)),
)
for pep_score, _, aa_scores, pred_tokens in heapq.nlargest(
Expand Down
19 changes: 19 additions & 0 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,25 @@ def test_beam_search_decode():
[pep[-1] for pep in list(model._get_top_peptide(test_cache))[0]]
) == {"PEPK", "PEPP"}

# Test reverse aa scores when decoder is reversed
pred_cache = {
0: [(1.0, 0.42, np.array([1.0, 0.0]), torch.Tensor([4, 14]))]
}

model.decoder.reverse = True
top_peptides = list(model._get_top_peptide(pred_cache))
assert len(top_peptides) == 1
assert len(top_peptides[0]) == 1
assert np.allclose(top_peptides[0][0][1], np.array([0.0, 1.0]))
assert top_peptides[0][0][2] == "EP"

model.decoder.reverse = False
top_peptides = list(model._get_top_peptide(pred_cache))
assert len(top_peptides) == 1
assert len(top_peptides[0]) == 1
assert np.allclose(top_peptides[0][0][1], np.array([1.0, 0.0]))
assert top_peptides[0][0][2] == "PE"

# Test _get_topk_beams().
# Set scores to proceed generating the unfinished beam.
step = 4
Expand Down

0 comments on commit 23c02f6

Please sign in to comment.