diff --git a/axlearn/common/dit.py b/axlearn/common/dit.py index 2a63f9ed..766f96e9 100644 --- a/axlearn/common/dit.py +++ b/axlearn/common/dit.py @@ -27,6 +27,7 @@ ) from axlearn.common.layers import Dropout, Embedding, LayerNormStateless, Linear, get_activation_fn from axlearn.common.module import Module, Tensor +from axlearn.common.utils import NestedTensor, TensorSpec def modulate(*, x, shift, scale): @@ -430,6 +431,84 @@ def forward( output = input + x return output + def init_states(self, input_spec: TensorSpec) -> NestedTensor: + """Initializes cache for autoregressive cached decoding. + + Args: + input_spec: TensorSpec [batch, num_length, target_dim] corresponding to query vector. + + Returns: + init_states: A Nested Tensor state depending on the `attention` layer implementation. + """ + states = dict() + states["attention"], _ = self.attention.init_states( + time_step=None, query=input_spec, attention_logit_biases=None + ) + return states + + def extend_step( + self, + cached_states: NestedTensor, + target: Tensor, + *, + shift: Optional[Tensor] = None, + scale: Optional[Tensor] = None, + gate: Optional[Tensor] = None, + ) -> tuple[NestedTensor, Tensor]: + """Computes the value vector given the query of the current step. + This function is used by autoregressive decoding. + + Args: + cached_states: A `NestedTensor` object containing tensors which are the + results of previous attentions, and index used for fast decoding. Contains + "attention" cached states. + target: target tensor with shape [batch_size, step_length, target_dim]. + shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and + scale should be provided. + scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and + shift should be provided. + gate: If provided, applying before the residual addition with shape + [batch_size, target_dim]. + + Returns: + A tuple (cached_states, output): + * cached_states: A NestedTensor of cache states. + * output: A output tensor of shape [batch, steps, target_dim] + """ + if (shift is None) != (scale is None): + raise ValueError("shift and scale must be both provided or both None.") + + cfg = self.config + if cfg.structure == "prenorm": + x = self.norm(target) + elif cfg.structure == "hybridnorm": + x = self.prenorm(target) + elif cfg.structure == "postnorm": + x = target + + if shift is not None and scale is not None: + x = modulate(x=x, shift=shift, scale=scale) + + # It supports only the (sliding window) causal case, which is handled by attention itself. + attention_logit_biases = None + attn_states, attn_output = self.attention.extend_step( + cached_states=cached_states["attention"], + query=x, + attention_logit_biases=attention_logit_biases, + ) + x = attn_output.data + + if cfg.structure == "postnorm": + x = self.norm(x) + elif cfg.structure == "hybridnorm": + x = self.postnorm(x) + + if gate is not None: + x = x * jnp.expand_dims(gate, 1) + + output = target + x + return dict(attention=attn_states), output + class DiTBlock(BaseLayer): """The DiT block layer. @@ -477,6 +556,55 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor: return x + def init_states(self, input_spec: TensorSpec) -> NestedTensor: + """Initializes cache for autoregressive cached decoding. + + Args: + input_spec: TensorSpec [batch, target_length, target_dim] corresponding to query vector. + + Returns: + init_states: A Nested Tensor state depending on the `attention` layer implementation. + """ + states = dict() + states["attention"] = self.attention.init_states(input_spec=input_spec) + return states + + def extend_step( + self, + cached_states: NestedTensor, + target: Tensor, + *, + condition: Tensor, + ) -> tuple[NestedTensor, Tensor]: + """Computes the value vector given the query of the current step. + This function is used by autoregressive decoding. + + Args: + cached_states: A `NestedTensor` object containing tensors which are the + results of previous attentions, and index used for fast decoding. Contains + "attention" cached states. + target: target tensor with shape [batch_size, step_length, input_dim]. + condition: tensor with shape [batch_size, input_dim] for generating + layer norm shift, scale, and gate. + + Returns: + A tuple (cached_states, output): + * cached_states: A NestedTensor of cache states. + * output: A output tensor of shape [batch, steps, target_dim] + """ + layer_norm_params = self.adaln(condition) + shift_attn, scale_attn, gate_attn = layer_norm_params[0:3] + shift_ffn, scale_ffn, gate_ffn = layer_norm_params[3:6] + attn_states, x = self.attention.extend_step( + cached_states=cached_states["attention"], + target=target, + shift=shift_attn, + scale=scale_attn, + gate=gate_attn, + ) + x = self.feed_forward(input=x, shift=shift_ffn, scale=scale_ffn, gate=gate_ffn) + return dict(attention=attn_states), x + class DiTFinalLayer(BaseLayer): """The DiT final layer. diff --git a/axlearn/common/dit_test.py b/axlearn/common/dit_test.py index 5233c389..f4f9d4fa 100644 --- a/axlearn/common/dit_test.py +++ b/axlearn/common/dit_test.py @@ -16,11 +16,12 @@ import numpy as np import pytest import torch -from absl.testing import parameterized +from absl.testing import absltest, parameterized from timm.models.vision_transformer import Attention, Mlp, PatchEmbed from torch import nn -from axlearn.common.attention_bias import NEG_INF +from axlearn.common.attention_bias import NEG_INF, causal_mask, sliding_window_causal_mask +from axlearn.common.config import config_for_function from axlearn.common.dit import ( AdaptiveLayerNormModulation, DiTAttentionLayer, @@ -34,7 +35,7 @@ from axlearn.common.module import functional as F from axlearn.common.test_utils import assert_allclose from axlearn.common.torch_utils import parameters_from_torch_layer -from axlearn.common.utils import as_tensor +from axlearn.common.utils import TensorSpec, as_tensor from axlearn.common.vision_transformer import ConvertToSequence @@ -565,6 +566,73 @@ def test_dit_attn_optional_input_value_error(self): ) assert_allclose(layer_output.shape, inputs.shape) + @parameterized.parameters( + [causal_mask, config_for_function(sliding_window_causal_mask).set(sliding_window_size=10)] + ) + def test_dit_attn_extend_step(self, mask): + batch_size = 2 + seq_len = 12 + dim = 32 + num_heads = 2 + prng_key = jax.random.PRNGKey(123) + prng_key, data_key = jax.random.split(prng_key) + inputs = jax.random.normal(data_key, shape=(batch_size, seq_len, dim)) + shift = jax.random.normal(data_key, shape=(batch_size, dim)) + scale = jax.random.normal(data_key, shape=(batch_size, dim)) + gate = jax.random.normal(data_key, shape=(batch_size, dim)) + + layer_cfg = DiTAttentionLayer.default_config().set( + name="test", + source_dim=dim, + target_dim=dim, + ) + layer_cfg.attention.num_heads = num_heads + layer_cfg.attention.mask = mask + + layer = layer_cfg.instantiate(parent=None) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + + fwd_output, _ = F( + layer, + inputs=dict( + input=inputs, + shift=shift, + scale=scale, + gate=gate, + attention_logit_biases=None, + ), + state=layer_params, + is_training=False, + prng_key=prng_key, + ) + + cached_states = layer.init_states(input_spec=TensorSpec(inputs.shape, inputs.dtype)) + step_sizes = (1, 2, 3) + step_outputs = [] + i = 0 + while i < seq_len: + step_size = step_sizes[i % len(step_sizes)] + step_inputs = dict( + cached_states=cached_states, + target=inputs[:, i : i + step_size], + shift=shift, + scale=scale, + gate=gate, + ) + i += step_size + (cached_states, step_output), _ = F( + layer, + inputs=step_inputs, + state=layer_params, + is_training=False, + prng_key=prng_key, + method="extend_step", + ) + step_outputs.append(step_output) + step_outputs = jnp.concatenate(step_outputs, axis=1) + assert_allclose(step_outputs, fwd_output) + class TestDiTBlock(parameterized.TestCase): """Tests DiTBlock.""" @@ -604,6 +672,62 @@ def test_dit_block(self): assert_allclose(layer_output, as_tensor(ref_output)) + @parameterized.parameters( + [causal_mask, config_for_function(sliding_window_causal_mask).set(sliding_window_size=10)] + ) + def test_dit_block_extend_step(self, mask): + batch_size = 2 + seq_len = 12 + dim = 32 + num_heads = 2 + prng_key = jax.random.PRNGKey(123) + prng_key, data_key = jax.random.split(prng_key) + inputs = jax.random.normal(data_key, shape=(batch_size, seq_len, dim)) + condition = jax.random.normal(data_key, shape=(batch_size, dim)) + + layer_cfg = DiTBlock.default_config().set(name="test", input_dim=dim) + layer_cfg.attention.attention.num_heads = num_heads + layer_cfg.attention.attention.mask = mask + + layer = layer_cfg.instantiate(parent=None) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + + fwd_output, _ = F( + layer, + inputs=dict( + input=inputs, + condition=condition, + ), + state=layer_params, + is_training=False, + prng_key=prng_key, + ) + + cached_states = layer.init_states(input_spec=TensorSpec(inputs.shape, inputs.dtype)) + step_sizes = (1, 2, 3) + step_outputs = [] + i = 0 + while i < seq_len: + step_size = step_sizes[i % len(step_sizes)] + step_inputs = dict( + cached_states=cached_states, + target=inputs[:, i : i + step_size], + condition=condition, + ) + i += step_size + (cached_states, step_output), _ = F( + layer, + inputs=step_inputs, + state=layer_params, + is_training=False, + prng_key=prng_key, + method="extend_step", + ) + step_outputs.append(step_output) + step_outputs = jnp.concatenate(step_outputs, axis=1) + assert_allclose(step_outputs, fwd_output) + class TestDiTFinalLayer(parameterized.TestCase): """Tests DiTFinalLayer.""" @@ -682,3 +806,7 @@ def test_dit_patch_embed(self): ) assert_allclose(layer_output, as_tensor(ref_output)) + + +if __name__ == "__main__": + absltest.main()