diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index b1a9a7e3..f30d0fe3 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -109,6 +109,7 @@ check_numerics, flatten_items, get_or_none, + replicate_sharding, shapes, split_prng_key, ) @@ -864,19 +865,20 @@ def extend_step( # Ensure that we accumulate using the original dtype. k_proj = k_proj.astype(cached_key.dtype) v_proj = v_proj.astype(cached_value.dtype) + # Ensure sharding of kv_proj. + k_proj = replicate_sharding(source=cached_key, target=k_proj) + v_proj = replicate_sharding(source=cached_value, target=v_proj) + + # Function to update the cache for a single batch element. + def update_single(cached_kv_slice, kv_proj_slice, time_idx): + return jax.lax.dynamic_update_slice_in_dim( + cached_kv_slice, kv_proj_slice, time_idx, axis=0 + ) - # TODO(dhwang2): jax.lax.dynamic_update_slice_in_dim is generally faster than advanced - # indexing, but an unusual slowdown was observed, with RLHF sampling taking up to - # 3 hours per run. Investigate and fix it. - # Note: All X_idx are small, so generating them on-demand is not costly. - b, _, n, h = cached_key.shape - b_idx = jnp.arange(b)[:, None, None, None] - t_idx = (jnp.arange(k_proj.shape[1])[None] + time_step[:, None])[:, :, None, None] - n_idx = jnp.arange(n)[None, None, :, None] - h_idx = jnp.arange(h)[None, None, None, :] - k_proj = cached_key.at[b_idx, t_idx, n_idx, h_idx].set(k_proj) - v_proj = cached_value.at[b_idx, t_idx, n_idx, h_idx].set(v_proj) - + # Use jax.vmap to vectorize over the batch dimension. + vmap_update = jax.vmap(update_single) + k_proj = vmap_update(cached_key, k_proj, time_step) + v_proj = vmap_update(cached_value, v_proj, time_step) updated_state.update(key=k_proj, value=v_proj) return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1401337f..3214b2f9 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -444,6 +444,14 @@ def with_sharding_constraint(x, shardings): return jax.lax.with_sharding_constraint(x, shardings) +def replicate_sharding(*, source: Tensor, target: Tensor): + if hasattr(source, "sharding"): + sharding_spec = source.sharding + return with_sharding_constraint(target, sharding_spec) + else: + return target + + def replicate_to_local_data(x: NestedTensor) -> NestedTensor: """Replicates and converts Tensors in `x` to local DeviceArrays. diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index f4c06b47..a53abcb0 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -5,12 +5,14 @@ import contextlib import dataclasses import enum +import math import sys from collections import OrderedDict from collections.abc import Iterable, Sequence from typing import Any, NamedTuple, Optional, Union # pylint: disable=no-self-use +import chex import jax import jaxlib import numpy as np @@ -20,7 +22,7 @@ from absl.testing import absltest, parameterized from jax import numpy as jnp from jax.experimental import checkify, mesh_utils -from jax.sharding import PartitionSpec +from jax.sharding import NamedSharding, PartitionSpec from axlearn.common import learner, optimizers, serialization, struct, utils from axlearn.common.base_layer import BaseLayer, FactorizationSpec, ParameterSpec @@ -70,6 +72,7 @@ match_regex_rules, prune_tree, pytree_children, + replicate_sharding, replicate_to_local_data, runtime_checks, set_data_dir, @@ -1698,6 +1701,52 @@ def test_length(self): self.assertEqual(2, len(HybridMeshShape(ici_mesh_shape=(1, 2), dcn_mesh_shape=(3, 4)))) +class CpuShardingTest(TestCase): + """Tests sharding utils using fake cpu devices.""" + + def setUp(self): + chex.set_n_cpu_devices(8) + super().setUp() + + def _create_mesh(self): + mesh_shape = [2, 2, 2] + self.assertEqual(math.prod(mesh_shape), jax.device_count()) + mesh_axis_names = ["data", "fsdp", "model"] + device_mesh = mesh_utils.create_device_mesh(mesh_shape=mesh_shape) + mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=mesh_axis_names) + return mesh + + def test_with_sharding_constraint(self): + inputs = jnp.ones((4, 10)) + jax.debug.visualize_array_sharding(inputs) + + mesh = self._create_mesh() + data_pspec = PartitionSpec("data") + data_sharding = NamedSharding(mesh, data_pspec) + ref_shard_inputs = jax.tree.map(lambda x: jax.device_put(x, data_sharding), inputs) + jax.debug.visualize_array_sharding(ref_shard_inputs) + + with mesh: + test_shard_inputs = with_sharding_constraint(inputs, data_sharding) + jax.debug.visualize_array_sharding(test_shard_inputs) + self.assertEqual(ref_shard_inputs.sharding, test_shard_inputs.sharding) + + def test_replicate_sharding(self): + inputs = jnp.ones((4, 10)) + jax.debug.visualize_array_sharding(inputs) + + mesh = self._create_mesh() + data_pspec = PartitionSpec("data") + data_sharding = NamedSharding(mesh, data_pspec) + ref_shard_inputs = jax.tree.map(lambda x: jax.device_put(x, data_sharding), inputs) + jax.debug.visualize_array_sharding(ref_shard_inputs) + + with mesh: + test_shard_inputs = replicate_sharding(source=ref_shard_inputs, target=inputs) + jax.debug.visualize_array_sharding(test_shard_inputs) + self.assertEqual(ref_shard_inputs.sharding, test_shard_inputs.sharding) + + class HostToGlobalArrayTest(TestCase): """Tests host_to_global_device_array."""