Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lvdongyi committed Dec 6, 2024
1 parent 7d7904e commit 257b3b3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions paddlenlp/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def reorder_cache(self, beam_idx: paddle.Tensor):
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx].numel() != 0:
device = self.key_cache[layer_idx].place
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(beam_idx.to(device), 0)
if self.value_cache[layer_idx].numel() != 0:
device = self.value_cache[layer_idx].place
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(beam_idx.to(device), 0)

@property
def seen_tokens(self):
Expand Down Expand Up @@ -474,8 +474,8 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[paddle.Tensor]]:
# Now deal with beam search ops which were delayed
if self.beam_idx is not None:
self.beam_idx = self.beam_idx.to(original_device)
key_tensor = key_tensor.index_select(0, self.beam_idx)
value_tensor = value_tensor.index_select(0, self.beam_idx)
key_tensor = key_tensor.index_select(self.beam_idx, 0)
value_tensor = value_tensor.index_select(self.beam_idx, 0)
# Prefetch the next layer
self.prefetch_layer((layer_idx + 1) % len(self))
return (key_tensor, value_tensor)
Expand Down

0 comments on commit 257b3b3

Please sign in to comment.