Skip to content

Commit

Permalink
fix: local variable sin referenced before assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
lvdongyi committed Nov 11, 2024
1 parent 2bb7115 commit dd5c52d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions paddlenlp/transformers/llama/modeling.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ def forward(
"with a layer index."
)
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)

sin = cos = None
if self.config.rope:
if self.reshard_layer is not None:
batch_size, seq_length, _, _ = query_states.shape
Expand Down Expand Up @@ -1066,11 +1066,12 @@ def forward(
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

cache_kwargs = {}
# [bs, seq_len, num_head, head_dim]
if past_key_value is not None:
# reuse k, v, self_attention
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
if sin is not None and cos is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if self.kv_indices is not None:
Expand Down

0 comments on commit dd5c52d

Please sign in to comment.