diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 9e5a34f7..909bac00 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -268,6 +268,7 @@ def beam_search_decode( scores = einops.rearrange(scores, "B L V S -> (B S) L V") # The main decoding loop. + self.n_term = self.n_term.to(self.decoder.device) for step in range(0, self.max_length): # Terminate beams exceeding the precursor m/z tolerance and track # all finished beams (either terminated or stop token predicted).