diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index f6a29213..d6eb3986 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -489,14 +489,19 @@ def _bool_value(self) -> Optional[Tensor]: - If `target_positions` is None: [target_len, source_len] - Else: [batch, target_len, source_len]. """ + # [target_len, 1], [1, source_len] target_positions, source_positions = jnp.indices(self.shape, sparse=True) if self.target_positions is not None: - target_positions = self.target_positions - if target_positions.ndim == 1: - # pylint: disable-next=unsubscriptable-object - target_positions = target_positions[:, None] + jnp.arange(self.shape[0]) - while target_positions.ndim < 3: - target_positions = target_positions[..., None] + target_positions = self.target_positions # [batch] + if target_positions.ndim != 1: + raise ValueError( + f"target_positions must be a rank 1 tensor, but {target_positions.shape}." + ) + # [batch, 1] + [1, target_len] = [batch, target_len] + # pylint: disable-next=unsubscriptable-object + target_positions = target_positions[:, None] + jnp.arange(self.shape[0])[None, :] + target_positions = target_positions[:, :, None] # [batch, target_len, 1] + source_positions = source_positions[None, :, :] # [1, 1, source_len] return self.mask(target_positions, source_positions) # pylint: disable=not-callable @classmethod diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py index 0932df10..8b3bdecd 100644 --- a/axlearn/common/attention_bias_test.py +++ b/axlearn/common/attention_bias_test.py @@ -6,7 +6,7 @@ import chex import jax.numpy as jnp import jax.util -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.sharding import PartitionSpec from axlearn.common import attention_bias, test_utils @@ -267,6 +267,25 @@ def test_mask_fn_attention_bias(self): expected = attention_bias.bool_to_bias(expected)[:, None, :] self.assertNestedEqual(bias.value(), expected) + def test_mask_fn_attention_bias_with_target_positions(self): + batch, target_len, source_len = 2, 5, 4 + time_step = jnp.arange(batch) + + def mask_fn(target_positions, source_positions): + self.assertEqual(target_positions.shape, (batch, target_len, 1)) + self.assertEqual(source_positions.shape, (1, 1, source_len)) + return attention_bias.causal_mask(target_positions, source_positions) + + bias = attention_bias.MaskFnAttentionBias( + mask=mask_fn, shape=(target_len, source_len), target_positions=time_step + ) + self.assertIsInstance(bias, attention_bias.MaskFnAttentionBias) + + ref_bias = attention_bias.MaskFnAttentionBias( + attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step + ) + chex.assert_trees_all_close(bias.value(), ref_bias.value()) + def test_bool_tensor_attention_bias(self): bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool)) self.assertNestedEqual( @@ -278,3 +297,7 @@ def test_astype(self): self.assertEqual(bias.value().dtype, jnp.float32) bias = bias.astype(jnp.bfloat16) self.assertEqual(bias.value().dtype, jnp.bfloat16) + + +if __name__ == "__main__": + absltest.main()