Skip to content

Commit

Permalink
Merge pull request #1130 from instadeepai/fix/sable-pos-encoding
Browse files Browse the repository at this point in the history
fix: limit timestep-pos-encoding to rec-Sable
  • Loading branch information
RuanJohn authored Nov 8, 2024
2 parents 73537c5 + 3ddcbff commit 6092dc6
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mava/networks/retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ def recurrent(
"""Recurrent representation of the multi-scale retention mechanism"""
B, S, _ = value_n.shape

# Positional encoding of the current step
key_n, query_n, value_n = self.pe(key_n, query_n, value_n, step_count)
# Positional encoding of the current step if enabled
if self.memory_config.timestep_positional_encoding:
key_n, query_n, value_n = self.pe(key_n, query_n, value_n, step_count)

ret_output = jnp.zeros((B, S, self.head_size), dtype=value_n.dtype)
for head in range(self.n_head):
Expand Down

0 comments on commit 6092dc6

Please sign in to comment.