Skip to content

Commit

Permalink
Unify init and prefill for attention layers. (apple#860)
Browse files Browse the repository at this point in the history
* Unify init and prefill for attention layers.

* Fix some types and docstrings.
  • Loading branch information
markblee authored Nov 27, 2024
1 parent bad0f0f commit fc761b0
Show file tree
Hide file tree
Showing 11 changed files with 644 additions and 531 deletions.
622 changes: 294 additions & 328 deletions axlearn/common/attention.py

Large diffs are not rendered by default.

141 changes: 123 additions & 18 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
Nested,
PartitionSpec,
Tensor,
TensorSpec,
VDict,
as_tensor,
flatten_items,
Expand Down Expand Up @@ -1472,7 +1473,10 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex
inputs=dict(query=query),
)

cache_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len)
cache_state, init_output = layer.init_states(
time_step=None, query=TensorSpec([batch_size, tgt_len])
)
self.assertIsNone(init_output)
step_querys = []
step_keys = step_values = None
for t in range(0, tgt_len, extend_step_len):
Expand Down Expand Up @@ -1531,18 +1535,19 @@ def __init__(self, cfg: Config, *, parent: Module):
qkv_linear = parent.qkv_linear
state = qkv_linear.initialize_parameters_recursively(jax.random.PRNGKey(0))

# Check dtypes from init_states
cache, _ = F(
# Check dtypes from init_states.
(cache, init_output), _ = F(
qkv_linear,
prng_key=jax.random.PRNGKey(0),
state=state,
inputs=dict(
target_batch_size=target_batch_size,
target_max_len=target_max_len,
time_step=None,
query=TensorSpec([target_batch_size, target_max_len]),
),
method="init_states",
is_training=False,
)
self.assertIsNone(init_output)
self.assertEqual(cache["key"].dtype, dtype)
self.assertEqual(cache["value"].dtype, dtype)

Expand All @@ -1562,7 +1567,7 @@ def __init__(self, cfg: Config, *, parent: Module):
prng_key=jax.random.PRNGKey(0),
state=state,
inputs=dict(time_step=time_step, query=query),
method="prefill_states",
method="init_states",
is_training=False,
)
self.assertEqual(init_state["key"].dtype, dtype)
Expand Down Expand Up @@ -2448,9 +2453,14 @@ def _test_extend_step(
inputs=inputs,
)

initial_state = layer.init_states(
target_batch_size=batch_size, target_max_len=tgt_len, kv_state=kv_state
initial_state, initial_output = layer.init_states(
time_step=None,
query=TensorSpec([batch_size, tgt_len]),
kv_state=kv_state,
# This is unused for initializing state from scratch.
attention_logit_biases=None,
)
self.assertIsNone(initial_output)
if kv_state is None:
for k in ["key", "value"]:
# Check that the cache dtype is inferred as the layer dtype.
Expand Down Expand Up @@ -2619,7 +2629,7 @@ def _test_prefill_states(
attention_logit_biases=attention_logit_biases,
return_aux=return_aux,
),
method="prefill_states",
method="init_states",
)

# Check time_step and shapes of state.
Expand Down Expand Up @@ -3227,6 +3237,96 @@ def test_multihead_attention_xl(self):
)


class TransformerAttentionLayerTest(TestCase):
@parameterized.parameters([False, True])
def test_forward_vs_extend_step(self, with_source: bool):
init_prng, target_prng, source_prng = jax.random.split(jax.random.PRNGKey(0), 3)

model_dim = 8
layer_kwargs = dict(target_dim=model_dim, source_dim=model_dim)
cfg: TransformerAttentionLayer.Config = TransformerAttentionLayer.default_config().set(
**layer_kwargs
)
cfg.attention.set(num_heads=2, mask=causal_mask)
layer: TransformerAttentionLayer = cfg.set(name="test").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=init_prng)

batch, decode_len = 2, 6
target = jax.random.uniform(target_prng, shape=[batch, decode_len, model_dim])
input_kwargs = {}

if with_source:
input_kwargs.update(
source=jax.random.uniform(source_prng, shape=[batch, decode_len, model_dim])
)

forward_outputs, _ = F(
layer,
inputs=dict(target=jnp.asarray(target), **input_kwargs),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)

