From e5496606e7f8bd5f95f521cf7fa6f1886c6c13c4 Mon Sep 17 00:00:00 2001 From: OmaymaMahjoub Date: Thu, 7 Nov 2024 13:52:20 +0000 Subject: [PATCH 1/5] feat: add continuous sable --- mava/configs/default/ff_sable.yaml | 2 +- mava/configs/default/rec_sable.yaml | 2 +- mava/networks/sable_network.py | 34 +++--- mava/networks/utils/sable/__init__.py | 6 +- mava/networks/utils/sable/decode.py | 143 +++++++++++++++++++++++-- mava/systems/sable/anakin/ff_sable.py | 22 ++-- mava/systems/sable/anakin/rec_sable.py | 16 +-- 7 files changed, 183 insertions(+), 42 deletions(-) diff --git a/mava/configs/default/ff_sable.yaml b/mava/configs/default/ff_sable.yaml index bcf11797c..6455cf271 100644 --- a/mava/configs/default/ff_sable.yaml +++ b/mava/configs/default/ff_sable.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: sable/ff_sable - network: ff_retention - - env: rware # [cleaner, connector, gigastep, lbf, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax] - _self_ hydra: diff --git a/mava/configs/default/rec_sable.yaml b/mava/configs/default/rec_sable.yaml index 7dbdbbbc8..6d956671f 100644 --- a/mava/configs/default/rec_sable.yaml +++ b/mava/configs/default/rec_sable.yaml @@ -3,7 +3,7 @@ defaults: - arch: anakin - system: sable/rec_sable - network: rec_retention - - env: rware # [cleaner, connector, gigastep, lbf, rware, smax] + - env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax] - _self_ hydra: diff --git a/mava/networks/sable_network.py b/mava/networks/sable_network.py index e626bfc16..40b1fd615 100644 --- a/mava/networks/sable_network.py +++ b/mava/networks/sable_network.py @@ -26,8 +26,10 @@ from mava.networks.torsos import SwiGLU from mava.networks.utils.sable import ( act_encoder_fn, - autoregressive_act, - train_decoder_fn, + continuous_autoregressive_act, + continuous_train_decoder_fn, + discrete_autoregressive_act, + discrete_train_decoder_fn, train_encoder_fn, ) from mava.systems.sable.types import HiddenStates, SableNetworkConfig @@ -352,7 +354,7 @@ class SableNetwork(nn.Module): action_space_type: str = _DISCRETE def setup(self) -> None: - if self.action_space_type not in [_DISCRETE]: + if self.action_space_type not in [_DISCRETE, _CONTINUOUS]: raise ValueError(f"Invalid action space type: {self.action_space_type}") assert ( @@ -385,15 +387,27 @@ def setup(self) -> None: train_encoder_fn, chunk_size=self.memory_config.chunk_size, ) - self.train_decoder_fn = partial( - train_decoder_fn, n_agents=self.n_agents, chunk_size=self.memory_config.chunk_size - ) - self.act_encoder_fn = partial( act_encoder_fn, chunk_size=self.n_agents_per_chunk, ) - self.autoregressive_act = autoregressive_act + if self.action_space_type == _CONTINUOUS: + self.train_decoder_fn = partial( + continuous_train_decoder_fn, + n_agents=self.n_agents, + chunk_size=self.memory_config.chunk_size, + action_dim=self.action_dim, + ) + self.autoregressive_act = partial( + continuous_autoregressive_act, action_dim=self.action_dim + ) + else: + self.train_decoder_fn = partial( + discrete_train_decoder_fn, + n_agents=self.n_agents, + chunk_size=self.memory_config.chunk_size, + ) + self.autoregressive_act = discrete_autoregressive_act # type: ignore def __call__( self, @@ -424,9 +438,7 @@ def __call__( rng_key=rng_key, ) - action_log = jnp.squeeze(action_log, axis=-1) value = jnp.squeeze(value, axis=-1) - entropy = jnp.squeeze(entropy, axis=-1) return value, action_log, entropy def get_actions( @@ -467,7 +479,5 @@ def get_actions( decoder_cross_retn=updated_dec_hs[1], ) - output_actions = jnp.squeeze(output_actions, axis=-1) - output_actions_log = jnp.squeeze(output_actions_log, axis=-1) value = jnp.squeeze(value, axis=-1) return output_actions, output_actions_log, value, updated_hs diff --git a/mava/networks/utils/sable/__init__.py b/mava/networks/utils/sable/__init__.py index d26b9f645..21b8a46f7 100644 --- a/mava/networks/utils/sable/__init__.py +++ b/mava/networks/utils/sable/__init__.py @@ -14,8 +14,10 @@ # ruff: noqa: F401 from mava.networks.utils.sable.decode import ( - autoregressive_act, - train_decoder_fn, + continuous_autoregressive_act, + continuous_train_decoder_fn, + discrete_autoregressive_act, + discrete_train_decoder_fn, ) from mava.networks.utils.sable.encode import ( act_encoder_fn, diff --git a/mava/networks/utils/sable/decode.py b/mava/networks/utils/sable/decode.py index c9befeb36..d78e17bce 100644 --- a/mava/networks/utils/sable/decode.py +++ b/mava/networks/utils/sable/decode.py @@ -18,8 +18,11 @@ import distrax import jax import jax.numpy as jnp +import tensorflow_probability.substrates.jax.distributions as tfd from flax import linen as nn +from mava.networks.distributions import TanhTransformedDistribution + # General shapes legend: # B: batch size # S: sequence length @@ -27,7 +30,7 @@ # N: number of agents -def train_decoder_fn( +def discrete_train_decoder_fn( decoder: nn.Module, obs_rep: chex.Array, action: chex.Array, @@ -43,7 +46,7 @@ def train_decoder_fn( # Delete `rng_key` since it is not used in discrete action space del rng_key - shifted_actions = get_shifted_actions(action, legal_actions, n_agents=n_agents) + shifted_actions = get_shifted_discrete_actions(action, legal_actions, n_agents=n_agents) logit = jnp.zeros_like(legal_actions, dtype=jnp.float32) # Apply the decoder per chunk @@ -73,14 +76,14 @@ def train_decoder_fn( distribution = distrax.Categorical(logits=masked_logits) action_log_prob = distribution.log_prob(action) - action_log_prob = jnp.expand_dims(action_log_prob, axis=-1) - entropy = jnp.expand_dims(distribution.entropy(), axis=-1) - return action_log_prob, entropy + return action_log_prob, distribution.entropy() -def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents: int) -> chex.Array: - """Get the shifted action sequence for predicting the next action.""" +def get_shifted_discrete_actions( + action: chex.Array, legal_actions: chex.Array, n_agents: int +) -> chex.Array: + """Get the shifted discrete action sequence for predicting the next action.""" B, S, A = legal_actions.shape # Create a shifted action sequence for predicting the next action @@ -102,7 +105,7 @@ def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents: return shifted_actions -def autoregressive_act( +def discrete_autoregressive_act( decoder: nn.Module, obs_rep: chex.Array, hstates: chex.Array, @@ -141,5 +144,129 @@ def autoregressive_act( shifted_actions = shifted_actions.at[:, i + 1, 1:].set( jax.nn.one_hot(action[:, 0], A), mode="drop" ) + output_actions = output_action.astype(jnp.int32) + output_actions = jnp.squeeze(output_actions, axis=-1) + output_action_log = jnp.squeeze(output_action_log, axis=-1) + return output_actions, output_action_log, hstates + + +def continuous_train_decoder_fn( + decoder: nn.Module, + obs_rep: chex.Array, + action: chex.Array, + legal_actions: chex.Array, + hstates: chex.Array, + dones: chex.Array, + step_count: chex.Array, + n_agents: int, + chunk_size: int, + action_dim: int, + rng_key: Optional[chex.PRNGKey] = None, +) -> Tuple[chex.Array, chex.Array]: + """Parallel action sampling for discrete action spaces.""" + # Delete `legal_actions` since it is not used in continuous action space + del legal_actions + + B, S, _ = action.shape + + # todo: double check if this needs to be a param + min_scale = 1e-3 + + shifted_actions = get_shifted_continuous_actions(action, action_dim, n_agents=n_agents) + act_mean = jnp.zeros((B, S, action_dim), dtype=jnp.float32) + + # Apply the decoder per chunk + num_chunks = shifted_actions.shape[1] // chunk_size + for chunk_id in range(0, num_chunks): + start_idx = chunk_id * chunk_size + end_idx = (chunk_id + 1) * chunk_size + # Chunk obs_rep, shifted_actions, dones, and step_count + chunked_obs_rep = obs_rep[:, start_idx:end_idx] + chunk_shifted_actions = shifted_actions[:, start_idx:end_idx] + chunk_dones = dones[:, start_idx:end_idx] + chunk_step_count = step_count[:, start_idx:end_idx] + chunked_act_mean, hstates = decoder( + action=chunk_shifted_actions, + obs_rep=chunked_obs_rep, + hstates=hstates, + dones=chunk_dones, + step_count=chunk_step_count, + ) + act_mean = act_mean.at[:, start_idx:end_idx].set(chunked_act_mean) + + action_std = jax.nn.softplus(decoder.log_std) + min_scale + + base_distribution = tfd.Normal(loc=act_mean, scale=action_std) + distribution = tfd.Independent( + TanhTransformedDistribution(base_distribution), + reinterpreted_batch_ndims=1, + ) + + action_log_prob = distribution.log_prob(action) + entropy = distribution.entropy(seed=rng_key) + + return action_log_prob, entropy + + +def get_shifted_continuous_actions( + action: chex.Array, action_dim: int, n_agents: int +) -> chex.Array: + """Get the shifted continuous action sequence for predicting the next action.""" + B, S, _ = action.shape + + shifted_actions = jnp.zeros((B, S, action_dim)) + start_timestep_token = jnp.zeros(action_dim) + shifted_actions = shifted_actions.at[:, 1:, :].set(action[:, :-1, :]) + shifted_actions = shifted_actions.at[:, ::n_agents, :].set(start_timestep_token) + + return shifted_actions + + +def continuous_autoregressive_act( + decoder: nn.Module, + obs_rep: chex.Array, + hstates: chex.Array, + legal_actions: chex.Array, + step_count: chex.Array, + action_dim: int, + key: chex.PRNGKey, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + # Delete `legal_actions` since it is not used in continuous action space + del legal_actions + + B, N = step_count.shape + min_scale = 1e-3 + + shifted_actions = jnp.zeros((B, N, action_dim)) + + output_action = jnp.zeros((B, N, action_dim)) + output_action_log = jnp.zeros((B, N)) + + # Apply the decoder autoregressively + for i in range(N): + act_mean, hstates = decoder.recurrent( + action=shifted_actions[:, i : i + 1, :], + obs_rep=obs_rep[:, i : i + 1, :], + hstates=hstates, + step_count=step_count[:, i : i + 1], + ) + action_std = jax.nn.softplus(decoder.log_std) + min_scale + + key, sample_key = jax.random.split(key) + + base_distribution = tfd.Normal(loc=act_mean, scale=action_std) + distribution = tfd.Independent( + TanhTransformedDistribution(base_distribution), + reinterpreted_batch_ndims=1, + ) + + # the action and raw action are now just identical. + action = distribution.sample(seed=sample_key) + action_log = distribution.log_prob(action) + + output_action = output_action.at[:, i, :].set(action[:, i, :]) + output_action_log = output_action_log.at[:, i].set(action_log[:, i]) + # Adds all except the last action to shifted_actions, as it is out of range + shifted_actions = shifted_actions.at[:, i + 1, :].set(action[:, i, :], mode="drop") return output_action.astype(jnp.int32), output_action_log, hstates diff --git a/mava/systems/sable/anakin/ff_sable.py b/mava/systems/sable/anakin/ff_sable.py index bcd7dd3e0..42b749c38 100644 --- a/mava/systems/sable/anakin/ff_sable.py +++ b/mava/systems/sable/anakin/ff_sable.py @@ -178,7 +178,7 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" # UNPACK TRAIN STATE AND BATCH INFO - params, opt_state = train_state + params, opt_state, key = train_state traj_batch, advantages, targets = batch_info def _loss_fn( @@ -186,6 +186,7 @@ def _loss_fn( traj_batch: Transition, gae: chex.Array, value_targets: chex.Array, + rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" # RERUN NETWORK @@ -194,6 +195,7 @@ def _loss_fn( observation=traj_batch.obs, action=traj_batch.action, dones=traj_batch.done, + rng_key=rng_key, ) # CALCULATE ACTOR LOSS @@ -231,13 +233,9 @@ def _loss_fn( return total_loss, (loss_actor, entropy, value_loss) # CALCULATE ACTOR LOSS + key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - loss_info, grads = grad_fn( - params, - traj_batch, - advantages, - targets, - ) + loss_info, grads = grad_fn(params, traj_batch, advantages, targets, entropy_key) # Compute the parallel mean (pmean) over the batch. # This calculation is inspired by the Anakin architecture demo notebook. @@ -263,7 +261,7 @@ def _loss_fn( "entropy": entropy, } - return (new_params, new_opt_state), loss_info + return (new_params, new_opt_state, key), loss_info ( params, @@ -275,7 +273,7 @@ def _loss_fn( ) = update_state # SHUFFLE MINIBATCHES - key, batch_shuffle_key, agent_shuffle_key = jax.random.split(key, 3) + key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch batch_size = config.system.rollout_length * config.arch.num_envs @@ -295,9 +293,9 @@ def _loss_fn( ) # UPDATE MINIBATCHES - (params, opt_states), loss_info = jax.lax.scan( + (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, - (params, opt_states), + (params, opt_states, entropy_key), minibatches, ) @@ -381,7 +379,7 @@ def learner_setup( key, net_key = keys # Get number of agents and actions. - action_dim = int(env.action_spec().num_values[0]) + action_dim = env.action_dim n_agents = env.action_spec().shape[0] config.system.num_agents = n_agents config.system.num_actions = action_dim diff --git a/mava/systems/sable/anakin/rec_sable.py b/mava/systems/sable/anakin/rec_sable.py index 5f1a4c16e..c66176f28 100644 --- a/mava/systems/sable/anakin/rec_sable.py +++ b/mava/systems/sable/anakin/rec_sable.py @@ -192,7 +192,7 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" # UNPACK TRAIN STATE AND BATCH INFO - params, opt_state = train_state + params, opt_state, key = train_state traj_batch, advantages, targets, prev_hstates = batch_info def _loss_fn( @@ -201,6 +201,7 @@ def _loss_fn( gae: chex.Array, value_targets: chex.Array, prev_hstates: HiddenStates, + rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" # RERUN NETWORK @@ -210,6 +211,7 @@ def _loss_fn( traj_batch.action, prev_hstates, traj_batch.done, + rng_key, ) # CALCULATE ACTOR LOSS @@ -247,6 +249,7 @@ def _loss_fn( return total_loss, (loss_actor, entropy, value_loss) # CALCULATE ACTOR LOSS + key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) loss_info, grads = grad_fn( params, @@ -254,6 +257,7 @@ def _loss_fn( advantages, targets, prev_hstates, + entropy_key, ) # Compute the parallel mean (pmean) over the batch. @@ -280,7 +284,7 @@ def _loss_fn( "entropy": entropy, } - return (new_params, new_opt_state), loss_info + return (new_params, new_opt_state, key), loss_info ( params, @@ -293,7 +297,7 @@ def _loss_fn( ) = update_state # SHUFFLE MINIBATCHES - key, batch_shuffle_key, agent_shuffle_key = jax.random.split(key, 3) + key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch batch_size = config.arch.num_envs @@ -322,9 +326,9 @@ def _loss_fn( ) # UPDATE MINIBATCHES - (params, opt_states), loss_info = jax.lax.scan( + (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, - (params, opt_states), + (params, opt_states, entropy_key), (*minibatches, prev_hs_minibatch), ) @@ -412,7 +416,7 @@ def learner_setup( key, net_key = keys # Get number of agents and actions. - action_dim = int(env.action_spec().num_values[0]) + action_dim = env.action_dim n_agents = env.action_spec().shape[0] config.system.num_agents = n_agents config.system.num_actions = action_dim From 7e4bcde6f423a18696a27ef92955076306330762 Mon Sep 17 00:00:00 2001 From: OmaymaMahjoub Date: Thu, 7 Nov 2024 14:00:21 +0000 Subject: [PATCH 2/5] feat: add continuous envs to sable integration test --- test/integration_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration_test.py b/test/integration_test.py index 419cf5799..2653f872d 100644 --- a/test/integration_test.py +++ b/test/integration_test.py @@ -82,7 +82,7 @@ def test_ppo_system(fast_config: dict, system_path: str) -> None: def test_sable_system(fast_config: dict, system_path: str) -> None: """Test all sable systems on random envs.""" _, _, system_name = system_path.split(".") - env = random.choice(discrete_envs) + env = random.choice(continuous_envs + discrete_envs) with initialize(version_base=None, config_path=config_path): cfg = compose(config_name=f"{system_name}", overrides=[f"env={env}"]) From b735fb01112969f6662a912eec748aac5107133e Mon Sep 17 00:00:00 2001 From: OmaymaMahjoub Date: Thu, 7 Nov 2024 14:06:52 +0000 Subject: [PATCH 3/5] feat: make min scale as global constant --- mava/networks/utils/sable/decode.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/mava/networks/utils/sable/decode.py b/mava/networks/utils/sable/decode.py index d78e17bce..fb1520b2b 100644 --- a/mava/networks/utils/sable/decode.py +++ b/mava/networks/utils/sable/decode.py @@ -29,6 +29,9 @@ # A: number of actions # N: number of agents +# Constant to avoid numerical instability +_MIN_SCALE = 1e-3 + def discrete_train_decoder_fn( decoder: nn.Module, @@ -168,10 +171,6 @@ def continuous_train_decoder_fn( del legal_actions B, S, _ = action.shape - - # todo: double check if this needs to be a param - min_scale = 1e-3 - shifted_actions = get_shifted_continuous_actions(action, action_dim, n_agents=n_agents) act_mean = jnp.zeros((B, S, action_dim), dtype=jnp.float32) @@ -194,7 +193,7 @@ def continuous_train_decoder_fn( ) act_mean = act_mean.at[:, start_idx:end_idx].set(chunked_act_mean) - action_std = jax.nn.softplus(decoder.log_std) + min_scale + action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE base_distribution = tfd.Normal(loc=act_mean, scale=action_std) distribution = tfd.Independent( @@ -235,10 +234,7 @@ def continuous_autoregressive_act( del legal_actions B, N = step_count.shape - min_scale = 1e-3 - shifted_actions = jnp.zeros((B, N, action_dim)) - output_action = jnp.zeros((B, N, action_dim)) output_action_log = jnp.zeros((B, N)) @@ -250,7 +246,7 @@ def continuous_autoregressive_act( hstates=hstates, step_count=step_count[:, i : i + 1], ) - action_std = jax.nn.softplus(decoder.log_std) + min_scale + action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE key, sample_key = jax.random.split(key) From 03fc3d5fc97fadf26b31a71de558394e3459ea34 Mon Sep 17 00:00:00 2001 From: OmaymaMahjoub Date: Fri, 8 Nov 2024 11:20:02 +0000 Subject: [PATCH 4/5] fix: add the timestep pos encoding flag fix --- mava/networks/retention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mava/networks/retention.py b/mava/networks/retention.py index 6f0d32df2..a041abf33 100644 --- a/mava/networks/retention.py +++ b/mava/networks/retention.py @@ -300,8 +300,9 @@ def recurrent( """Recurrent representation of the multi-scale retention mechanism""" B, S, _ = value_n.shape - # Positional encoding of the current step - key_n, query_n, value_n = self.pe(key_n, query_n, value_n, step_count) + # Positional encoding of the current step if enabled + if self.memory_config.timestep_positional_encoding: + key_n, query_n, value_n = self.pe(key_n, query_n, value_n, step_count) ret_output = jnp.zeros((B, S, self.head_size), dtype=value_n.dtype) for head in range(self.n_head): From ab6aba516d0258ac9ba35e169772be72b9a5f412 Mon Sep 17 00:00:00 2001 From: OmaymaMahjoub Date: Mon, 11 Nov 2024 09:27:09 +0000 Subject: [PATCH 5/5] fix: fixing cont autoregressive act --- mava/networks/utils/sable/decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mava/networks/utils/sable/decode.py b/mava/networks/utils/sable/decode.py index fb1520b2b..47edecf0f 100644 --- a/mava/networks/utils/sable/decode.py +++ b/mava/networks/utils/sable/decode.py @@ -265,4 +265,4 @@ def continuous_autoregressive_act( # Adds all except the last action to shifted_actions, as it is out of range shifted_actions = shifted_actions.at[:, i + 1, :].set(action[:, i, :], mode="drop") - return output_action.astype(jnp.int32), output_action_log, hstates + return output_action, output_action_log, hstates