diff --git a/axlearn/audio/decoder_asr.py b/axlearn/audio/decoder_asr.py index 87feb12b..d704f18d 100644 --- a/axlearn/audio/decoder_asr.py +++ b/axlearn/audio/decoder_asr.py @@ -1139,7 +1139,7 @@ def forward(self, input_batch: Nested[Tensor]) -> tuple[Tensor, Nested[Tensor]]: paddings=input_batch["paddings"] ) predict_outputs = self.decoder( - input_ids=input_batch["target"]["input_ids"], + input_batch=input_batch["target"], cross_attention_data=input_batch["inputs"], cross_attention_logit_biases=cross_attention_logit_biases, ) @@ -1241,7 +1241,7 @@ def beam_search_decode( with child_context("beam_search_decode", module=self.decoder): beam_search_outputs: decoding.BeamSearchOutputs = self.decoder.beam_search_decode( - prefix=input_batch["prefix"], + input_batch=input_batch, max_sequence_length=max_decode_len, num_decodes=num_decodes, cross_attention_data=input_batch["inputs"], @@ -1291,7 +1291,7 @@ def sample_decode( with child_context("sample_decode", module=self.decoder): sample_decode_outputs: decoding.SampleOutputs = self.decoder.sample_decode( - prefix=input_batch["prefix"], + input_batch=input_batch, max_sequence_length=max_decode_len, num_decodes=num_decodes, cross_attention_data=input_batch["inputs"], diff --git a/axlearn/audio/model_asr.py b/axlearn/audio/model_asr.py index 42e88863..225ca45e 100644 --- a/axlearn/audio/model_asr.py +++ b/axlearn/audio/model_asr.py @@ -9,7 +9,7 @@ from axlearn.common.base_encoder_decoder import BaseEncoderDecoderModel from axlearn.common.config import REQUIRED, Required, config_class from axlearn.common.module import Module -from axlearn.common.utils import Nested, Tensor +from axlearn.common.utils import Nested, Tensor, validate_contains_paths class ASRModel(BaseEncoderDecoderModel): @@ -51,7 +51,7 @@ def predict(self, input_batch: Nested[Tensor]) -> Nested[Tensor]: Returns: A dict containing logits. The shape of logits depend on the decoder. """ - self._validate_input_batch(input_batch, ["source", "target", "target_labels"]) + validate_contains_paths(input_batch, ["source", "target", "target_labels"]) # Encoder hidden states: [batch_size, source_len, dim]. encoder_output = self.encoder(**input_batch["source"]) logits = self.decoder.predict( @@ -82,7 +82,7 @@ def forward( aux_outputs: A dict containing auxiliary outputs if `return_aux=True`, otherwise an empty dict. """ - self._validate_input_batch(input_batch, ["source", "target", "target_labels"]) + validate_contains_paths(input_batch, ["source", "target", "target_labels"]) # Encoder hidden states: [batch_size, source_len, dim]. encoder_output = self.encoder(**input_batch["source"]) loss, aux_outputs = self.decoder( @@ -115,7 +115,7 @@ def beam_search_decode( Returns: Beam search decode outputs. """ - self._validate_input_batch(input_batch, ["source/inputs", "source/paddings"]) + validate_contains_paths(input_batch, ["source/inputs", "source/paddings"]) encoder_output = self.encoder(**input_batch["source"]) return self.decoder.beam_search_decode( input_batch=dict( diff --git a/axlearn/common/adapter_torch_test.py b/axlearn/common/adapter_torch_test.py index 46df266f..3f1e06c6 100644 --- a/axlearn/common/adapter_torch_test.py +++ b/axlearn/common/adapter_torch_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests PyTorch adapter layers.""" + # pylint: disable=too-many-lines import itertools from collections import OrderedDict @@ -889,9 +890,11 @@ def test_transformer_embeddings_forward( jax.random.PRNGKey(0), state=axlearn_layer_state, inputs=dict( - inputs=jnp.asarray(input_ids), - token_type_ids=axlearn_token_type_ids, - positions=axlearn_positions, + input_batch=dict( + inputs=jnp.asarray(input_ids), + token_type_ids=axlearn_token_type_ids, + positions=axlearn_positions, + ), ), is_training=False, method="forward", @@ -971,7 +974,7 @@ def test_decoder_inference(self): axlearn_layer, jax.random.PRNGKey(0), state=axlearn_layer_state, - inputs=dict(input_ids=jnp.asarray(input_ids)), + inputs=dict(input_batch=dict(input_ids=jnp.asarray(input_ids))), is_training=False, method="forward", )[0] diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index a42184d1..d7b59f62 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -3445,7 +3445,7 @@ def _test_decoder_with_transformer(self, transformer_cfg: BaseTransformerLayer.C oh_indices = jax.nn.one_hot(prefix_length - 1, seq_len, dtype=prefix.dtype) prefix = prefix * (1 - oh_indices) + bos_id * oh_indices inputs = dict( - prefix=prefix, + input_batch=dict(prefix=prefix), max_sequence_length=seq_len, # cross_attention_data=None, # cross_attention_logit_biases=None, diff --git a/axlearn/common/base_encoder_decoder.py b/axlearn/common/base_encoder_decoder.py index 6b609ace..13ab3639 100644 --- a/axlearn/common/base_encoder_decoder.py +++ b/axlearn/common/base_encoder_decoder.py @@ -2,7 +2,6 @@ """Base Encoder-Decoder model interface.""" -from collections.abc import Sequence from typing import Optional from axlearn.common.base_layer import BaseLayer @@ -10,7 +9,7 @@ from axlearn.common.config import REQUIRED, ConfigOr, Required, config_class from axlearn.common.decoding import BeamSearchOutputs, SampleOutputs from axlearn.common.logit_modifiers import LogitsToLogitsFn -from axlearn.common.utils import Nested, Tensor, get_recursively +from axlearn.common.utils import Nested, Tensor class BaseEncoderDecoderModel(BaseModel): @@ -61,14 +60,6 @@ def predict(self, input_batch: Nested[Tensor]) -> Nested[Tensor]: """ raise NotImplementedError(type(self)) - def _validate_input_batch(self, input_batch: Nested[Tensor], paths: Sequence[str]): - """Raises ValueError if any of the given `paths` are not present in `input_batch`.""" - for path in paths: - try: - get_recursively(input_batch, path) - except KeyError as e: - raise ValueError(f"Input batch is expected to contain '{path}'.") from e - def beam_search_decode( self, input_batch: dict[str, Tensor], num_decodes: int, **kwargs ) -> BeamSearchOutputs: diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index 7a8dfca4..0513d871 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Autoregressive decoder model, e.g. as seen in the GPT family.""" + import math import re from typing import Callable, Optional, Union @@ -170,7 +171,7 @@ def beam_search_decode( with child_context("beam_search_decode", module=self.decoder): prefix = input_batch["prefix"] return self.decoder.beam_search_decode( - prefix=prefix, + input_batch=input_batch, max_sequence_length=prefix.shape[-1], num_decodes=num_decodes, brevity_penalty=brevity_penalty, @@ -203,7 +204,7 @@ def sample_decode( with child_context("sample_decode", module=self.decoder): prefix = input_batch["prefix"] return self.decoder.sample_decode( - prefix=prefix, + input_batch=input_batch, max_sequence_length=prefix.shape[-1], num_decodes=num_decodes, logits_modifier=logits_modifier, @@ -280,10 +281,14 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]: input_positions: Optional[Tensor] = input_batch.get("input_positions") # Decoder hidden states: [batch_size, target_len, hidden_dim]. decoder_output = self.decoder( - input_ids=input_ids, - token_type_ids=token_type_ids, - input_segment_ids=input_segment_ids, - positions=input_positions, + # TODO(markblee): Simplify by using consistent naming between `input_positions` and + # `positions`, `input_segment_ids` and `segment_ids`. + input_batch=dict( + input_ids=input_ids, + token_type_ids=token_type_ids, + input_segment_ids=input_segment_ids, + positions=input_positions, + ), ) return decoder_output diff --git a/axlearn/common/deberta_test.py b/axlearn/common/deberta_test.py index 3ae59a4f..4bcfa376 100644 --- a/axlearn/common/deberta_test.py +++ b/axlearn/common/deberta_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests DeBERTa implementation.""" + # pylint: disable=no-self-use from types import SimpleNamespace from typing import Optional @@ -483,7 +484,7 @@ def test_emb(self, query_len: int, **kwargs): is_training=False, prng_key=jax.random.PRNGKey(0), state=layer_params["encoder"]["emb"], - inputs=[input_ids], + inputs=dict(input_batch=dict(inputs=input_ids)), ) ref_outputs = hf_layer.embeddings(as_torch_tensor(input_ids)) self.assertNestedAllClose(test_outputs, ref_outputs) diff --git a/axlearn/common/decoder.py b/axlearn/common/decoder.py index 755a377a..eaa6ef6d 100644 --- a/axlearn/common/decoder.py +++ b/axlearn/common/decoder.py @@ -47,7 +47,13 @@ current_context, new_output_collection, ) -from axlearn.common.utils import Nested, NestedTensor, TensorSpec, with_sharding_constraint +from axlearn.common.utils import ( + Nested, + NestedTensor, + TensorSpec, + validate_contains_paths, + with_sharding_constraint, +) # TODO(markblee): Remove this when we have a better solution at the decoding loop level. @@ -134,7 +140,7 @@ def init_states(self, *, batch_size: int, max_sequence_length: int) -> Nested[Te """ def prefill_states( - self, *, time_step: Tensor, input_ids: Tensor, **kwargs + self, *, time_step: Tensor, input_batch: Nested[Tensor], **kwargs ) -> tuple[Nested[Tensor], Nested[Tensor]]: """Initializes cache for autoregressive cached decoding. @@ -145,7 +151,7 @@ def prefill_states( time_step: A Tensor of shape [batch]. Each value is an index into the length dimension indicating where decoding will start from. If `time_step` exceeds `target_length`, reads consume the last token in the sequence, and writes are no-ops. - input_ids: An integer Tensor of shape [batch, target_length]. + input_batch: A nested Tensor. See corresponding implementation for details. kwargs: Additional kwargs for prefilling. Returns: @@ -192,7 +198,7 @@ def __init__(self, cfg, *, decoder: BaseDecoder): def beam_search_decode( self, *, - prefix: Tensor, + input_batch: Nested[Tensor], max_sequence_length: int, num_decodes: int, cross_attention_data: Optional[Tensor] = None, @@ -202,15 +208,19 @@ def beam_search_decode( """Perform beam search decoding. Args: - prefix: The prefix to use for prompting. A Tensor of shape [batch, max_prefix_length]. - The prefix for each example in the batch should begin with a prompt token (e.g. - BOS). + input_batch: A dict containing: + prefix: The prefix to use for prompting of shape [batch, max_prefix_length]. + The prefix for each example in the batch should begin with a prompt token (e.g. + BOS). + The prefix will be padded with `cfg.pad_token_id` to `max_sequence_length`, thus + it is expected that `max_prefix_length <= max_sequence_length`. max_sequence_length: The maximum sequence length of tokens to generate. num_decodes: The number of decoded sequences to return. These are the number of hypotheses per batch example. cross_attention_data: A float Tensor of shape [batch_size, source_len, hidden_dim]. cross_attention_logit_biases: A Tensor of shape [batch_size, target_len, source_len]. A -inf represents a disconnected position pair. + `target_len` should be broadcastable to `max_sequence_length`. brevity_penalty: Brevity penalty function for length normalization during beam search. Returns: @@ -219,6 +229,9 @@ def beam_search_decode( Raises: ValueError: If pad_token_id is non-zero. """ + validate_contains_paths(input_batch, paths=["prefix"]) + prefix = input_batch["prefix"] + cfg: DecodingLayer.Config = self.config tokens_to_scores_fn = self._tokens_to_scores( num_decodes=num_decodes, @@ -230,9 +243,11 @@ def beam_search_decode( prefix, max_sequence_length=max_sequence_length, pad_id=cfg.pad_token_id ) time_step = infer_initial_time_step(prefix, pad_id=cfg.pad_token_id) + prefill_batch = {**input_batch} + prefill_batch["input_ids"] = input_ids init_states, _ = self._decoder.prefill_states( time_step=time_step, - input_ids=input_ids, + input_batch=prefill_batch, cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, ) @@ -250,7 +265,7 @@ def beam_search_decode( def sample_decode( self, *, - prefix: Tensor, + input_batch: Nested[Tensor], max_sequence_length: int, num_decodes: int, cross_attention_data: Optional[Tensor] = None, @@ -261,15 +276,19 @@ def sample_decode( """Perform sample-based decoding. Args: - prefix: The prefix to use for prompting. Of shape [batch, max_prefix_length]. - The prefix for each example in the batch should begin with a prompt token (e.g. - BOS). + input_batch: A dict containing: + prefix: The prefix to use for prompting of shape [batch, max_prefix_length]. + The prefix for each example in the batch should begin with a prompt token (e.g. + BOS). + The prefix will be padded with `cfg.pad_token_id` to `max_sequence_length`, thus + it is expected that `max_prefix_length <= max_sequence_length`. max_sequence_length: The maximum sequence length of tokens to generate. num_decodes: The number of decoded sequences to return. These are the number of hypotheses per batch example. cross_attention_data: A float Tensor of shape [batch_size, source_len, hidden_dim]. cross_attention_logit_biases: A Tensor of shape [batch_size, target_len, source_len]. A -inf represents a disconnected position pair. + `target_len` should be broadcastable to `max_sequence_length`. logits_modifier: Function used to adjust the raw next-token logit distribution values, to e.g. implement top-k/top-p/etc sampling (see `logit_modifiers`). If None, do not modify the logits. @@ -279,6 +298,9 @@ def sample_decode( Returns: The sample decoding outputs. """ + validate_contains_paths(input_batch, paths=["prefix"]) + prefix = input_batch["prefix"] + cfg: DecodingLayer.Config = self.config logits_modifier = maybe_instantiate(logits_modifier) tokens_to_scores_fn = self._tokens_to_scores( @@ -291,9 +313,11 @@ def sample_decode( prefix, max_sequence_length=max_sequence_length, pad_id=cfg.pad_token_id ) time_step = infer_initial_time_step(prefix, pad_id=cfg.pad_token_id) + prefill_batch = {**input_batch} + prefill_batch["input_ids"] = input_ids init_states, init_outputs = self._decoder.prefill_states( time_step=time_step, - input_ids=input_ids, + input_batch=prefill_batch, cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, ) @@ -467,16 +491,19 @@ def _forward_for_mode( self, *, mode: ForwardMode, - input_ids: Tensor, + input_batch: Nested[Tensor], self_attention_logit_biases: Optional[Tensor], - input_segment_ids: Optional[Tensor] = None, - token_type_ids: Optional[Tensor] = None, cross_attention_data: Optional[Tensor] = None, cross_attention_logit_biases: Optional[Tensor] = None, - positions: Optional[Tensor] = None, cached_states: Optional[NestedTensor] = None, ) -> tuple[Optional[NestedTensor], Tensor]: - x = self.emb(inputs=input_ids, token_type_ids=token_type_ids, positions=positions) + validate_contains_paths(input_batch, paths=["input_ids"]) + input_segment_ids = input_batch.get("input_segment_ids", None) + + emb_batch = {**input_batch} + emb_batch["inputs"] = emb_batch["input_ids"] + x = self.emb(input_batch=emb_batch) + if mode == ForwardMode.FORWARD: transformer_state, x = ( None, @@ -531,31 +558,30 @@ def _forward_for_mode( def forward( self, - input_ids: Tensor, + input_batch: Nested[Tensor], *, - input_segment_ids: Optional[Tensor] = None, - token_type_ids: Optional[Tensor] = None, cross_attention_data: Optional[Tensor] = None, cross_attention_logit_biases: Optional[Tensor] = None, - positions: Optional[Tensor] = None, + **kwargs, ) -> dict[str, Tensor]: """Computes decoder hidden states and logits from input ids and cross attention hidden states. Args: - input_ids: An int Tensor of shape [batch_size, target_len]. - Values should be in the range [0, vocab_size). - input_segment_ids: An optional Tensor of same shape as `input_ids` with values in - [0, num_segments]. Tokens are only allowed to attend to other tokens within the same - segment. input_segment_ids == 0 represents paddings. If None, inferred from - input_ids != pad_token_id. - token_type_ids: An optional int Tensor of shape [batch_size, target_len]. - Values should be in the range [0, type_vocab_size). + input_batch: A dict containing: + * input_ids: An int Tensor of shape [batch_size, target_len]. + Values should be in the range [0, vocab_size). + * input_segment_ids: An optional Tensor of same shape as `input_ids` with values in + [0, num_segments]. Tokens are only allowed to attend to other tokens within the + same segment. input_segment_ids == 0 represents paddings. If None, inferred from + input_ids != pad_token_id. + * token_type_ids: An optional int Tensor of shape [batch_size, target_len]. + Values should be in the range [0, type_vocab_size). + * positions: An optional int Tensor of shape [batch_size, target_len]. + If None, assumed to be jnp.arange(target_len) for each sequence. cross_attention_data: A float Tensor of shape [batch_size, source_len, hidden_dim]. cross_attention_logit_biases: A Tensor of shape [batch_size, target_len, source_len]. A -inf represents a disconnected position pair. - positions: An optional int Tensor of shape [batch_size, target_len]. - If None, assumed to be jnp.arange(target_len) for each sequence. Returns: A dict containing: @@ -563,19 +589,22 @@ def forward( logits: A float Tensor of shape [batch_size, target_len, num_classes], where num_classes depends on the configured lm_head. """ + validate_contains_paths(input_batch, paths=["input_ids"]) + input_ids = input_batch["input_ids"] + input_segment_ids = input_batch.get("input_segment_ids", None) + positions = input_batch.get("positions", None) + _, output = self._forward_for_mode( mode=ForwardMode.FORWARD, - input_ids=input_ids, + input_batch=input_batch, # [batch_size, num_heads, seq_len, seq_len]. self_attention_logit_biases=self.compute_attention_logit_biases( input_ids, segment_ids=input_segment_ids, positions=positions ), - input_segment_ids=input_segment_ids, - token_type_ids=token_type_ids, cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, - positions=positions, cached_states=None, + **kwargs, ) if self._output_logits_modifier is not None: output["logits"] = self._output_logits_modifier(output["logits"]) @@ -597,31 +626,55 @@ def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTen ) def prefill_states( - self, *, time_step: Tensor, input_ids: Tensor, **kwargs - ) -> tuple[NestedTensor, NestedTensor]: - """See `BaseDecoder.prefill_states` for details.""" + self, + *, + time_step: Tensor, + input_batch: Nested[Tensor], + **kwargs, + ) -> tuple[Nested[Tensor], Nested[Tensor]]: + """See `BaseDecoder.prefill_states` for details. + + Args: + time_step: A Tensor of shape [batch_size]. See `BaseDecoder.prefill_states` for details. + input_batch: See `forward` for details. + kwargs: See `forward` for details. + + Returns: + See `BaseDecoder.prefill_states` for details. + """ + validate_contains_paths(input_batch, paths=["input_ids"]) + input_ids = input_batch["input_ids"] + input_segment_ids = input_batch.get("input_segment_ids", None) + positions = input_batch.get("positions", None) + states, outputs = self._forward_for_mode( mode=ForwardMode.INIT_STATES, cached_states=dict(transformer_state=time_step), - input_ids=input_ids, + input_batch=input_batch, # TODO(markblee): Consider supporting packed inputs for more efficient prefilling. - self_attention_logit_biases=self.compute_attention_logit_biases(input_ids), + self_attention_logit_biases=self.compute_attention_logit_biases( + input_ids, segment_ids=input_segment_ids, positions=positions + ), **kwargs, ) states = dict(time_step=time_step, input_ids=input_ids, **states) return states, outputs def extend_step( - self, *, cached_states: NestedTensor, input_ids: Tensor, **kwargs - ) -> tuple[NestedTensor, NestedTensor]: + self, + *, + cached_states: Nested[Tensor], + input_ids: Tensor, + **kwargs, + ) -> tuple[Nested[Tensor], Nested[Tensor]]: """See `BaseDecoder.extend_step` for details.""" - time_step = cached_states["time_step"] + time_step: Tensor = cached_states["time_step"] assert time_step.ndim == 1 # Update cached input_ids via "scatter via one-hot broadcast" trick. # Note: in the cases where `time_step` exceeds `target_len`, the update becomes a no-op. # --> [B, T]. - cached_inputs = cached_states["input_ids"] + cached_inputs: Tensor = cached_states["input_ids"] target_len = cached_inputs.shape[-1] oh_indices = jax.nn.one_hot(time_step, target_len, dtype=input_ids.dtype) updated_inputs = cached_inputs * (1 - oh_indices) + input_ids * oh_indices @@ -644,11 +697,19 @@ def extend_step( mode="clip", ) + input_segment_ids = kwargs.pop("input_segment_ids", None) + token_type_ids = kwargs.pop("token_type_ids", None) + positions = kwargs.pop("positions", jnp.expand_dims(time_step, 1)) + updated_states, outputs = self._forward_for_mode( mode=ForwardMode.EXTEND_STEP, - input_ids=input_ids, + input_batch=dict( + input_ids=input_ids, + input_segment_ids=input_segment_ids, + token_type_ids=token_type_ids, + positions=positions, + ), self_attention_logit_biases=self_attention_biases, - positions=jnp.expand_dims(time_step, 1), cached_states=cached_states, **kwargs, ) @@ -666,14 +727,14 @@ def extend_step( def beam_search_decode( self, *, - prefix: Tensor, + input_batch: Nested[Tensor], max_sequence_length: int, num_decodes: int, **kwargs, ): """See configured `decoding` implementation for details.""" return self._decoding.beam_search_decode( - prefix=prefix, + input_batch=input_batch, max_sequence_length=max_sequence_length, num_decodes=num_decodes, **kwargs, @@ -681,14 +742,15 @@ def beam_search_decode( def sample_decode( self, - prefix: Tensor, + *, + input_batch: Nested[Tensor], max_sequence_length: int, num_decodes: int, **kwargs, ): """See configured `decoding` implementation for details.""" return self._decoding.sample_decode( - prefix=prefix, + input_batch=input_batch, max_sequence_length=max_sequence_length, num_decodes=num_decodes, **kwargs, diff --git a/axlearn/common/decoder_test.py b/axlearn/common/decoder_test.py index 6eef860d..73386428 100644 --- a/axlearn/common/decoder_test.py +++ b/axlearn/common/decoder_test.py @@ -111,7 +111,7 @@ def test_tied_lm_head_differs_from_untied(self): def layer_output(state, layer): return functional( layer, - inputs=dict(input_ids=inputs), + inputs=dict(input_batch=dict(input_ids=inputs)), state=state, is_training=False, prng_key=jax.random.PRNGKey(2), @@ -200,7 +200,7 @@ def test_causal_attention( def layer_output(state, layer): return functional( layer, - inputs=dict(input_ids=input_ids), + inputs=dict(input_batch=dict(input_ids=input_ids)), state=state, is_training=False, prng_key=jax.random.PRNGKey(2), @@ -221,7 +221,7 @@ def layer_output(state, layer): oh_indices = jax.nn.one_hot(prefix_length - 1, source_length, dtype=prefix.dtype) prefix = prefix * (1 - oh_indices) + ref_cfg.eos_token_id * oh_indices inputs = dict( - prefix=prefix, + input_batch=dict(prefix=prefix), max_sequence_length=source_length, num_decodes=2, ) @@ -305,7 +305,7 @@ def test_add_tensor_stats(self): _, output_collection = functional( layer, - inputs=dict(input_ids=inputs), + inputs=dict(input_batch=dict(input_ids=inputs)), state=layer_state, is_training=False, prng_key=jax.random.PRNGKey(2), @@ -397,9 +397,11 @@ def test_extend_step( forward_outputs, _ = functional( layer, inputs=dict( - input_ids=input_ids, - input_segment_ids=jnp.ones_like(input_ids), - positions=jnp.arange(input_ids.shape[-1])[None, :], + input_batch=dict( + input_ids=input_ids, + input_segment_ids=jnp.ones_like(input_ids), + positions=jnp.arange(input_ids.shape[-1])[None, :], + ), cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, ), @@ -412,7 +414,7 @@ def test_extend_step( layer, inputs=dict( time_step=time_step, - input_ids=input_ids, + input_batch=dict(input_ids=input_ids), cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, ), @@ -565,7 +567,7 @@ def test_decode( prefix = prefix * (1 - oh_indices) + bos_id * oh_indices inputs = dict( - prefix=prefix, + input_batch=dict(prefix=prefix), max_sequence_length=tgt_len, cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, @@ -682,7 +684,10 @@ def test_output_logits_modifier(self): ) decoder_cfg.set(name="tmp", output_logits_modifier=output_logits_modifier) decoder = decoder_cfg.instantiate(parent=None) - chex.assert_trees_all_close(decoder(5 * jnp.ones(3)), dict(logits=17 * 5 * jnp.ones(3))) + chex.assert_trees_all_close( + decoder(input_batch=dict(input_ids=5 * jnp.ones(3))), + dict(logits=17 * 5 * jnp.ones(3)), + ) def test_token_scores_match_between_decoded_and_prefix(self): """Test that token scores match between sample_decode passes. @@ -734,7 +739,7 @@ def test_token_scores_match_between_decoded_and_prefix(self): outputs_1, _ = functional( decoder, inputs=dict( - prefix=prefix_1, + input_batch=dict(prefix=prefix_1), max_sequence_length=target_length, num_decodes=1, # Don't stop decoding until target length is reached @@ -755,7 +760,7 @@ def test_token_scores_match_between_decoded_and_prefix(self): outputs_2, _ = functional( decoder, inputs=dict( - prefix=prefix_2, + input_batch=dict(prefix=prefix_2), max_sequence_length=target_length, num_decodes=1, # Don't stop decoding until target length is reached diff --git a/axlearn/common/embedding.py b/axlearn/common/embedding.py index df7fbd4f..444add6f 100644 --- a/axlearn/common/embedding.py +++ b/axlearn/common/embedding.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Embedding layers.""" + from typing import Optional from jax import numpy as jnp @@ -9,6 +10,7 @@ from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class from axlearn.common.layers import Dropout, Embedding from axlearn.common.module import Module, Tensor, child_context +from axlearn.common.utils import Nested, validate_contains_paths class TransformerTextEmbeddings(BaseLayer): @@ -41,37 +43,39 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): self._add_child("norm", cfg.norm.set(input_dim=cfg.dim)) self._add_child("dropout", cfg.dropout) - def forward( - self, - inputs: Tensor, - *, - token_type_ids: Optional[Tensor] = None, - positions: Optional[Tensor] = None, - ) -> Tensor: + def forward(self, input_batch: Nested[Tensor]) -> Tensor: """Computes input embeddings with positional embeddings. If token_type_ids is provided, we also add input type embeddings. Args: - inputs: arbitrary input tensor with general shape [batch_size, seq_len, ...] that - will be fed directly to `self.token_emb`. - token_type_ids: An optional int Tensor of shape [batch_size, seq_len]. - positions: An optional int Tensor of shape [batch_size, seq_len]. - If None, assumed to be jnp.arange(seq_len) for each sequence. + input_batch: A dict containing: + * inputs: An input tensor with general shape [batch_size, seq_len, ...] that will be + fed directly to `self.token_emb`. + * token_type_ids: An optional int Tensor of shape [batch_size, seq_len]. + * positions: An optional int Tensor of shape [batch_size, seq_len]. + If None, assumed to be jnp.arange(seq_len) for each sequence. Returns: A float Tensor of shape [batch_size, seq_len, hidden_dim] """ + validate_contains_paths(input_batch, paths=["inputs"]) + inputs = input_batch["inputs"] + token_type_ids = input_batch.get("token_type_ids", None) + positions = input_batch.get("positions", None) + + cfg: TransformerTextEmbeddings.Config = self.config + x = self.token_emb(inputs) - if self.config.type_emb is not None: + if cfg.type_emb is not None: if token_type_ids is None: token_type_ids = jnp.zeros_like(inputs) x = x + self.type_emb(token_type_ids) - if self.config.pos_emb is not None: + if cfg.pos_emb is not None: if positions is None: positions = jnp.arange(x.shape[1]) x += self.pos_emb(positions) - if self.config.norm is not None: + if cfg.norm is not None: x = self.norm(x) x = self.dropout(x) return x diff --git a/axlearn/common/embedding_test.py b/axlearn/common/embedding_test.py index c749fad3..880afd13 100644 --- a/axlearn/common/embedding_test.py +++ b/axlearn/common/embedding_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Test embedding layers.""" + # pylint: disable=no-self-use import itertools @@ -77,7 +78,7 @@ def test_against_hf_bert_embeddings(self, pos_emb_cls: type, use_explicit_positi test_hidden_states, ref_hidden_states = self._compute_layer_outputs( test_layer=layer, ref_layer=ref_layer, - test_inputs=test_inputs, + test_inputs=dict(input_batch=test_inputs), ref_inputs=dict( input_ids=as_torch_tensor(input_ids), ), diff --git a/axlearn/common/encoder.py b/axlearn/common/encoder.py index 9846edbb..b36734cb 100644 --- a/axlearn/common/encoder.py +++ b/axlearn/common/encoder.py @@ -58,6 +58,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): self._add_child("output", cfg.output.set(input_dim=cfg.dim)) self._add_child("attention_mask", cfg.attention_mask) + # TODO(markblee): Generalize to support input_batch, similar to Decoder. def forward( self, input_ids: Tensor, @@ -82,7 +83,9 @@ def forward( A Tensor of shape [batch_size, seq_len, hidden_dim]. """ # [batch_size, seq_len, hidden_dim]. - x = self.emb(inputs=input_ids, token_type_ids=token_type_ids, positions=positions) + x = self.emb( + input_batch=dict(inputs=input_ids, token_type_ids=token_type_ids, positions=positions) + ) # [batch_size, num_heads, seq_len, seq_len]. attention_logit_biases = self.compute_attention_logit_biases( input_ids, segment_ids=input_segment_ids, positions=positions @@ -231,7 +234,9 @@ def forward( batch_size, max_seq_len = input_ids.shape # [batch_size, seq_len, hidden_dim]. - x = self.emb(inputs=input_ids, token_type_ids=token_type_ids, positions=positions) + x = self.emb( + input_batch=dict(inputs=input_ids, token_type_ids=token_type_ids, positions=positions) + ) # Append optional cls tokens as used in CoCa. if cfg.num_cls_tokens > 0: @@ -282,7 +287,9 @@ def prefill_states( ) -> tuple[NestedTensor, NestedTensor]: # Note: this follows `Decoder.prefill_states` closely. Refer to that method for details. # TODO(markblee): Possibly consolidate some of this with decoder. - x = self.emb(input_ids, token_type_ids=token_type_ids, positions=None) + x = self.emb( + input_batch=dict(inputs=input_ids, token_type_ids=token_type_ids, positions=None) + ) transformer_state, x = self.transformer.init_states( time_step=time_step, data=x, @@ -328,7 +335,11 @@ def extend_step( # [B, 1, D]. x = self.emb( - input_ids, positions=jnp.expand_dims(time_step, 1), token_type_ids=token_type_ids + input_batch=dict( + inputs=input_ids, + positions=jnp.expand_dims(time_step, 1), + token_type_ids=token_type_ids, + ) ) updated_transformer_state, transformer_data = self.transformer.extend_step( cached_states=cached_states["transformer_state"], diff --git a/axlearn/common/encoder_decoder.py b/axlearn/common/encoder_decoder.py index 7a8a8001..6c0945d8 100644 --- a/axlearn/common/encoder_decoder.py +++ b/axlearn/common/encoder_decoder.py @@ -16,7 +16,7 @@ from axlearn.common.loss import cross_entropy from axlearn.common.metrics import WeightedScalar from axlearn.common.module import Module, Tensor, child_context -from axlearn.common.utils import Nested +from axlearn.common.utils import Nested, validate_contains_paths class EncoderDecoderModel(BaseEncoderDecoderModel): @@ -48,7 +48,7 @@ def forward( return_aux: bool = False, ) -> tuple[Tensor, Nested[Tensor]]: """See `BaseEncoderDecoderModel` docstring for details.""" - self._validate_input_batch(input_batch, paths=["source", "target", "target_labels"]) + validate_contains_paths(input_batch, paths=["source", "target", "target_labels"]) predict_outputs = self.predict(input_batch) loss, aux_metrics = self._metrics(input_batch, predict_outputs=predict_outputs) if not predict_outputs.keys().isdisjoint(aux_metrics.keys()): @@ -100,7 +100,7 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]: Raises: ValueError: If source_segment_ids and target_segment_ids are not provided together. """ - self._validate_input_batch(input_batch, paths=["source/input_ids", "target/input_ids"]) + validate_contains_paths(input_batch, paths=["source/input_ids", "target/input_ids"]) source_batch: dict[str, Tensor] = input_batch["source"] target_batch: dict[str, Tensor] = input_batch["target"] source_segment_ids: Optional[Tensor] = source_batch.get("input_segment_ids") @@ -122,7 +122,7 @@ def predict(self, input_batch: dict[str, Tensor]) -> dict[str, Tensor]: ) # Decoder hidden states: [batch_size, target_len, hidden_dim]. decoder_output = self.decoder( - **target_batch, + target_batch, cross_attention_data=encoder_output, cross_attention_logit_biases=cross_attention_logit_biases, ) @@ -158,8 +158,8 @@ def _metrics( aux_outputs: A dict containing auxiliary metrics: per_token_loss: A float Tensor of shape [batch_size, target_len]. """ - self._validate_input_batch(input_batch, paths=["target_labels"]) - self._validate_input_batch(predict_outputs, paths=["logits"]) + validate_contains_paths(input_batch, paths=["target_labels"]) + validate_contains_paths(predict_outputs, paths=["logits"]) logits: Tensor = predict_outputs["logits"] target_labels: Tensor = input_batch["target_labels"] @@ -202,15 +202,14 @@ def beam_search_decode( Returns: Beam search outputs. See parent docstring for details. """ - self._validate_input_batch(input_batch, paths=["prefix", "source/input_ids"]) - prefix: Tensor = input_batch["prefix"] + validate_contains_paths(input_batch, paths=["prefix", "source/input_ids"]) input_ids: Tensor = input_batch["source"]["input_ids"] encoder_output = self.encoder(input_ids=input_ids) cross_attention_logit_biases = self.compute_attention_logit_biases(input_ids) with child_context("beam_search_decode", module=self.decoder): return self.decoder.beam_search_decode( - prefix=prefix, + input_batch=input_batch, num_decodes=num_decodes, cross_attention_data=encoder_output, cross_attention_logit_biases=cross_attention_logit_biases, @@ -241,15 +240,14 @@ def sample_decode( Returns: Sample decoding outputs. See parent docstring for details. """ - self._validate_input_batch(input_batch, paths=["prefix", "source/input_ids"]) - prefix: Tensor = input_batch["prefix"] + validate_contains_paths(input_batch, paths=["prefix", "source/input_ids"]) input_ids: Tensor = input_batch["source"]["input_ids"] encoder_output = self.encoder(input_ids=input_ids) cross_attention_logit_biases = self.compute_attention_logit_biases(input_ids) with child_context("sample_decode", module=self.decoder): return self.decoder.sample_decode( - prefix=prefix, + input_batch=input_batch, num_decodes=num_decodes, cross_attention_data=encoder_output, cross_attention_logit_biases=cross_attention_logit_biases, diff --git a/axlearn/common/encoder_decoder_test.py b/axlearn/common/encoder_decoder_test.py index 3a114009..95aa3e6a 100644 --- a/axlearn/common/encoder_decoder_test.py +++ b/axlearn/common/encoder_decoder_test.py @@ -4,6 +4,7 @@ import os from typing import Literal, Optional +from unittest import mock import jax import numpy as np @@ -272,9 +273,6 @@ def test_decode( def test_forward_key_conflict(self): # pylint: disable=unused-argument class DummyEncoderDecoderModel(EncoderDecoderModel): - def _validate_input_batch(self, *args, **kwargs): - pass - def predict(self, *args, **kwargs): return dict(x=1) @@ -286,7 +284,10 @@ def _metrics(self, *args, **kwargs): encoder=cfg.encoder, decoder=cfg.decoder ) model = cfg.set(name="test").instantiate(parent=None) - with self.assertRaisesRegex(KeyError, "conflict"): + with ( + mock.patch(f"{utils.__name__}.validate_contains_paths"), + self.assertRaisesRegex(KeyError, "conflict"), + ): # noqa: F821 F( model, inputs=dict( diff --git a/axlearn/common/encoder_test.py b/axlearn/common/encoder_test.py index 80ec7bc0..b3394f3b 100644 --- a/axlearn/common/encoder_test.py +++ b/axlearn/common/encoder_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests encoder layers.""" + # pylint: disable=no-self-use from typing import Optional @@ -159,9 +160,11 @@ def test_embeddings(self, num_segments: int): prng_key=jax.random.PRNGKey(123), state=parameters_from_torch_layer(self.hf_encoder)["encoder"]["emb"], inputs=dict( - inputs=source_ids, - token_type_ids=source_type_ids, - positions=source_positions, + input_batch=dict( + inputs=source_ids, + token_type_ids=source_type_ids, + positions=source_positions, + ), ), ) ref_outputs = self.hf_encoder.embeddings( diff --git a/axlearn/common/multiway_transformer.py b/axlearn/common/multiway_transformer.py index cbbe39f5..270a83a8 100644 --- a/axlearn/common/multiway_transformer.py +++ b/axlearn/common/multiway_transformer.py @@ -372,7 +372,7 @@ def get_visual_embed( def get_text_embed(self, data: Tensor, modality: int) -> Tensor: # Same text embeddings as Bert. - x = self.text_embed(inputs=data) + x = self.text_embed(input_batch=dict(inputs=data)) # Add modality type embedding. x = x + self.modality_emb(jnp.full(x.shape[:2], modality)) return x diff --git a/axlearn/common/param_converter_test.py b/axlearn/common/param_converter_test.py index ad269d67..34e15639 100644 --- a/axlearn/common/param_converter_test.py +++ b/axlearn/common/param_converter_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests param converter utils.""" + # pylint: disable=too-many-lines import os from typing import Any, Callable, Optional @@ -270,7 +271,7 @@ def test_bert_embeddings(self): out, hf_out = self._compute_layer_outputs( test_layer=layer, ref_layer=hf_layer, - test_inputs=[inputs], + test_inputs=dict(input_batch=dict(inputs=inputs)), ref_inputs=as_torch_tensor(inputs), ) self.assertNestedAllClose(out, hf_out) @@ -303,7 +304,7 @@ def test_roberta_embeddings(self): out, hf_out = self._compute_layer_outputs( test_layer=layer, ref_layer=hf_layer, - test_inputs=[inputs], + test_inputs=dict(input_batch=dict(inputs=inputs)), ref_inputs=as_torch_tensor(inputs), ) # Compare only at non-padding positions. diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 2c755a32..1401337f 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1441,3 +1441,15 @@ def sequence_mask(*, lengths: Tensor, max_len: int, dtype: Optional[jnp.dtype] = # [..., 1] lengths = lengths[..., jnp.newaxis] return (sequence < lengths).astype(dtype) + + +def validate_contains_paths(x: Nested[Tensor], paths: Sequence[str]): + """Raises ValueError if any of the given `paths` are not present in `x`.""" + for path in paths: + try: + get_recursively(x, path) + except KeyError as e: + raise ValueError( + f"Input is expected to contain '{path}'; " + f"instead, it contains: '{jax.tree_structure(x)}'." + ) from e diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 3b631155..f4c06b47 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -2,6 +2,7 @@ """Tests common utils.""" +import contextlib import dataclasses import enum import sys @@ -75,6 +76,7 @@ set_recursively, split_prng_key, tree_paths, + validate_contains_paths, validate_float_dtype, vectorized_tree_map, with_sharding_constraint, @@ -1777,5 +1779,30 @@ def test_every_other_process(self): self.assertNestedEqual(expected, replicate_to_local_data(batch)) +class ValidateContainsPathsTest(TestCase): + @parameterized.parameters( + # Missing path. + dict( + x={}, + paths=["test"], + missing="test", + ), + # OK. + dict(x={"test": 123}, paths=["test"], missing=None), + # OK. + dict(x={"00": {"10": 123}}, paths=["00/10"], missing=None), + # Missing '00/11'. + dict(x={"00": {"10": 123}}, paths=["00/10", "00/11"], missing="00/11"), + ) + def test_basic(self, x: Nested[Tensor], paths: Sequence[str], missing: Optional[str]): + if missing is not None: + ctx = self.assertRaisesRegex(ValueError, missing) + else: + ctx = contextlib.nullcontext() + + with ctx: + validate_contains_paths(x, paths=paths) + + if __name__ == "__main__": absltest.main() diff --git a/axlearn/vision/coca.py b/axlearn/vision/coca.py index dbab386d..4c3300ff 100644 --- a/axlearn/vision/coca.py +++ b/axlearn/vision/coca.py @@ -47,7 +47,7 @@ from axlearn.common.module import Module from axlearn.common.multi_stream_model import FusionNetwork, MultiStreamModel, StreamEncoder from axlearn.common.poolings import AttentionPooling, BasePoolingLayer, LastNTokenPooling -from axlearn.common.utils import NestedTensor, Tensor, TensorSpec +from axlearn.common.utils import Nested, NestedTensor, Tensor, TensorSpec, validate_contains_paths from axlearn.common.vision_transformer import VisionTransformer, layer_norm_config from axlearn.vision.clip import CLIPFusionNetwork @@ -889,11 +889,28 @@ def prefill_states( self, *, time_step: Tensor, - input_ids: Tensor, + input_batch: Nested[Tensor], cross_attention_data: Optional[Tensor] = None, cross_attention_logit_biases: Optional[Tensor] = None, ) -> tuple[NestedTensor, NestedTensor]: - """See `BaseDecoder.prefill_states` for details.""" + """See `BaseDecoder.prefill_states` for details. + + Args: + time_step: A Tensor of shape [batch_size]. See `BaseDecoder.prefill_states` for details. + input_batch: A dict containing at minimum: + * input_ids: An int Tensor of shape [batch_size, seq_len]. + Values should be in the range [0, vocab_size), where `vocab_size` is commonly + configured in `textual_encoder`. + cross_attention_data: A float Tensor of shape [batch_size, source_len, hidden_dim]. + cross_attention_logit_biases: A Tensor of shape [batch_size, target_len, source_len]. + A -inf represents a disconnected position pair. + + Returns: + See `BaseDecoder.prefill_states` for details. + """ + validate_contains_paths(input_batch, paths=["input_ids"]) + input_ids = input_batch["input_ids"] + textual_encoder_state, textual_encoder_output = self._stream_encoder[ "textual_encoder" ].prefill_states( @@ -978,14 +995,14 @@ def extend_step( def beam_search_decode( self, *, - prefix: Tensor, + input_batch: Nested[Tensor], max_sequence_length: int, num_decodes: int, **kwargs, ): """See configured `decoding` implementation for details.""" return self._decoding.beam_search_decode( - prefix=prefix, + input_batch=input_batch, max_sequence_length=max_sequence_length, num_decodes=num_decodes, **kwargs, @@ -993,14 +1010,15 @@ def beam_search_decode( def sample_decode( self, - prefix: Tensor, + *, + input_batch: Nested[Tensor], max_sequence_length: int, num_decodes: int, **kwargs, ): """See configured `decoding` implementation for details.""" return self._decoding.sample_decode( - prefix=prefix, + input_batch=input_batch, max_sequence_length=max_sequence_length, num_decodes=num_decodes, **kwargs, @@ -1044,9 +1062,9 @@ def predict_caption( if decode_method in ("beam_search_decode", "sample_decode"): output = getattr(self, decode_method)( + input_batch=input_batch, max_sequence_length=max_sequence_length, num_decodes=num_decodes, - prefix=input_batch["prefix"], cross_attention_data=visual_features, ) else: diff --git a/axlearn/vision/coca_test.py b/axlearn/vision/coca_test.py index 02d13cbe..8a847139 100644 --- a/axlearn/vision/coca_test.py +++ b/axlearn/vision/coca_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests CoCa implementations.""" + # pylint: disable=no-self-use from copy import deepcopy @@ -553,7 +554,7 @@ def test_extend_step(self, use_cross_attention: bool, prefill_states: bool): prng_key=jax.random.PRNGKey(456), inputs=dict( time_step=time_step, - input_ids=tokenized_text, + input_batch=dict(input_ids=tokenized_text), cross_attention_data=cross_attention_data, ), method="prefill_states", diff --git a/axlearn/vision/param_converter_test.py b/axlearn/vision/param_converter_test.py index 3c1e358f..5ce77d6f 100644 --- a/axlearn/vision/param_converter_test.py +++ b/axlearn/vision/param_converter_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests vision param converter utils.""" + import jax import jax.numpy as jnp from absl.testing import parameterized @@ -105,7 +106,7 @@ def test_clip_text_embeddings(self): out, hf_out = self._compute_layer_outputs( test_layer=layer, ref_layer=hf_layer, - test_inputs=[inputs], + test_inputs=dict(input_batch=dict(inputs=inputs)), ref_inputs=as_torch_tensor(inputs), ) # Compare only at non-padding positions. diff --git a/axlearn/vision/virtex.py b/axlearn/vision/virtex.py index a057f183..49409de7 100644 --- a/axlearn/vision/virtex.py +++ b/axlearn/vision/virtex.py @@ -10,6 +10,7 @@ https://arxiv.org/abs/2006.06666 """ + from typing import Optional, Union import jax @@ -24,7 +25,7 @@ from axlearn.common.layers import Linear from axlearn.common.metrics import WeightedScalar from axlearn.common.module import Module, child_context -from axlearn.common.utils import NestedTensor, get_recursively, tree_paths +from axlearn.common.utils import NestedTensor, get_recursively, tree_paths, validate_contains_paths from axlearn.common.vision_transformer import VisionTransformer as ViTModel from axlearn.common.vision_transformer import named_model_configs as vit_named_model_configs from axlearn.vision.resnet import ResNet @@ -235,7 +236,8 @@ def forward( # Decode caption. decoder_ids, decoder_labels = caption_tokens[:, :-1], caption_tokens[:, 1:] predictions: dict[str, Tensor] = self.textual( - decoder_ids, cross_attention_data=projected_visual_features + input_batch=dict(input_ids=decoder_ids), + cross_attention_data=projected_visual_features, ) metrics = self._metrics(predictions["logits"], decoder_labels) @@ -266,9 +268,9 @@ def caption( projected_visual_features = self.embed_image(image) with child_context("beam_search_decode", module=self.textual): output: BeamSearchOutputs = self.textual.beam_search_decode( + input_batch=dict(prefix=prefix), max_sequence_length=max_sequence_length, num_decodes=num_decodes, - prefix=prefix, cross_attention_data=projected_visual_features, ) return output @@ -294,6 +296,7 @@ def beam_search_decode( Returns: The beam search outputs. """ + validate_contains_paths(input_batch, paths=["image", "prefix"]) return self.caption( image=input_batch["image"], prefix=input_batch["prefix"], diff --git a/axlearn/vision/virtex_test.py b/axlearn/vision/virtex_test.py index 0bee26e0..cf1b9476 100644 --- a/axlearn/vision/virtex_test.py +++ b/axlearn/vision/virtex_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests VirTex implementations.""" + # pylint: disable=no-self-use from collections.abc import Sequence