Skip to content

Commit

Permalink
DiT: implement init_states and extend_step for DiT transformer
Browse files Browse the repository at this point in the history
Beyond vision, speech also uses it for generative model, and speech needs
streaming decoding.
  • Loading branch information
ds-hwang committed Dec 10, 2024
1 parent 8708ce1 commit 32af532
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 3 deletions.
128 changes: 128 additions & 0 deletions axlearn/common/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from axlearn.common.layers import Dropout, Embedding, LayerNormStateless, Linear, get_activation_fn
from axlearn.common.module import Module, Tensor
from axlearn.common.utils import NestedTensor, TensorSpec


def modulate(*, x, shift, scale):
Expand Down Expand Up @@ -430,6 +431,84 @@ def forward(
output = input + x
return output

def init_states(self, input_spec: TensorSpec) -> NestedTensor:
"""Initializes cache for autoregressive cached decoding.
Args:
input_spec: TensorSpec [batch, num_length, target_dim] corresponding to query vector.
Returns:
init_states: A Nested Tensor state depending on the `attention` layer implementation.
"""
states = dict()
states["attention"], _ = self.attention.init_states(
time_step=None, query=input_spec, attention_logit_biases=None
)
return states

def extend_step(
self,
cached_states: NestedTensor,
target: Tensor,
*,
shift: Optional[Tensor] = None,
scale: Optional[Tensor] = None,
gate: Optional[Tensor] = None,
) -> tuple[NestedTensor, Tensor]:
"""Computes the value vector given the query of the current step.
This function is used by autoregressive decoding.
Args:
cached_states: A `NestedTensor` object containing tensors which are the
results of previous attentions, and index used for fast decoding. Contains
"attention" cached states.
target: target tensor with shape [batch_size, step_length, target_dim].
shift: If provided, shifting the norm tensor with shape [batch_size, target_dim] and
scale should be provided.
scale: If provided, scaling the norm tensor with shape [batch_size, target_dim] and
shift should be provided.
gate: If provided, applying before the residual addition with shape
[batch_size, target_dim].
Returns:
A tuple (cached_states, output):
* cached_states: A NestedTensor of cache states.
* output: A output tensor of shape [batch, steps, target_dim]
"""
if (shift is None) != (scale is None):
raise ValueError("shift and scale must be both provided or both None.")

cfg = self.config
if cfg.structure == "prenorm":
x = self.norm(target)
elif cfg.structure == "hybridnorm":
x = self.prenorm(target)
elif cfg.structure == "postnorm":
x = target

if shift is not None and scale is not None:
x = modulate(x=x, shift=shift, scale=scale)

# It supports only the (sliding window) causal case, which is handled by attention itself.
attention_logit_biases = None
attn_states, attn_output = self.attention.extend_step(
cached_states=cached_states["attention"],
query=x,
attention_logit_biases=attention_logit_biases,
)
x = attn_output.data

if cfg.structure == "postnorm":
x = self.norm(x)
elif cfg.structure == "hybridnorm":
x = self.postnorm(x)

if gate is not None:
x = x * jnp.expand_dims(gate, 1)

output = target + x
return dict(attention=attn_states), output


class DiTBlock(BaseLayer):
"""The DiT block layer.
Expand Down Expand Up @@ -477,6 +556,55 @@ def forward(self, *, input: Tensor, condition: Tensor) -> Tensor:

return x

def init_states(self, input_spec: TensorSpec) -> NestedTensor:
"""Initializes cache for autoregressive cached decoding.
Args:
input_spec: TensorSpec [batch, target_length, target_dim] corresponding to query vector.
Returns:
init_states: A Nested Tensor state depending on the `attention` layer implementation.
"""
states = dict()
states["attention"] = self.attention.init_states(input_spec=input_spec)
return states

def extend_step(
self,
cached_states: NestedTensor,
target: Tensor,
*,
condition: Tensor,
) -> tuple[NestedTensor, Tensor]:
"""Computes the value vector given the query of the current step.
This function is used by autoregressive decoding.
Args:
cached_states: A `NestedTensor` object containing tensors which are the
results of previous attentions, and index used for fast decoding. Contains
"attention" cached states.
target: target tensor with shape [batch_size, step_length, input_dim].
condition: tensor with shape [batch_size, input_dim] for generating
layer norm shift, scale, and gate.
Returns:
A tuple (cached_states, output):
* cached_states: A NestedTensor of cache states.
* output: A output tensor of shape [batch, steps, target_dim]
"""
layer_norm_params = self.adaln(condition)
shift_attn, scale_attn, gate_attn = layer_norm_params[0:3]
shift_ffn, scale_ffn, gate_ffn = layer_norm_params[3:6]
attn_states, x = self.attention.extend_step(
cached_states=cached_states["attention"],
target=target,
shift=shift_attn,
scale=scale_attn,
gate=gate_attn,
)
x = self.feed_forward(input=x, shift=shift_ffn, scale=scale_ffn, gate=gate_ffn)
return dict(attention=attn_states), x


class DiTFinalLayer(BaseLayer):
"""The DiT final layer.
Expand Down
134 changes: 131 additions & 3 deletions axlearn/common/dit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import numpy as np
import pytest
import torch
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
from torch import nn

from axlearn.common.attention_bias import NEG_INF
from axlearn.common.attention_bias import NEG_INF, causal_mask, sliding_window_causal_mask
from axlearn.common.config import config_for_function
from axlearn.common.dit import (
AdaptiveLayerNormModulation,
DiTAttentionLayer,
Expand All @@ -34,7 +35,7 @@
from axlearn.common.module import functional as F
from axlearn.common.test_utils import assert_allclose
from axlearn.common.torch_utils import parameters_from_torch_layer
from axlearn.common.utils import as_tensor
from axlearn.common.utils import TensorSpec, as_tensor
from axlearn.common.vision_transformer import ConvertToSequence


Expand Down Expand Up @@ -565,6 +566,73 @@ def test_dit_attn_optional_input_value_error(self):
)
assert_allclose(layer_output.shape, inputs.shape)

@parameterized.parameters(
[causal_mask, config_for_function(sliding_window_causal_mask).set(sliding_window_size=10)]
)
def test_dit_attn_extend_step(self, mask):
batch_size = 2
seq_len = 12
dim = 32
num_heads = 2
prng_key = jax.random.PRNGKey(123)
prng_key, data_key = jax.random.split(prng_key)
inputs = jax.random.normal(data_key, shape=(batch_size, seq_len, dim))
shift = jax.random.normal(data_key, shape=(batch_size, dim))
scale = jax.random.normal(data_key, shape=(batch_size, dim))
gate = jax.random.normal(data_key, shape=(batch_size, dim))

layer_cfg = DiTAttentionLayer.default_config().set(
name="test",
source_dim=dim,
target_dim=dim,
)
layer_cfg.attention.num_heads = num_heads
layer_cfg.attention.mask = mask

layer = layer_cfg.instantiate(parent=None)
prng_key, init_key = jax.random.split(prng_key)
layer_params = layer.initialize_parameters_recursively(init_key)

fwd_output, _ = F(
layer,
inputs=dict(
input=inputs,
shift=shift,
scale=scale,
gate=gate,
attention_logit_biases=None,
),
state=layer_params,
is_training=False,
prng_key=prng_key,
)

cached_states = layer.init_states(input_spec=TensorSpec(inputs.shape, inputs.dtype))
step_sizes = (1, 2, 3)
step_outputs = []
i = 0
while i < seq_len:
step_size = step_sizes[i % len(step_sizes)]
step_inputs = dict(
cached_states=cached_states,
target=inputs[:, i : i + step_size],
shift=shift,
scale=scale,
gate=gate,
)
i += step_size
(cached_states, step_output), _ = F(
layer,
inputs=step_inputs,
state=layer_params,
is_training=False,
prng_key=prng_key,
method="extend_step",
)
step_outputs.append(step_output)
step_outputs = jnp.concatenate(step_outputs, axis=1)
assert_allclose(step_outputs, fwd_output)


class TestDiTBlock(parameterized.TestCase):
"""Tests DiTBlock."""
Expand Down Expand Up @@ -604,6 +672,62 @@ def test_dit_block(self):

assert_allclose(layer_output, as_tensor(ref_output))

@parameterized.parameters(
[causal_mask, config_for_function(sliding_window_causal_mask).set(sliding_window_size=10)]
)
def test_dit_block_extend_step(self, mask):
batch_size = 2
seq_len = 12
dim = 32
num_heads = 2
prng_key = jax.random.PRNGKey(123)
prng_key, data_key = jax.random.split(prng_key)
inputs = jax.random.normal(data_key, shape=(batch_size, seq_len, dim))
condition = jax.random.normal(data_key, shape=(batch_size, dim))

layer_cfg = DiTBlock.default_config().set(name="test", input_dim=dim)
layer_cfg.attention.attention.num_heads = num_heads
layer_cfg.attention.attention.mask = mask

layer = layer_cfg.instantiate(parent=None)
prng_key, init_key = jax.random.split(prng_key)
layer_params = layer.initialize_parameters_recursively(init_key)

fwd_output, _ = F(
layer,
inputs=dict(
input=inputs,
condition=condition,
),
state=layer_params,
is_training=False,
prng_key=prng_key,
)

cached_states = layer.init_states(input_spec=TensorSpec(inputs.shape, inputs.dtype))
step_sizes = (1, 2, 3)
step_outputs = []
i = 0
while i < seq_len:
step_size = step_sizes[i % len(step_sizes)]
step_inputs = dict(
cached_states=cached_states,
target=inputs[:, i : i + step_size],
condition=condition,
)
i += step_size
(cached_states, step_output), _ = F(
layer,
inputs=step_inputs,
state=layer_params,
is_training=False,
prng_key=prng_key,
method="extend_step",
)
step_outputs.append(step_output)
step_outputs = jnp.concatenate(step_outputs, axis=1)
assert_allclose(step_outputs, fwd_output)


class TestDiTFinalLayer(parameterized.TestCase):
"""Tests DiTFinalLayer."""
Expand Down Expand Up @@ -682,3 +806,7 @@ def test_dit_patch_embed(self):
)

assert_allclose(layer_output, as_tensor(ref_output))


if __name__ == "__main__":
absltest.main()

0 comments on commit 32af532

Please sign in to comment.