From 65f3827f9dc57b93802b686b715bfddd23743d04 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Wed, 11 Dec 2024 16:30:58 -0800 Subject: [PATCH] MaskFnAttentionBias._bool_value passes the same rank position tensors to mask_fn. When target_positions is set, a rank 3 target_positions and a rank 2 source_positions are passed to mask_fn. From the perspective of a downstream defining mask_fn, this is a big surprise. --- axlearn/common/attention_bias.py | 5 +++++ axlearn/common/attention_bias_test.py | 25 ++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index f6a29213..8977f444 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -440,6 +440,7 @@ def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor: x = f(jnp.asarray([1,2]), jnp.asarray([3,4])) assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None] ``` + * Both tensors have the same rank (either 2 or 3), as batch dim is optional. * If given non-scalar arguments of different shapes, the result must be the same if we first broadcast the arguments against each other to make them have the same shape. * Beyond requiring broadcastability, must not impose any constraints on the shapes of its @@ -489,12 +490,16 @@ def _bool_value(self) -> Optional[Tensor]: - If `target_positions` is None: [target_len, source_len] - Else: [batch, target_len, source_len]. """ + # [target_len, 1], [1, source_len] target_positions, source_positions = jnp.indices(self.shape, sparse=True) if self.target_positions is not None: target_positions = self.target_positions if target_positions.ndim == 1: + # [batch, 1] + [target_len] = [batch, target_len] # pylint: disable-next=unsubscriptable-object target_positions = target_positions[:, None] + jnp.arange(self.shape[0]) + while source_positions.ndim < 3: + source_positions = source_positions[None, ...] while target_positions.ndim < 3: target_positions = target_positions[..., None] return self.mask(target_positions, source_positions) # pylint: disable=not-callable diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py index 0932df10..ec445b78 100644 --- a/axlearn/common/attention_bias_test.py +++ b/axlearn/common/attention_bias_test.py @@ -6,7 +6,7 @@ import chex import jax.numpy as jnp import jax.util -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.sharding import PartitionSpec from axlearn.common import attention_bias, test_utils @@ -267,6 +267,25 @@ def test_mask_fn_attention_bias(self): expected = attention_bias.bool_to_bias(expected)[:, None, :] self.assertNestedEqual(bias.value(), expected) + def test_mask_fn_attention_bias_with_target_positions(self): + # Ensure that MaskFnAttentionBias provides the mask_fn callback with target_positions and + # source_positions tensors of the same rank. + batch, target_len, source_len = 2, 5, 4 + time_step = jnp.arange(batch) + + def mask_fn(target_positions, source_positions): + self.assertEqual(target_positions.shape, (batch, target_len, 1)) + self.assertEqual(source_positions.shape, (1, 1, source_len)) + return attention_bias.causal_mask(target_positions, source_positions) + + bias = attention_bias.MaskFnAttentionBias( + mask=mask_fn, shape=(target_len, source_len), target_positions=time_step + ) + ref_bias = attention_bias.MaskFnAttentionBias( + attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step + ) + chex.assert_trees_all_close(bias.value(), ref_bias.value()) + def test_bool_tensor_attention_bias(self): bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool)) self.assertNestedEqual( @@ -278,3 +297,7 @@ def test_astype(self): self.assertEqual(bias.value().dtype, jnp.float32) bias = bias.astype(jnp.bfloat16) self.assertEqual(bias.value().dtype, jnp.bfloat16) + + +if __name__ == "__main__": + absltest.main()