for start_time_step in (-1, 0, 2, decode_len):
if start_time_step < 0:
(cached_states, init_outputs), _ = F(
layer,
inputs=dict(
time_step=None,
target=TensorSpec(target.shape, target.dtype),
**input_kwargs,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
method="init_states",
)
self.assertIsNone(init_outputs)
data = jnp.zeros([batch, decode_len, model_dim])
start_time_step = 0
else:
(cached_states, prefill_outputs), _ = F(
layer,
inputs=dict(
time_step=jnp.array([start_time_step] * batch, dtype=jnp.int32),
target=target,
**input_kwargs,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
method="init_states",
)
data = prefill_outputs.data

data = jnp.einsum("btd->tbd", data)

for time_step in range(start_time_step, decode_len):
extend_kwargs = {}
for k, v in input_kwargs.items():
extend_kwargs[k] = jnp.asarray(v[:, time_step : time_step + 1, :])

(cached_states, extend_outputs), _ = F(
layer,
inputs=dict(
target=jnp.asarray(target[:, time_step : time_step + 1, :]),
cached_states=cached_states,
**extend_kwargs,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
method="extend_step",
)
data = data.at[time_step].set(jnp.squeeze(extend_outputs.data, axis=1))

data = jnp.einsum("tbd->btd", data)

# Prefill + extend_step == forward.
assert_allclose(forward_outputs.data, data)


class TransformerFeedForwardLayerTest(TestCase):
@parameterized.parameters(
dict(rms_norm_summary=[]),
Expand Down Expand Up @@ -3392,20 +3492,21 @@ def _test_forward_vs_extend_step(
for start_time_step in (-1, 0, 2, tgt_len):
if start_time_step > tgt_len:
continue
print(f"start_time_step={start_time_step}")
print(f"start_time_step={start_time_step} layer={type(layer)}")
if start_time_step < 0:
cached_states, _ = F(
(cached_states, init_outputs), _ = F(
layer,
inputs=dict(
target_batch_size=batch_size,
target_max_len=tgt_len,
time_step=None,
data=TensorSpec([batch_size, tgt_len]),
**input_kwargs,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
method="init_states",
)
self.assertIsNone(init_outputs)
decoder_output = jnp.zeros_like(target)
start_time_step = 0
else:
Expand All @@ -3419,7 +3520,7 @@ def _test_forward_vs_extend_step(
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
method="prefill_states",
method="init_states",
)
decoder_output = prefill_outputs.data
# Transpose to [tgt_len, batch_size, model_dim].
Expand Down Expand Up @@ -3850,7 +3951,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
batch_size, src_len, tgt_len = 10, 4, 6
num_dec_layers, model_dim, num_heads = 3, 16, 4

cfg = transformer_type.default_config().set(
cfg: BaseStackedTransformerLayer.Config = transformer_type.default_config().set(
name="test",
input_dim=model_dim,
num_layers=num_dec_layers,
Expand All @@ -3872,7 +3973,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
layer_cfg.feed_forward.hidden_dim = model_dim * 4

# Instantiate transformer stack.
layer = cfg.instantiate(parent=None)
layer: BaseStackedTransformerLayer = cfg.instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123))

target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim])
Expand All @@ -3897,7 +3998,11 @@ def test_transformer_extend_step(self, transformer_type, layer_type):
is_training=False,
prng_key=jax.random.PRNGKey(0),
)
initial_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len)
initial_state, initial_output = layer.init_states(
time_step=None,
data=TensorSpec([batch_size, tgt_len]),
)
self.assertIsNone(initial_output)
inputs = dict(
cached_states=initial_state, cross_attention_data=source, return_aux=return_aux
)
Expand Down Expand Up @@ -4036,7 +4141,7 @@ def test_transformer_prefill_states(self, transformer_type, layer_type):
cross_attention_logit_biases=cross_attention_logit_biases,
return_aux=return_aux,
),
method="prefill_states",
method="init_states",
)

# Zero-out outputs starting from initial time_step, and test that we can recover the full
Expand Down
12 changes: 7 additions & 5 deletions axlearn/common/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
current_context,
new_output_collection,
)
from axlearn.common.utils import Nested, NestedTensor, with_sharding_constraint
from axlearn.common.utils import Nested, NestedTensor, TensorSpec, with_sharding_constraint


