Skip to content

Commit

Permalink
fix: update w_o shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ArnolFokam committed Nov 29, 2024
1 parent cb6ff06 commit 6d1016f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion mava/networks/retention.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def setup(self) -> None:
self.w_o = self.param(
"w_o",
nn.initializers.normal(stddev=1 / self.embed_dim),
(self.head_size, self.embed_dim),
(self.embed_dim, self.embed_dim),
)
self.group_norm = nn.GroupNorm(num_groups=self.n_head)

Expand Down

0 comments on commit 6d1016f

Please sign in to comment.