Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sable [Continuous actions] #1127

Merged
merged 8 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mava/configs/default/ff_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/ff_sable
- network: ff_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax]
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/rec_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/rec_sable
- network: rec_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax]
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax]
- _self_

hydra:
Expand Down
34 changes: 22 additions & 12 deletions mava/networks/sable_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from mava.networks.torsos import SwiGLU
from mava.networks.utils.sable import (
act_encoder_fn,
autoregressive_act,
train_decoder_fn,
continuous_autoregressive_act,
continuous_train_decoder_fn,
discrete_autoregressive_act,
discrete_train_decoder_fn,
train_encoder_fn,
)
from mava.systems.sable.types import HiddenStates, SableNetworkConfig
Expand Down Expand Up @@ -352,7 +354,7 @@ class SableNetwork(nn.Module):
action_space_type: str = _DISCRETE

def setup(self) -> None:
if self.action_space_type not in [_DISCRETE]:
if self.action_space_type not in [_DISCRETE, _CONTINUOUS]:
raise ValueError(f"Invalid action space type: {self.action_space_type}")

assert (
Expand Down Expand Up @@ -385,15 +387,27 @@ def setup(self) -> None:
train_encoder_fn,
chunk_size=self.memory_config.chunk_size,
)
self.train_decoder_fn = partial(
train_decoder_fn, n_agents=self.n_agents, chunk_size=self.memory_config.chunk_size
)

self.act_encoder_fn = partial(
act_encoder_fn,
chunk_size=self.n_agents_per_chunk,
)
self.autoregressive_act = autoregressive_act
if self.action_space_type == _CONTINUOUS:
self.train_decoder_fn = partial(
continuous_train_decoder_fn,
n_agents=self.n_agents,
chunk_size=self.memory_config.chunk_size,
action_dim=self.action_dim,
)
self.autoregressive_act = partial(
continuous_autoregressive_act, action_dim=self.action_dim
)
else:
self.train_decoder_fn = partial(
discrete_train_decoder_fn,
n_agents=self.n_agents,
chunk_size=self.memory_config.chunk_size,
)
self.autoregressive_act = discrete_autoregressive_act # type: ignore

