diff --git a/mava/networks/utils/sable/decode.py b/mava/networks/utils/sable/decode.py index fb1520b2b..47edecf0f 100644 --- a/mava/networks/utils/sable/decode.py +++ b/mava/networks/utils/sable/decode.py @@ -265,4 +265,4 @@ def continuous_autoregressive_act( # Adds all except the last action to shifted_actions, as it is out of range shifted_actions = shifted_actions.at[:, i + 1, :].set(action[:, i, :], mode="drop") - return output_action.astype(jnp.int32), output_action_log, hstates + return output_action, output_action_log, hstates