Skip to content

Commit

Permalink
MaskFnAttentionBias._bool_value passes the same rank position tensors…
Browse files Browse the repository at this point in the history
… to mask_fn.

When target_positions is set, a rank 3 target_positions and a rank 2
source_positions are passed to mask_fn. From the perspective of a downstream
defining mask_fn, this is a big surprise.
  • Loading branch information
ds-hwang committed Dec 12, 2024
1 parent 73625c9 commit 50da898
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
17 changes: 11 additions & 6 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,14 +489,19 @@ def _bool_value(self) -> Optional[Tensor]:
- If `target_positions` is None: [target_len, source_len]
- Else: [batch, target_len, source_len].
"""
# [target_len, 1], [1, source_len]
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
if self.target_positions is not None:
target_positions = self.target_positions
if target_positions.ndim == 1:
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
while target_positions.ndim < 3:
target_positions = target_positions[..., None]
target_positions = self.target_positions # [batch]
if target_positions.ndim != 1:
raise ValueError(
f"target_positions must be a rank 1 tensor, but {target_positions.shape}."
)
# [batch, 1] + [1, target_len] = [batch, target_len]
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])[None, :]
target_positions = target_positions[:, :, None] # [batch, target_len, 1]
source_positions = source_positions[None, :, :] # [1, 1, source_len]
return self.mask(target_positions, source_positions) # pylint: disable=not-callable

@classmethod
Expand Down
25 changes: 24 additions & 1 deletion axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chex
import jax.numpy as jnp
import jax.util
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.sharding import PartitionSpec

from axlearn.common import attention_bias, test_utils
Expand Down Expand Up @@ -267,6 +267,25 @@ def test_mask_fn_attention_bias(self):
expected = attention_bias.bool_to_bias(expected)[:, None, :]
self.assertNestedEqual(bias.value(), expected)

def test_mask_fn_attention_bias_with_target_positions(self):
batch, target_len, source_len = 2, 5, 4
time_step = jnp.arange(batch)

def mask_fn(target_positions, source_positions):
self.assertEqual(target_positions.shape, (batch, target_len, 1))
self.assertEqual(source_positions.shape, (1, 1, source_len))
return attention_bias.causal_mask(target_positions, source_positions)

bias = attention_bias.MaskFnAttentionBias(
mask=mask_fn, shape=(target_len, source_len), target_positions=time_step
)
self.assertIsInstance(bias, attention_bias.MaskFnAttentionBias)

ref_bias = attention_bias.MaskFnAttentionBias(
attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step
)
chex.assert_trees_all_close(bias.value(), ref_bias.value())

def test_bool_tensor_attention_bias(self):
bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool))
self.assertNestedEqual(
Expand All @@ -278,3 +297,7 @@ def test_astype(self):
self.assertEqual(bias.value().dtype, jnp.float32)
bias = bias.astype(jnp.bfloat16)
self.assertEqual(bias.value().dtype, jnp.bfloat16)


if __name__ == "__main__":
absltest.main()

0 comments on commit 50da898

Please sign in to comment.