Skip to content

Commit

Permalink
Support fine grained activation sharding. (apple#21) (apple#881)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoulme-aws authored Dec 17, 2024
1 parent 4c8f8c3 commit 01b762e
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 1 deletion.
20 changes: 19 additions & 1 deletion axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from axlearn.common.utils import (
NestedTensor,
Tensor,
maybe_shard,
partial_with_fn_metadata,
with_sharding_constraint,
)
Expand Down Expand Up @@ -331,6 +332,10 @@ class Config(BaseNormalizationLayer.Config):
eps: float = 1e-8
# Cast input to this dtype for the 'forward' call. If None, do not cast.
forward_dtype: Optional[jnp.dtype] = jnp.float32
# If not None, how to partition input activation values.
input_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition output activation values.
output_partition_spec: Optional[tuple[Optional[str]]] = None

def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
cfg = self.config
Expand All @@ -341,13 +346,15 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
def forward(self, x: Tensor, *, paddings: Optional[Tensor] = None) -> Tensor:
del paddings # paddings do not affect LayerNorm results
cfg = self.config
x = maybe_shard(x, cfg.input_partition_spec)
x_dtype = x.dtype
if cfg.forward_dtype is not None:
x = x.astype(cfg.forward_dtype)
moment2 = (x * x).mean(axis=-1, keepdims=True)
x = x * jax.lax.rsqrt(moment2 + cfg.eps)
x = x.astype(x_dtype)
x = x * self.parameters["scale"]
x = maybe_shard(x, cfg.output_partition_spec)
return x


Expand Down Expand Up @@ -780,6 +787,12 @@ class Config(BaseLayer.Config):

num_embeddings: Required[int] = REQUIRED # Maximum number of embeddings in table.
dim: Required[int] = REQUIRED # Embedding vector dimensionality.
# If not None, how to partition input activation values.
input_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition embedding table.
embedding_partition_spec: Optional[tuple[Optional[str]]] = None
# If not None, how to partition output activation values.
output_partition_spec: Optional[tuple[Optional[str]]] = None

@classmethod
def default_config(cls):
Expand Down Expand Up @@ -814,8 +827,13 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
)

def forward(self, x: Tensor) -> Tensor:
cfg = self.config
x = maybe_shard(x, cfg.input_partition_spec)
emb = self.parameters["weight"]
return emb[x]
emb = maybe_shard(emb, cfg.embedding_partition_spec)
activation = emb[x]
activation = maybe_shard(activation, cfg.output_partition_spec)
return activation

