Skip to content

Commit

Permalink
feat: update shifting action method in autoregressive act
Browse files Browse the repository at this point in the history
Co-authored-by: Sasha Abramowitz <[email protected]>
  • Loading branch information
OmaymaMahjoub committed Nov 6, 2024
1 parent 945937c commit 0dd0eab
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions mava/networks/utils/sable/encoder_decoder_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,9 @@ def autoregressive_act(
output_action = output_action.at[:, i, :].set(action)
output_action_log = output_action_log.at[:, i, :].set(action_log)

update_shifted_action = i + 1 < A
shifted_actions = jax.lax.cond(
update_shifted_action,
lambda action=action, i=i, shifted_actions=shifted_actions: shifted_actions.at[
:, i + 1, 1:
].set(jax.nn.one_hot(action[:, 0], N)),
lambda shifted_actions=shifted_actions: shifted_actions,
# Adds all except the last action to shifted_actions, as it is out of range.
shifted_actions = shifted_actions.at[:, i + 1, 1:].set(
jax.nn.one_hot(action[:, 0], N), mode="drop"
)

return output_action.astype(jnp.int32), output_action_log, hstates

0 comments on commit 0dd0eab

Please sign in to comment.