diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index 478f00d1..d605242b 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -440,6 +440,7 @@ def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor: x = f(jnp.asarray([1,2]), jnp.asarray([3,4])) assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None] ``` + * Both tensors have the same rank (either 2 or 3), as batch dim is optional. * If given non-scalar arguments of different shapes, the result must be the same if we first broadcast the arguments against each other to make them have the same shape. * Beyond requiring broadcastability, must not impose any constraints on the shapes of its @@ -494,14 +495,14 @@ def _bool_value(self) -> Optional[Tensor]: NotImplementedError. If `target_positions.ndim not in [1,2]`. """ target_positions, source_positions = jnp.indices(self.shape, sparse=True) - # Shape: [batch, target_len, source_len]. + # Shape: [1, target_len, 1], [1, 1, source_len]. target_positions, source_positions = target_positions[None], source_positions[None] if self.target_positions is not None: target_positions = self.target_positions if target_positions.ndim not in [1, 2]: raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.") if target_positions.ndim == 1: - # Shape: [batch, target_len]. + # Shape: [batch, 1] + [target_len] = [batch, target_len] # pylint: disable-next=unsubscriptable-object target_positions = target_positions[:, None] + jnp.arange(self.shape[0]) elif target_positions.ndim == 2: diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py index 531d6d22..358c6291 100644 --- a/axlearn/common/attention_bias_test.py +++ b/axlearn/common/attention_bias_test.py @@ -6,7 +6,7 @@ import chex import jax.numpy as jnp import jax.util -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.sharding import PartitionSpec from axlearn.common import attention_bias, test_utils @@ -287,6 +287,25 @@ def test_mask_fn_attention_bias_target_positions_ndim(self): ) self.assertNestedEqual(bias.bool_value(), expected) + def test_mask_fn_attention_bias_with_target_positions(self): + # Ensure that MaskFnAttentionBias provides the mask_fn callback with target_positions and + # source_positions tensors of the same rank. + batch, target_len, source_len = 2, 5, 4 + time_step = jnp.arange(batch) + + def mask_fn(target_positions, source_positions): + self.assertEqual(target_positions.shape, (batch, target_len, 1)) + self.assertEqual(source_positions.shape, (1, 1, source_len)) + return attention_bias.causal_mask(target_positions, source_positions) + + bias = attention_bias.MaskFnAttentionBias( + mask=mask_fn, shape=(target_len, source_len), target_positions=time_step + ) + ref_bias = attention_bias.MaskFnAttentionBias( + attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step + ) + chex.assert_trees_all_close(bias.value(), ref_bias.value()) + def test_bool_tensor_attention_bias(self): bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool)) self.assertNestedEqual( @@ -298,3 +317,7 @@ def test_astype(self): self.assertEqual(bias.value().dtype, jnp.float32) bias = bias.astype(jnp.bfloat16) self.assertEqual(bias.value().dtype, jnp.bfloat16) + + +if __name__ == "__main__": + absltest.main()