Skip to content

Commit

Permalink
Merge pull request #1127 from instadeepai/feat/sable-cont
Browse files Browse the repository at this point in the history
Add Sable [Continuous actions]
  • Loading branch information
OmaymaMahjoub authored Nov 13, 2024
2 parents 6092dc6 + ab6aba5 commit 15e591d
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 44 deletions.
2 changes: 1 addition & 1 deletion mava/configs/default/ff_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/ff_sable
- network: ff_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax]
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/rec_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/rec_sable
- network: rec_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax]
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax]
- _self_

hydra:
Expand Down
34 changes: 22 additions & 12 deletions mava/networks/sable_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from mava.networks.torsos import SwiGLU
from mava.networks.utils.sable import (
act_encoder_fn,
autoregressive_act,
train_decoder_fn,
continuous_autoregressive_act,
continuous_train_decoder_fn,
discrete_autoregressive_act,
discrete_train_decoder_fn,
train_encoder_fn,
)
from mava.systems.sable.types import HiddenStates, SableNetworkConfig
Expand Down Expand Up @@ -352,7 +354,7 @@ class SableNetwork(nn.Module):
action_space_type: str = _DISCRETE

def setup(self) -> None:
if self.action_space_type not in [_DISCRETE]:
if self.action_space_type not in [_DISCRETE, _CONTINUOUS]:
raise ValueError(f"Invalid action space type: {self.action_space_type}")

assert (
Expand Down Expand Up @@ -385,15 +387,27 @@ def setup(self) -> None:
train_encoder_fn,
chunk_size=self.memory_config.chunk_size,
)
self.train_decoder_fn = partial(
train_decoder_fn, n_agents=self.n_agents, chunk_size=self.memory_config.chunk_size
)

self.act_encoder_fn = partial(
act_encoder_fn,
chunk_size=self.n_agents_per_chunk,
)
self.autoregressive_act = autoregressive_act
if self.action_space_type == _CONTINUOUS:
self.train_decoder_fn = partial(
continuous_train_decoder_fn,
n_agents=self.n_agents,
chunk_size=self.memory_config.chunk_size,
action_dim=self.action_dim,
)
self.autoregressive_act = partial(
continuous_autoregressive_act, action_dim=self.action_dim
)
else:
self.train_decoder_fn = partial(
discrete_train_decoder_fn,
n_agents=self.n_agents,
chunk_size=self.memory_config.chunk_size,
)
self.autoregressive_act = discrete_autoregressive_act # type: ignore