def attend(self, x: Tensor) -> Tensor:
"""Apply query array 'x' to the embedding weight array.
Expand Down
105 changes: 105 additions & 0 deletions axlearn/common/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections.abc import Sequence
from functools import partial
from typing import Optional, Union
from unittest import mock

import jax.random
import numpy as np
Expand Down Expand Up @@ -507,6 +508,57 @@ def test_rms_norm(self):
# The output_norm should be close to 2 * sqrt(dim).
assert_allclose(output_norm, np.ones_like(output_norm) * 2.0 * math.sqrt(dim))

@mock.patch("axlearn.common.utils.with_sharding_constraint")
def test_rms_norm_partition_specs_constraint(self, mock_with_sharding_constraint):
# Configure mock to return its input.
mock_with_sharding_constraint.side_effect = lambda x, *args: x

dim = 6
cfg = RMSNorm.default_config().set(
name="norm",
input_dim=dim,
input_partition_spec=("fsdp", "model", None),
output_partition_spec=("fsdp", None, None),
)
layer: RMSNorm = cfg.instantiate(parent=None)

# Initialize layer parameters.
prng_key = jax.random.PRNGKey(123)
prng_key, init_key = jax.random.split(prng_key)
layer_params = layer.initialize_parameters_recursively(init_key)

# Random inputs.
prng_key, input_key = jax.random.split(prng_key)
inputs = jax.random.normal(input_key, [2, 3, dim])

# Run forward pass.
outputs, _ = F(
layer,
inputs=(inputs,),
is_training=True,
state=layer_params,
prng_key=prng_key,
)

# Verify with_sharding_constraint calls.
calls = mock_with_sharding_constraint.call_args_list
# Should be called twice - once for input, once for output.
self.assertEqual(len(calls), 2)

# 1. Input tensor constraint.
input_spec = calls[0].args[1]
self.assertEqual(input_spec, ("fsdp", "model", None))
self.assertEqual(calls[0].args[0].shape, (2, 3, dim))
self.assertEqual(calls[0].args[0].dtype, jnp.float32)
np.testing.assert_array_equal(calls[0].args[0], inputs)

# 2. Output tensor constraint.
output_spec = calls[1].args[1]
self.assertEqual(output_spec, ("fsdp", None, None))
self.assertEqual(calls[1].args[0].shape, (2, 3, dim))
self.assertEqual(calls[1].args[0].dtype, jnp.float32)
np.testing.assert_array_equal(calls[1].args[0], outputs)

def test_l2_norm(self):
cfg = L2Norm.default_config().set(name="norm")
layer: L2Norm = cfg.instantiate(parent=None)
Expand Down Expand Up @@ -1207,6 +1259,59 @@ def test_embed_attend(self, seq_len, dim, num_embeddings, is_training):
)[0]
assert_allclose(jnp.dot(x, state["weight"].T), actual_attends)

@mock.patch("axlearn.common.utils.with_sharding_constraint")
def test_embed_partition_specs_constraint(self, mock_with_sharding_constraint):
# Configure mock to return its input.
mock_with_sharding_constraint.side_effect = lambda x, *args: x

dim = 16
num_embeddings = 100
seq_len = 5
rng = jax.random.PRNGKey(1)

# Configure embedding with partition specs.
cfg = Embedding.default_config().set(
name="embed",
dim=dim,
num_embeddings=num_embeddings,
input_partition_spec=("fsdp", None),
output_partition_spec=("fsdp", "model"),
embedding_partition_spec=("model", "fsdp"),
)

# Instantiate embedding.
emb = cfg.instantiate(parent=None)
state = emb.initialize_parameters_recursively(rng)

# Test lookup functionality.
ixs = jax.random.randint(rng, minval=0, maxval=num_embeddings, shape=(3, seq_len))
actual_embeds, _ = module.functional(emb, rng, state=state, inputs=[ixs], is_training=True)

# Verify with_sharding_constraint was called in correct order with proper specs.
calls = mock_with_sharding_constraint.call_args_list
self.assertEqual(len(calls), 3)

# 1. Input activation constraint (indices tensor).
input_spec = calls[0].args[1]
self.assertEqual(input_spec, ("fsdp", None))
self.assertEqual(calls[0].args[0].shape, (3, seq_len))
self.assertEqual(calls[0].args[0].dtype, jnp.int32)
np.testing.assert_array_equal(calls[0].args[0], ixs)

# 2. Embedding weight constraint.
weight_spec = calls[1].args[1]
self.assertEqual(weight_spec, ("model", "fsdp"))
self.assertEqual(calls[1].args[0].shape, (num_embeddings, dim))
self.assertEqual(calls[1].args[0].dtype, jnp.float32)
np.testing.assert_array_equal(calls[1].args[0], state["weight"])

# 3. Output activation constraint (after lookup).
output_spec = calls[2].args[1]
self.assertEqual(output_spec, ("fsdp", "model"))
self.assertEqual(calls[2].args[0].shape, (3, seq_len, dim))
self.assertEqual(calls[2].args[0].dtype, jnp.float32)
np.testing.assert_array_equal(calls[2].args[0], actual_embeds)


class BiasLayer(BaseLayer):
"""A test layer with bias."""
Expand Down
6 changes: 6 additions & 0 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,12 @@ def with_sharding_constraint(x, shardings):
return jax.lax.with_sharding_constraint(x, shardings)


def maybe_shard(x: NestedTensor, partition_spec: Optional[PartitionSpec]) -> NestedTensor:
if partition_spec is None:
return x
return with_sharding_constraint(x, PartitionSpec(*partition_spec))


def replicate_to_local_data(x: NestedTensor) -> NestedTensor:
"""Replicates and converts Tensors in `x` to local DeviceArrays.
Expand Down

0 comments on commit 01b762e

Please sign in to comment.