def __call__(
self,
Expand Down Expand Up @@ -424,9 +438,7 @@ def __call__(
rng_key=rng_key,
)

action_log = jnp.squeeze(action_log, axis=-1)
value = jnp.squeeze(value, axis=-1)
entropy = jnp.squeeze(entropy, axis=-1)
return value, action_log, entropy

def get_actions(
Expand Down Expand Up @@ -467,7 +479,5 @@ def get_actions(
decoder_cross_retn=updated_dec_hs[1],
)

output_actions = jnp.squeeze(output_actions, axis=-1)
output_actions_log = jnp.squeeze(output_actions_log, axis=-1)
value = jnp.squeeze(value, axis=-1)
return output_actions, output_actions_log, value, updated_hs
6 changes: 4 additions & 2 deletions mava/networks/utils/sable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# ruff: noqa: F401

from mava.networks.utils.sable.decode import (
autoregressive_act,
train_decoder_fn,
continuous_autoregressive_act,
continuous_train_decoder_fn,
discrete_autoregressive_act,
discrete_train_decoder_fn,
)
from mava.networks.utils.sable.encode import (
act_encoder_fn,
Expand Down
141 changes: 132 additions & 9 deletions mava/networks/utils/sable/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@
import distrax
import jax
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
from flax import linen as nn

from mava.networks.distributions import TanhTransformedDistribution

# General shapes legend:
# B: batch size
# S: sequence length
# A: number of actions
# N: number of agents

# Constant to avoid numerical instability
_MIN_SCALE = 1e-3


def train_decoder_fn(
def discrete_train_decoder_fn(
decoder: nn.Module,
obs_rep: chex.Array,
action: chex.Array,
Expand All @@ -43,7 +49,7 @@ def train_decoder_fn(
# Delete `rng_key` since it is not used in discrete action space
del rng_key

shifted_actions = get_shifted_actions(action, legal_actions, n_agents=n_agents)
shifted_actions = get_shifted_discrete_actions(action, legal_actions, n_agents=n_agents)
logit = jnp.zeros_like(legal_actions, dtype=jnp.float32)

# Apply the decoder per chunk
Expand Down Expand Up @@ -73,14 +79,14 @@ def train_decoder_fn(

distribution = distrax.Categorical(logits=masked_logits)
action_log_prob = distribution.log_prob(action)
action_log_prob = jnp.expand_dims(action_log_prob, axis=-1)
entropy = jnp.expand_dims(distribution.entropy(), axis=-1)

return action_log_prob, entropy
return action_log_prob, distribution.entropy()


def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents: int) -> chex.Array:
"""Get the shifted action sequence for predicting the next action."""
def get_shifted_discrete_actions(
action: chex.Array, legal_actions: chex.Array, n_agents: int
) -> chex.Array:
"""Get the shifted discrete action sequence for predicting the next action."""
B, S, A = legal_actions.shape

# Create a shifted action sequence for predicting the next action
Expand All @@ -102,7 +108,7 @@ def get_shifted_actions(action: chex.Array, legal_actions: chex.Array, n_agents:
return shifted_actions


def autoregressive_act(
def discrete_autoregressive_act(
decoder: nn.Module,
obs_rep: chex.Array,
hstates: chex.Array,
Expand Down Expand Up @@ -141,5 +147,122 @@ def autoregressive_act(
shifted_actions = shifted_actions.at[:, i + 1, 1:].set(
jax.nn.one_hot(action[:, 0], A), mode="drop"
)
output_actions = output_action.astype(jnp.int32)
output_actions = jnp.squeeze(output_actions, axis=-1)
output_action_log = jnp.squeeze(output_action_log, axis=-1)
return output_actions, output_action_log, hstates


def continuous_train_decoder_fn(
decoder: nn.Module,
obs_rep: chex.Array,
action: chex.Array,
legal_actions: chex.Array,
hstates: chex.Array,
dones: chex.Array,
step_count: chex.Array,
n_agents: int,
chunk_size: int,
action_dim: int,
rng_key: Optional[chex.PRNGKey] = None,
) -> Tuple[chex.Array, chex.Array]:
"""Parallel action sampling for discrete action spaces."""
# Delete `legal_actions` since it is not used in continuous action space
del legal_actions

B, S, _ = action.shape
shifted_actions = get_shifted_continuous_actions(action, action_dim, n_agents=n_agents)
act_mean = jnp.zeros((B, S, action_dim), dtype=jnp.float32)

# Apply the decoder per chunk
num_chunks = shifted_actions.shape[1] // chunk_size
for chunk_id in range(0, num_chunks):
start_idx = chunk_id * chunk_size
end_idx = (chunk_id + 1) * chunk_size
# Chunk obs_rep, shifted_actions, dones, and step_count
chunked_obs_rep = obs_rep[:, start_idx:end_idx]
chunk_shifted_actions = shifted_actions[:, start_idx:end_idx]
chunk_dones = dones[:, start_idx:end_idx]
chunk_step_count = step_count[:, start_idx:end_idx]
chunked_act_mean, hstates = decoder(
action=chunk_shifted_actions,
obs_rep=chunked_obs_rep,
hstates=hstates,
dones=chunk_dones,
step_count=chunk_step_count,
)
act_mean = act_mean.at[:, start_idx:end_idx].set(chunked_act_mean)

action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE

base_distribution = tfd.Normal(loc=act_mean, scale=action_std)
distribution = tfd.Independent(
TanhTransformedDistribution(base_distribution),
reinterpreted_batch_ndims=1,
)

action_log_prob = distribution.log_prob(action)
entropy = distribution.entropy(seed=rng_key)

return action_log_prob, entropy


def get_shifted_continuous_actions(
action: chex.Array, action_dim: int, n_agents: int
) -> chex.Array:
"""Get the shifted continuous action sequence for predicting the next action."""
B, S, _ = action.shape

shifted_actions = jnp.zeros((B, S, action_dim))
start_timestep_token = jnp.zeros(action_dim)
shifted_actions = shifted_actions.at[:, 1:, :].set(action[:, :-1, :])
shifted_actions = shifted_actions.at[:, ::n_agents, :].set(start_timestep_token)

return shifted_actions


def continuous_autoregressive_act(
decoder: nn.Module,
obs_rep: chex.Array,
hstates: chex.Array,
legal_actions: chex.Array,
step_count: chex.Array,
action_dim: int,
key: chex.PRNGKey,
) -> Tuple[chex.Array, chex.Array, chex.Array]:
# Delete `legal_actions` since it is not used in continuous action space
del legal_actions

B, N = step_count.shape
shifted_actions = jnp.zeros((B, N, action_dim))
output_action = jnp.zeros((B, N, action_dim))
output_action_log = jnp.zeros((B, N))

# Apply the decoder autoregressively
for i in range(N):
act_mean, hstates = decoder.recurrent(
action=shifted_actions[:, i : i + 1, :],
obs_rep=obs_rep[:, i : i + 1, :],
hstates=hstates,
step_count=step_count[:, i : i + 1],
)
action_std = jax.nn.softplus(decoder.log_std) + _MIN_SCALE

key, sample_key = jax.random.split(key)

base_distribution = tfd.Normal(loc=act_mean, scale=action_std)
distribution = tfd.Independent(
TanhTransformedDistribution(base_distribution),
reinterpreted_batch_ndims=1,
)

# the action and raw action are now just identical.
action = distribution.sample(seed=sample_key)
action_log = distribution.log_prob(action)

output_action = output_action.at[:, i, :].set(action[:, i, :])
output_action_log = output_action_log.at[:, i].set(action_log[:, i])
# Adds all except the last action to shifted_actions, as it is out of range
shifted_actions = shifted_actions.at[:, i + 1, :].set(action[:, i, :], mode="drop")

return output_action.astype(jnp.int32), output_action_log, hstates
return output_action, output_action_log, hstates
22 changes: 10 additions & 12 deletions mava/systems/sable/anakin/ff_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,15 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
"""Update the network for a single minibatch."""
# UNPACK TRAIN STATE AND BATCH INFO
params, opt_state = train_state
params, opt_state, key = train_state
traj_batch, advantages, targets = batch_info

def _loss_fn(
params: Params,
traj_batch: Transition,
gae: chex.Array,
value_targets: chex.Array,
rng_key: chex.PRNGKey,
) -> Tuple:
"""Calculate Sable loss."""
# RERUN NETWORK
Expand All @@ -194,6 +195,7 @@ def _loss_fn(
observation=traj_batch.obs,
action=traj_batch.action,
dones=traj_batch.done,
rng_key=rng_key,
)

# CALCULATE ACTOR LOSS
Expand Down Expand Up @@ -231,13 +233,9 @@ def _loss_fn(
return total_loss, (loss_actor, entropy, value_loss)

# CALCULATE ACTOR LOSS
key, entropy_key = jax.random.split(key)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
loss_info, grads = grad_fn(
params,
traj_batch,
advantages,
targets,
)
loss_info, grads = grad_fn(params, traj_batch, advantages, targets, entropy_key)

# Compute the parallel mean (pmean) over the batch.
# This calculation is inspired by the Anakin architecture demo notebook.
Expand All @@ -263,7 +261,7 @@ def _loss_fn(
"entropy": entropy,
}

return (new_params, new_opt_state), loss_info
return (new_params, new_opt_state, key), loss_info

(
params,
Expand All @@ -275,7 +273,7 @@ def _loss_fn(
) = update_state

# SHUFFLE MINIBATCHES
key, batch_shuffle_key, agent_shuffle_key = jax.random.split(key, 3)
key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4)

# Shuffle batch
batch_size = config.system.rollout_length * config.arch.num_envs
Expand All @@ -295,9 +293,9 @@ def _loss_fn(
)

# UPDATE MINIBATCHES
(params, opt_states), loss_info = jax.lax.scan(
(params, opt_states, entropy_key), loss_info = jax.lax.scan(
_update_minibatch,
(params, opt_states),
(params, opt_states, entropy_key),
minibatches,
)

Expand Down Expand Up @@ -381,7 +379,7 @@ def learner_setup(
key, net_key = keys

# Get number of agents and actions.
action_dim = int(env.action_spec().num_values[0])
action_dim = env.action_dim
n_agents = env.action_spec().shape[0]
config.system.num_agents = n_agents
config.system.num_actions = action_dim
Expand Down
Loading