def __call__(
self,
Expand Down Expand Up @@ -424,9 +438,7 @@ def __call__(
rng_key=rng_key,
)

action_log = jnp.squeeze(action_log, axis=-1)
value = jnp.squeeze(value, axis=-1)
entropy = jnp.squeeze(entropy, axis=-1)
return value, action_log, entropy

def get_actions(
Expand Down Expand Up @@ -467,7 +479,5 @@ def get_actions(
decoder_cross_retn=updated_dec_hs[1],
)

output_actions = jnp.squeeze(output_actions, axis=-1)
output_actions_log = jnp.squeeze(output_actions_log, axis=-1)
value = jnp.squeeze(value, axis=-1)
return output_actions, output_actions_log, value, updated_hs
6 changes: 4 additions & 2 deletions mava/networks/utils/sable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# ruff: noqa: F401

from mava.networks.utils.sable.decode import (
autoregressive_act,
train_decoder_fn,
continuous_autoregressive_act,
continuous_train_decoder_fn,
discrete_autoregressive_act,
discrete_train_decoder_fn,
)
from mava.networks.utils.sable.encode import (
act_encoder_fn,
Expand Down
141 changes: 132 additions & 9 deletions mava/networks/utils/sable/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@
import distrax
import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
from flax import linen as nn

from mava.networks.distributions import TanhTransformedDistribution

# General shapes legend:
# B: batch size
# S: sequence length
# A: number of actions
# N: number of agents

# Constant to avoid numerical instability
_MIN_SCALE = 1e-3


def train_decoder_fn(
def discrete_train_decoder_fn(
decoder: nn.Module,
obs_rep: chex.Array,
action: chex.Array,
Expand All @@ -43,7 +49,7 @@ def train_decoder_fn(
# Delete `rng_key` since it is not used in discrete action space
del rng_key

shifted_actions = get_shifted_actions(action, legal_actions, n_agents=n_agents)
shifted_actions = get_shifted_discrete_actions(action, legal_actions, n_agents=n_agents)
logit = jnp.zeros_like(legal_actions, dtype=jnp.float32)

# Apply the decoder per chunk
Expand Down Expand Up @@ -73,14 +79,14 @@ def train_decoder_fn(

distribution = distrax.Categorical(logits=masked_logits)
action_log_prob = distribution.log_prob(action)
action_log_prob = jnp.expand_dims(action_log_prob, axis=-1)
entropy = jnp.expand_dims(distribution.entropy(), axis=-1)

return action_log_prob, entropy
return action_log_prob, distribution.entropy()


def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents: int) -> chex.Array:
"""Get the shifted action sequence for predicting the next action."""
def get_shifted_discrete_actions(
action: chex.Array, legal_actions: chex.Array, n_agents: int
) -> chex.Array:
"""Get the shifted discrete action sequence for predicting the next action."""
B, S, A = legal_actions.shape

# Create a shifted action sequence for predicting the next action
Expand All @@ -102,7 +108,7 @@ def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents:
return shifted_actions


def autoregressive_act(
def discrete_autoregressive_act(
decoder: nn.Module,
obs_rep: chex.Array,
hstates: chex.Array,
Expand Down Expand Up @@ -141,5 +147,122 @@ def autoregressive_act(
shifted_actions = shifted_actions.at[:, i + 1, 1:].set(
jax.nn.one_hot(action[:, 0], A), mode="drop"
)
output_actions = output_action.astype(jnp.int32)
output_actions = jnp.squeeze(output_actions, axis=-1)
output_action_log = jnp.squeeze(output_action_log, axis=-1)
return output_actions, output_action_log, hstates


def continuous_train_decoder_fn(
decoder: nn.Module,
obs_rep: chex.Array,
action: chex.Array,
legal_actions: chex.Array,
hstates: chex.Array,
dones: chex.Array,
step_count: chex.Array,
n_agents: int,
chunk_size: int,
action_dim: int,
rng_key: Optional[chex.PRNGKey] = None,
) -> Tuple[chex.Array, chex.Array]:
"""Parallel action sampling for discrete action spaces."""
# Delete `legal_actions` since it is not used in continuous action space
del legal_actions

B, S, _ = action.shape
shifted_actions = get_shifted_continuous_actions(action, action_dim, n_agents=n_agents)
act_mean = jnp.zeros((B, S, action_dim), dtype=jnp.float32)

# Apply the decoder per chunk
num_chunks = shifted_actions.shape[1] // chunk_size
for chunk_id in range(0, num_chunks):
start_idx = chunk_id * chunk_size
end_idx = (chunk_id + 1) * chunk_size
# Chunk obs_rep, shifted_actions, dones, and step_count
chunked_obs_rep = obs_rep[:, start_idx:end_idx]
chunk_shifted_actions = shifted_actions[:, start_idx:end_idx]
chunk_dones = dones[:, start_idx:end_idx]
chunk_step_count = step_count[:, start_idx:end_idx]
chunked_act_mean, hstates = decoder(
action=chunk_shifted_actions,
obs_rep=chunked_obs_rep,
hstates=hstates,
dones=chunk_dones,
step_count=chunk_step_count,
)
act_mean = act_mean.at[:, start_idx:end_idx].set(chunked_act_mean)

action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE

base_distribution = tfd.Normal(loc=act_mean, scale=action_std)
distribution = tfd.Independent(
TanhTransformedDistribution(base_distribution),
reinterpreted_batch_ndims=1,
)

action_log_prob = distribution.log_prob(action)
entropy = distribution.entropy(seed=rng_key)

return action_log_prob, entropy


def get_shifted_continuous_actions(
action: chex.Array, action_dim: int, n_agents: int
) -> chex.Array:
"""Get the shifted continuous action sequence for predicting the next action."""
B, S, _ = action.shape

shifted_actions = jnp.zeros((B, S, action_dim))
start_timestep_token = jnp.zeros(action_dim)
shifted_actions = shifted_actions.at[:, 1:, :].set(action[:, :-1, :])
shifted_actions = shifted_actions.at[:, ::n_agents, :].set(start_timestep_token)

return shifted_actions


def continuous_autoregressive_act(
decoder: nn.Module,
obs_rep: chex.Array,
hstates: chex.Array,
legal_actions: chex.Array,
step_count: chex.Array,
action_dim: int,
key: chex.PRNGKey,
) -> Tuple[chex.Array, chex.Array, chex.Array]:
# Delete `legal_actions` since it is not used in continuous action space
del legal_actions

B, N = step_count.shape
shifted_actions = jnp.zeros((B, N, action_dim))
output_action = jnp.zeros((B, N, action_dim))
output_action_log = jnp.zeros((B, N))

# Apply the decoder autoregressively
for i in range(N):
act_mean, hstates = decoder.recurrent(
action=shifted_actions[:, i : i + 1, :],
obs_rep=obs_rep[:, i : i + 1, :],
hstates=hstates,
step_count=step_count[:, i : i + 1],
)
action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE

key, sample_key = jax.random.split(key)

base_distribution = tfd.Normal(loc=act_mean, scale=action_std)
distribution = tfd.Independent(
TanhTransformedDistribution(base_distribution),
reinterpreted_batch_ndims=1,
)

# the action and raw action are now just identical.
action = distribution.sample(seed=sample_key)
action_log = distribution.log_prob(action)

output_action = output_action.at[:, i, :].set(action[:, i, :])
output_action_log = output_action_log.at[:, i].set(action_log[:, i])
# Adds all except the last action to shifted_actions, as it is out of range
shifted_actions = shifted_actions.at[:, i + 1, :].set(action[:, i, :], mode="drop")

return output_action.astype(jnp.int32), output_action_log, hstates
return output_action, output_action_log, hstates
22 changes: 10 additions & 12 deletions mava/systems/sable/anakin/ff_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,15 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
"""Update the network for a single minibatch."""
# UNPACK TRAIN STATE AND BATCH INFO
params, opt_state = train_state
params, opt_state, key = train_state
traj_batch, advantages, targets = batch_info

def _loss_fn(
params: Params,
traj_batch: Transition,
gae: chex.Array,
value_targets: chex.Array,
rng_key: chex.PRNGKey,
) -> Tuple:
"""Calculate Sable loss."""
# RERUN NETWORK
Expand All @@ -194,6 +195,7 @@ def _loss_fn(
observation=traj_batch.obs,
action=traj_batch.action,
dones=traj_batch.done,
rng_key=rng_key,
)

# CALCULATE ACTOR LOSS
Expand Down Expand Up @@ -231,13 +233,9 @@ def _loss_fn(
return total_loss, (loss_actor, entropy, value_loss)

# CALCULATE ACTOR LOSS
key, entropy_key = jax.random.split(key)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
loss_info, grads = grad_fn(
params,
traj_batch,
advantages,
targets,
)
loss_info, grads = grad_fn(params, traj_batch, advantages, targets, entropy_key)

# Compute the parallel mean (pmean) over the batch.
# This calculation is inspired by the Anakin architecture demo notebook.
Expand All @@ -263,7 +261,7 @@ def _loss_fn(
"entropy": entropy,
}

return (new_params, new_opt_state), loss_info
return (new_params, new_opt_state, key), loss_info

(
params,
Expand All @@ -275,7 +273,7 @@ def _loss_fn(
) = update_state

# SHUFFLE MINIBATCHES
key, batch_shuffle_key, agent_shuffle_key = jax.random.split(key, 3)
key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4)

# Shuffle batch
batch_size = config.system.rollout_length * config.arch.num_envs
Expand All @@ -295,9 +293,9 @@ def _loss_fn(
)

# UPDATE MINIBATCHES
(params, opt_states), loss_info = jax.lax.scan(
(params, opt_states, entropy_key), loss_info = jax.lax.scan(
_update_minibatch,
(params, opt_states),
(params, opt_states, entropy_key),
minibatches,
)

Expand Down Expand Up @@ -381,7 +379,7 @@ def learner_setup(
key, net_key = keys

# Get number of agents and actions.
action_dim = int(env.action_spec().num_values[0])
action_dim = env.action_dim
n_agents = env.action_spec().shape[0]
config.system.num_agents = n_agents
config.system.num_actions = action_dim
Expand Down
Loading

0 comments on commit 15e591d

Please sign in to comment.