# TODO(markblee): Remove this when we have a better solution at the decoding loop level.
Expand Down Expand Up @@ -492,7 +492,7 @@ def _forward_for_mode(
assert cached_states is not None
if input_segment_ids is not None:
raise ValueError("input_segment_ids is not supported in INIT_STATES.")
transformer_state, x = self.transformer.prefill_states(
transformer_state, x = self.transformer.init_states(
time_step=cached_states["transformer_state"],
data=x,
self_attention_logit_biases=self_attention_logit_biases,
Expand Down Expand Up @@ -584,10 +584,12 @@ def forward(
def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTensor:
"""See `BaseDecoder.init_states` for details."""
cfg: Decoder.Config = self.config
init_state, _ = self.transformer.init_states(
time_step=None,
data=TensorSpec([batch_size, max_sequence_length, cfg.dim]),
)
return dict(
transformer_state=self.transformer.init_states(
target_batch_size=batch_size, target_max_len=max_sequence_length
),
transformer_state=init_state,
input_ids=jnp.full(
(batch_size, max_sequence_length), cfg.pad_token_id, dtype=jnp.int32
),
Expand Down
16 changes: 10 additions & 6 deletions axlearn/common/encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Encoder layers."""

import math
from typing import Optional

Expand All @@ -20,7 +21,7 @@
from axlearn.common.embedding import TransformerTextEmbeddings
from axlearn.common.layers import BaseClassificationHead, set_dropout_rate_recursively
from axlearn.common.module import Module, Tensor, child_context
from axlearn.common.utils import NestedTensor
from axlearn.common.utils import NestedTensor, TensorSpec


class Encoder(BaseLayer):
Expand Down Expand Up @@ -167,12 +168,15 @@ def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTen
Returns:
The cache as a `NestedTensor` with key and value initialized.
"""
cfg: CausalEncoder.Config = self.config
init_state, _ = self.transformer.init_states(
time_step=None,
data=TensorSpec([batch_size, max_sequence_length, cfg.dim]),
)
return dict(
transformer_state=self.transformer.init_states(
target_batch_size=batch_size, target_max_len=max_sequence_length
),
transformer_state=init_state,
input_ids=jnp.full(
(batch_size, max_sequence_length), self.config.pad_token_id, dtype=jnp.int32
(batch_size, max_sequence_length), cfg.pad_token_id, dtype=jnp.int32
),
time_step=jnp.zeros(batch_size, dtype=jnp.int32),
)
Expand Down Expand Up @@ -279,7 +283,7 @@ def prefill_states(
# Note: this follows `Decoder.prefill_states` closely. Refer to that method for details.
# TODO(markblee): Possibly consolidate some of this with decoder.
x = self.emb(input_ids, token_type_ids=token_type_ids, positions=None)
transformer_state, x = self.transformer.prefill_states(
transformer_state, x = self.transformer.init_states(
time_step=time_step,
data=x,
self_attention_logit_biases=self.compute_attention_logit_biases(input_ids),
Expand Down
17 changes: 13 additions & 4 deletions axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from axlearn.common.module import Module
from axlearn.common.module import functional as F
from axlearn.common.test_utils import TestCase, is_supported_mesh_shape
from axlearn.common.utils import TensorSpec


def _fake_inputs(
Expand Down Expand Up @@ -650,12 +651,20 @@ def test_extend_step(
)

# Prepare initial states.
initial_state = test_layer.init_states(
target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state
initial_state, initial_output = test_layer.init_states(
time_step=None,
query=TensorSpec([batch, seq_len]),
kv_state=kv_state,
attention_logit_biases=None,
)
ref_initial_state = ref_layer.init_states(
target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state
ref_initial_state, ref_inital_output = ref_layer.init_states(
time_step=None,
query=TensorSpec([batch, seq_len]),
kv_state=kv_state,
attention_logit_biases=None,
)
self.assertIsNone(initial_output)
self.assertIsNone(ref_inital_output)
for k in ["key", "value"]:
self.assertEqual(ref_initial_state["i_proj"][k].dtype, dtype)
self.assertEqual(initial_state["i_proj"][k].dtype, dtype)
Expand Down
10 changes: 6 additions & 4 deletions axlearn/common/lora_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from axlearn.common.module import functional as F
from axlearn.common.param_converter import as_torch_tensor
from axlearn.common.test_utils import TestCase, assert_allclose
from axlearn.common.utils import Tensor
from axlearn.common.utils import Tensor, TensorSpec


class LoraLinearTest(TestCase):
Expand Down Expand Up @@ -233,9 +233,11 @@ def test_extend_step(self, layer):
q_proj, k_proj, v_proj = outputs
forward_outputs = jnp.stack([q_proj, k_proj, v_proj])

initial_cache_state = layer.init_states(
target_batch_size=batch_size, target_max_len=seq_len
initial_cache_state, init_output = layer.init_states(
time_step=None,
query=TensorSpec([batch_size, seq_len]),
)
self.assertIsNone(init_output)

decoder_inputs = dict(cached_states=initial_cache_state)
decoder_outputs = jnp.zeros(shape=[seq_len, 3, batch_size, num_heads, per_head_dim])
Expand Down Expand Up @@ -305,7 +307,7 @@ def test_prefill_states(self):
is_training=False,
prng_key=jax.random.PRNGKey(456),
inputs=dict(time_step=time_step, query=inputs),
method="prefill_states",
method="init_states",
)
time_step_mask = jnp.arange(seq_len) < time_step[:, None]
# [batch, tgt_len, num_heads, per_head_dim].
Expand Down
Loading

0 comments on commit fc761b0

Please sign in to comment.