Skip to content

Commit

Permalink
Generalizes decoder by taking input batch. (apple#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee authored Nov 30, 2024
1 parent ffdf2d9 commit b1b6c25
Show file tree
Hide file tree
Showing 24 changed files with 296 additions and 147 deletions.
6 changes: 3 additions & 3 deletions axlearn/audio/decoder_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ def forward(self, input_batch: Nested[Tensor]) -> tuple[Tensor, Nested[Tensor]]:
paddings=input_batch["paddings"]
)
predict_outputs = self.decoder(
input_ids=input_batch["target"]["input_ids"],
input_batch=input_batch["target"],
cross_attention_data=input_batch["inputs"],
cross_attention_logit_biases=cross_attention_logit_biases,
)
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def beam_search_decode(

with child_context("beam_search_decode", module=self.decoder):
beam_search_outputs: decoding.BeamSearchOutputs = self.decoder.beam_search_decode(
prefix=input_batch["prefix"],
input_batch=input_batch,
max_sequence_length=max_decode_len,
num_decodes=num_decodes,
cross_attention_data=input_batch["inputs"],
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def sample_decode(

with child_context("sample_decode", module=self.decoder):
sample_decode_outputs: decoding.SampleOutputs = self.decoder.sample_decode(
prefix=input_batch["prefix"],
input_batch=input_batch,
max_sequence_length=max_decode_len,
num_decodes=num_decodes,
cross_attention_data=input_batch["inputs"],
Expand Down
8 changes: 4 additions & 4 deletions axlearn/audio/model_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from axlearn.common.base_encoder_decoder import BaseEncoderDecoderModel
from axlearn.common.config import REQUIRED, Required, config_class
from axlearn.common.module import Module
from axlearn.common.utils import Nested, Tensor
from axlearn.common.utils import Nested, Tensor, validate_contains_paths


class ASRModel(BaseEncoderDecoderModel):
Expand Down Expand Up @@ -51,7 +51,7 @@ def predict(self, input_batch: Nested[Tensor]) -> Nested[Tensor]:
Returns:
A dict containing logits. The shape of logits depend on the decoder.
"""
self._validate_input_batch(input_batch, ["source", "target", "target_labels"])
validate_contains_paths(input_batch, ["source", "target", "target_labels"])
# Encoder hidden states: [batch_size, source_len, dim].
encoder_output = self.encoder(**input_batch["source"])
logits = self.decoder.predict(
Expand Down Expand Up @@ -82,7 +82,7 @@ def forward(
aux_outputs: A dict containing auxiliary outputs if `return_aux=True`, otherwise an
empty dict.
"""
self._validate_input_batch(input_batch, ["source", "target", "target_labels"])
validate_contains_paths(input_batch, ["source", "target", "target_labels"])
# Encoder hidden states: [batch_size, source_len, dim].
encoder_output = self.encoder(**input_batch["source"])
loss, aux_outputs = self.decoder(
Expand Down Expand Up @@ -115,7 +115,7 @@ def beam_search_decode(
Returns:
Beam search decode outputs.
"""
self._validate_input_batch(input_batch, ["source/inputs", "source/paddings"])
validate_contains_paths(input_batch, ["source/inputs", "source/paddings"])
encoder_output = self.encoder(**input_batch["source"])
return self.decoder.beam_search_decode(
input_batch=dict(
Expand Down
11 changes: 7 additions & 4 deletions axlearn/common/adapter_torch_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Tests PyTorch adapter layers."""

# pylint: disable=too-many-lines
import itertools
from collections import OrderedDict
Expand Down Expand Up @@ -889,9 +890,11 @@ def test_transformer_embeddings_forward(
jax.random.PRNGKey(0),
state=axlearn_layer_state,
inputs=dict(
inputs=jnp.asarray(input_ids),
token_type_ids=axlearn_token_type_ids,
positions=axlearn_positions,
input_batch=dict(
inputs=jnp.asarray(input_ids),
token_type_ids=axlearn_token_type_ids,
positions=axlearn_positions,
),
),
is_training=False,
method="forward",
Expand Down Expand Up @@ -971,7 +974,7 @@ def test_decoder_inference(self):
axlearn_layer,
jax.random.PRNGKey(0),
state=axlearn_layer_state,
inputs=dict(input_ids=jnp.asarray(input_ids)),
inputs=dict(input_batch=dict(input_ids=jnp.asarray(input_ids))),
is_training=False,
method="forward",
)[0]
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3445,7 +3445,7 @@ def _test_decoder_with_transformer(self, transformer_cfg: BaseTransformerLayer.C
oh_indices = jax.nn.one_hot(prefix_length - 1, seq_len, dtype=prefix.dtype)
prefix = prefix * (1 - oh_indices) + bos_id * oh_indices
inputs = dict(
prefix=prefix,
input_batch=dict(prefix=prefix),
max_sequence_length=seq_len,
# cross_attention_data=None,
# cross_attention_logit_biases=None,
Expand Down
11 changes: 1 addition & 10 deletions axlearn/common/base_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

"""Base Encoder-Decoder model interface."""

from collections.abc import Sequence
from typing import Optional

from axlearn.common.base_layer import BaseLayer
from axlearn.common.base_model import BaseModel
from axlearn.common.config import REQUIRED, ConfigOr, Required, config_class
from axlearn.common.decoding import BeamSearchOutputs, SampleOutputs
from axlearn.common.logit_modifiers import LogitsToLogitsFn
from axlearn.common.utils import Nested, Tensor, get_recursively
from axlearn.common.utils import Nested, Tensor


class BaseEncoderDecoderModel(BaseModel):
Expand Down Expand Up @@ -61,14 +60,6 @@ def predict(self, input_batch: Nested[Tensor]) -> Nested[Tensor]:
"""
raise NotImplementedError(type(self))

def _validate_input_batch(self, input_batch: Nested[Tensor], paths: Sequence[str]):
"""Raises ValueError if any of the given `paths` are not present in `input_batch`."""
for path in paths:
try:
get_recursively(input_batch, path)
except KeyError as e:
raise ValueError(f"Input batch is expected to contain '{path}'.") from e

def beam_search_decode(
self, input_batch: dict[str, Tensor], num_decodes: int, **kwargs
) -> BeamSearchOutputs:
Expand Down
17 changes: 11 additions & 6 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Autoregressive decoder model, e.g. as seen in the GPT family."""

import math
import re
from typing import Callable, Optional, Union
Expand Down Expand Up @@ -170,7 +171,7 @@ def beam_search_decode(
with child_context("beam_search_decode", module=self.decoder):
prefix = input_batch["prefix"]
return self.decoder.beam_search_decode(
prefix=prefix,
input_batch=input_batch,
max_sequence_length=prefix.shape[-1],
num_decodes=num_decodes,
brevity_penalty=brevity_penalty,
Expand Down Expand Up @@ -203,7 +204,7 @@ def sample_decode(
with child_context("sample_decode", module=self.decoder):
prefix = input_batch["prefix"]
return self.decoder.sample_decode(
prefix=prefix,
input_batch=input_batch,
max_sequence_length=prefix.shape[-1],
num_decodes=num_decodes,
logits_modifier=logits_modifier,
Expand Down Expand Up @@ -280,10 +281,14 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]:
input_positions: Optional[Tensor] = input_batch.get("input_positions")
# Decoder hidden states: [batch_size, target_len, hidden_dim].
decoder_output = self.decoder(
input_ids=input_ids,
token_type_ids=token_type_ids,
input_segment_ids=input_segment_ids,
positions=input_positions,
# TODO(markblee): Simplify by using consistent naming between `input_positions` and
# `positions`, `input_segment_ids` and `segment_ids`.
input_batch=dict(
input_ids=input_ids,
token_type_ids=token_type_ids,
input_segment_ids=input_segment_ids,
positions=input_positions,
),
)
return decoder_output

Expand Down
3 changes: 2 additions & 1 deletion axlearn/common/deberta_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Tests DeBERTa implementation."""

# pylint: disable=no-self-use
from types import SimpleNamespace
from typing import Optional
Expand Down Expand Up @@ -483,7 +484,7 @@ def test_emb(self, query_len: int, **kwargs):
is_training=False,
prng_key=jax.random.PRNGKey(0),
state=layer_params["encoder"]["emb"],
inputs=[input_ids],
inputs=dict(input_batch=dict(inputs=input_ids)),
)
ref_outputs = hf_layer.embeddings(as_torch_tensor(input_ids))
self.assertNestedAllClose(test_outputs, ref_outputs)
Expand Down
Loading

0 comments on commit b1b6c25

Please sign in to comment.