Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonDuToit committed Nov 5, 2024
1 parent 45df968 commit 1cf8534
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def learner_setup(
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -420,7 +420,7 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state
Expand Down
5 changes: 3 additions & 2 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,9 @@ def learner_setup(
(config.arch.num_envs, config.system.num_agents),
dtype=bool,
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -403,7 +404,7 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state
Expand Down
1 change: 1 addition & 0 deletions mava/systems/ppo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LearnerState(NamedTuple):
key: chex.PRNGKey
env_state: State
timestep: TimeStep
dones: Done


class RNNLearnerState(NamedTuple):
Expand Down
1 change: 0 additions & 1 deletion mava/utils/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def _get_advantages(
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward
gamma = gamma
delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae
Expand Down

0 comments on commit 1cf8534

Please sign in to comment.