Skip to content

Commit

Permalink
MaskFnAttentionBias._bool_value passes the same rank position tensors…
Browse files Browse the repository at this point in the history
… to mask_fn.

When target_positions is set, a rank 3 target_positions and a rank 2
source_positions are passed to mask_fn. From the perspective of a downstream
defining mask_fn, this is a big surprise.
  • Loading branch information
ds-hwang committed Dec 17, 2024
1 parent 01b762e commit 65f3827
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
5 changes: 5 additions & 0 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor:
x = f(jnp.asarray([1,2]), jnp.asarray([3,4]))
assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None]
```
* Both tensors have the same rank (either 2 or 3), as batch dim is optional.
* If given non-scalar arguments of different shapes, the result must be the same if we
first broadcast the arguments against each other to make them have the same shape.
* Beyond requiring broadcastability, must not impose any constraints on the shapes of its
Expand Down Expand Up @@ -489,12 +490,16 @@ def _bool_value(self) -> Optional[Tensor]:
- If `target_positions` is None: [target_len, source_len]
- Else: [batch, target_len, source_len].
"""
# [target_len, 1], [1, source_len]
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
if self.target_positions is not None:
target_positions = self.target_positions
if target_positions.ndim == 1:
# [batch, 1] + [target_len] = [batch, target_len]
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
while source_positions.ndim < 3:
source_positions = source_positions[None, ...]
while target_positions.ndim < 3:
target_positions = target_positions[..., None]
return self.mask(target_positions, source_positions) # pylint: disable=not-callable
Expand Down
25 changes: 24 additions & 1 deletion axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chex
import jax.numpy as jnp
import jax.util
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.sharding import PartitionSpec

from axlearn.common import attention_bias, test_utils
Expand Down Expand Up @@ -267,6 +267,25 @@ def test_mask_fn_attention_bias(self):
expected = attention_bias.bool_to_bias(expected)[:, None, :]
self.assertNestedEqual(bias.value(), expected)

def test_mask_fn_attention_bias_with_target_positions(self):
# Ensure that MaskFnAttentionBias provides the mask_fn callback with target_positions and
# source_positions tensors of the same rank.
batch, target_len, source_len = 2, 5, 4
time_step = jnp.arange(batch)

def mask_fn(target_positions, source_positions):
self.assertEqual(target_positions.shape, (batch, target_len, 1))
self.assertEqual(source_positions.shape, (1, 1, source_len))
return attention_bias.causal_mask(target_positions, source_positions)

bias = attention_bias.MaskFnAttentionBias(
mask=mask_fn, shape=(target_len, source_len), target_positions=time_step
)
ref_bias = attention_bias.MaskFnAttentionBias(
attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step
)
chex.assert_trees_all_close(bias.value(), ref_bias.value())

def test_bool_tensor_attention_bias(self):
bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool))
self.assertNestedEqual(
Expand All @@ -278,3 +297,7 @@ def test_astype(self):
self.assertEqual(bias.value().dtype, jnp.float32)
bias = bias.astype(jnp.bfloat16)
self.assertEqual(bias.value().dtype, jnp.bfloat16)


if __name__ == "__main__":
absltest.main()

0 comments on commit 65f3827

Please sign in to comment.