From fc761b0b6e0937f1e6fe4aaae11a53a9275af7b9 Mon Sep 17 00:00:00 2001 From: Mark Lee Date: Tue, 26 Nov 2024 19:51:23 -0500 Subject: [PATCH] Unify init and prefill for attention layers. (#860) * Unify init and prefill for attention layers. * Fix some types and docstrings. --- axlearn/common/attention.py | 622 +++++++++---------- axlearn/common/attention_test.py | 141 ++++- axlearn/common/decoder.py | 12 +- axlearn/common/encoder.py | 16 +- axlearn/common/flash_attention/layer_test.py | 17 +- axlearn/common/lora_test.py | 10 +- axlearn/common/multiway_transformer.py | 45 +- axlearn/common/multiway_transformer_test.py | 10 +- axlearn/common/ssm.py | 267 ++++---- axlearn/common/ssm_test.py | 20 +- axlearn/vision/coca.py | 15 +- 11 files changed, 644 insertions(+), 531 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 33785edf3..b1a9a7e3b 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -46,7 +46,6 @@ * Tokens are only allowed to attend to other tokens within the same segment. * segment_ids == 0 represents paddings. * None represents an all-one tensor, i.e. all positions are in the same segment. - """ # pylint: disable=abstract-method,too-many-lines @@ -105,6 +104,7 @@ NestedTensor, PartitionSpec, Tensor, + TensorSpec, VDict, check_numerics, flatten_items, @@ -225,37 +225,19 @@ def forward( def init_states( self, *, - target_batch_size: int, - target_max_len: int, - self_attention_kv_state: Optional[KVState] = None, - ) -> NestedTensor: - """Initializes cached states for incremental computation. - - Args: - target_batch_size: The batch size for target sequences. - target_max_len: The maximum number of tokens in a target sequence. - self_attention_kv_state: An optional KVState used for self-attention. - - Returns: - A nested tree of Tensors, which can be used as `cached_states` for the initial call - of `extend_step()`. - """ - raise NotImplementedError(type(self)) - - def prefill_states( - self, - *, - time_step: Tensor, - data: Tensor, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], self_attention_kv_state: Optional[KVState] = None, self_attention_logit_biases: Optional[Tensor] = None, cross_attention_data: Optional[Tensor] = None, cross_attention_logit_biases: Optional[Tensor] = None, - ) -> tuple[NestedTensor, Output]: + ) -> tuple[Nested[Tensor], Optional[Output]]: """Initializes cached states for incremental computation. - TODO(markblee): Rename to init_states once we add support for decoding at non-zero time - step. + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `data` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `data` as Tensors. Args: time_step: A Tensor of shape [batch]. Each value is an index into the length dimension @@ -270,11 +252,14 @@ def prefill_states( biases. Returns: - A nested tree of Tensors, which can be used as `cached_states` for the initial call - of `extend_step()`. - A BaseTransformerLayer.Output instance, where .data is of the same shape as `data`, - .self_attention_probs is of shape [batch, num_heads, target_length, target_length], and - .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. + A tuple (init_states, output): + * init_states: A nested tree of Tensors, which can be used as `cached_states` for the + initial call of `extend_step()`. + * output: In the prefill case, a BaseTransformerLayer.Output instance, where: + .data is of the same shape as `data`; + .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; + .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. + Otherwise, if initializing cache from scratch, output will be None. """ raise NotImplementedError(type(self)) @@ -716,30 +701,86 @@ def num_kv_heads(self): def init_states( self, *, - target_batch_size: int, - target_max_len: int, + time_step: Optional[Tensor], + query: Union[Tensor, TensorSpec], + key: Optional[Tensor] = None, + value: Optional[Tensor] = None, kv_state: Optional[KVState] = None, - ) -> NestedTensor: - cfg = self.config + ) -> tuple[Nested[Tensor], Optional[Output]]: + """Initializes cache for autoregressive cached decoding. + + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `query` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `query` as Tensors. + + Args: + time_step: An optional Tensor of shape [batch]. Each value is an index into the length + dimension indicating where decoding will start from. + query: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding + to query vector at `time_step` indices. + For batch index `i`, only `query[i, :time_step[i], ...]` will affect subsequent + decoding. + key: An optional Tensor of shape [batch, source_length, source_dim]. + If None, will use `query`. + value: An optional Tensor of shape [batch, source_length, source_dim]. + If None, will use `query`. + kv_state: An optional KVState. If not None, both key and value must be None. + + Returns: + A tuple (init_states, output): + * init_states: A Nested Tensor state of `key`, `value` of shape + [batch, num_heads, per_head_dim, source_length], and `time_step` of shape [batch]. + * output: In the prefill case, an Output instance, where query is of size + [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim + [batch, source_length, num_heads, per_head_dim]. + Otherwise, if initializing cache from scratch, output will be None. + + Raises: + ValueError: If key/value and kv_state are an invalid combination. + ValueError: If query and time_step are an invalid combination. + """ + cfg: BaseQKVLinear.Config = self.config # Default to base layer dtype for initialization if cache_dtype is None. dtype = cfg.cache_dtype or self.dtype() assert dtype is not None - cache = dict(time_step=jnp.zeros(target_batch_size, dtype=jnp.int32)) + if kv_state is not None and (key is not None or value is not None): + raise ValueError("kv_state should not be specified together with key/value.") + if time_step is not None and isinstance(query, TensorSpec): + raise ValueError("query must be a Tensor if time_step is provided.") + + output = None + # Always initialize to all 0's; if `time_step` is provided, we invoke `extend_step` below + # which updates the cache with the new `time_step`. + init_state = dict(time_step=jnp.zeros(query.shape[0], dtype=jnp.int32)) + # If `kv_state` is provided externally, we do not have to maintain key/value in cache. + # Otherwise, initialize the cache from provided query, key, value. if kv_state is None: - cache.update( - key=jnp.zeros( - shape=(target_batch_size, target_max_len, self.num_kv_heads, cfg.per_head_dim), - dtype=dtype, - ), - value=jnp.zeros( - shape=(target_batch_size, target_max_len, self.num_kv_heads, cfg.per_head_dim), - dtype=dtype, - ), + + def maybe_initialize(kv: Optional[Tensor]): + # [batch, source/target_len, num_kv_heads, per_head_dim]. + if kv is None: + kv = jnp.zeros( + (*query.shape[:2], self.num_kv_heads, cfg.per_head_dim), dtype=dtype + ) + else: + kv = jnp.reshape(kv, (*kv.shape[:2], self.num_kv_heads, cfg.per_head_dim)) + return kv + + init_state.update(key=maybe_initialize(key), value=maybe_initialize(value)) + + # If time_step is not provided, initialize an empty cache (i.e., all 0's). + # Otherwise, treat as prefill case and invoke `extend_step`. + if time_step is not None: + init_state, output = self.extend_step( + init_state, query, key=key, value=value, kv_state=kv_state ) - # TODO(sneha,markblee): Add sharding annotations for all elements in the cache. - return cache + # The time_step from `extend_step` includes full query length. + init_state["time_step"] = time_step + + return init_state, output def forward( self, @@ -768,68 +809,6 @@ def forward( """ raise NotImplementedError(type(self)) - def prefill_states( - self, - *, - time_step: Tensor, - query: Tensor, - key: Optional[Tensor] = None, - value: Optional[Tensor] = None, - kv_state: Optional[KVState] = None, - ) -> tuple[NestedTensor, Output]: - """Initializes cache for autoregressive cached decoding. - - TODO(markblee): Rename to init_states once we add support for decoding at non-zero time - step. - - Args: - time_step: A Tensor of shape [batch]. Each value is an index into the length dimension - indicating where decoding will start from. - query: Tensor of shape [batch, target_length, target_dim] corresponding to query vector - at `time_step` indices. For batch index `i`, only `query[i, :time_step[i], ...]` - will affect subsequent decoding. - key: An optional Tensor of shape [batch, source_length, source_dim]. If None, will use - `query`. - value: An optional Tensor of shape [batch, source_length, source_dim]. If None, will - use `query`. - kv_state: An optional KVState. If not None, both key and value must be None. - - Returns: - A `NestedTensor` state of `key`, `value` of shape - [batch, num_heads, per_head_dim, source_length], and `time_step` of shape [batch]. - An Output instance, where query is of size - [batch, target_length, num_heads, per_head_dim] and each of key, value are of dim - [batch, source_length, num_heads, per_head_dim]. - """ - cfg = self.config - # Default to base layer dtype for initialization if cache_dtype is None. - dtype = cfg.cache_dtype or self.dtype() - assert dtype is not None - - if kv_state is not None: - if key is not None or value is not None: - raise ValueError("kv_state should not be specified together with key/value") - kv_kwargs = dict(kv_state=kv_state) - else: - kv_kwargs = dict(key=key, value=value) - # In the prefill state, the time_step filtering is not provided in the QKV forward function, - # but in the time_step_mask defined below. - # Therefore, time_step argument for the forward is set as None. - q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs) - - init_state = dict(time_step=time_step) - # If external kv_state is provided, we don't need to maintain key/value in cached_state. - if kv_state is None: - # Zero-out everything from time_step onwards. - # Being able to assume that non-filled cache values are 0 allows us to do a slightly - # more efficient update to `cached_{key,value}` in `extend_step`, by doing a simple add - # instead of a mask + add. - time_step_mask = (jnp.arange(k_proj.shape[1]) < time_step[:, None])[..., None, None] - k_proj = k_proj * time_step_mask - v_proj = v_proj * time_step_mask - init_state.update(key=k_proj.astype(dtype), value=v_proj.astype(dtype)) - return init_state, self.Output(query=q_proj, key=k_proj, value=v_proj) - def extend_step( self, cached_states: NestedTensor, @@ -1761,7 +1740,7 @@ def _forward_for_mode( self, *, mode: ForwardMode, - query: Tensor, + query: Union[Tensor, TensorSpec], key: Optional[Tensor] = None, value: Optional[Tensor] = None, kv_state: Optional[KVState] = None, @@ -1769,7 +1748,7 @@ def _forward_for_mode( segment_ids: Optional[Tensor] = None, cached_states: Optional[NestedTensor] = None, return_aux: Optional[set[str]] = None, - ) -> tuple[Optional[NestedTensor], Output]: + ) -> tuple[Nested[Tensor], Optional[Output]]: """Computes attention for the given query, key, value, and attention logit biases. If key and value are both None, computes self-attention using query. @@ -1777,19 +1756,21 @@ def _forward_for_mode( Args: mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for details. - query: A Tensor of shape [batch, target_length, target_dim]. + query: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. key: An optional Tensor of shape [batch, source_length, source_dim]. value: An optional Tensor of shape [batch, source_length, source_dim]. kv_state: An optional KVState. If specified, both `key` and `value` should be None. attention_logit_biases: See ``On attention logit biases`` in the file comments. segment_ids: See ``On segment_ids`` in the file comments. - cached_states: Optional NestedTensor as produced by `prefill_states`. + cached_states: Optional NestedTensor as produced by `init_states`. return_aux: See comments on `Output`. Returns: - An optional NestedTensor of cache states, depending on `mode`. - An Output instance, where .data is of the same shape as query and .probs is of shape - [batch, num_heads, target_length, source_length]. + A tuple (cached_states, output): + * cached_states: An optional NestedTensor of cache states, depending on `mode`. + * output: An optional Output instance, where .data is of the same shape as query and + .probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. Raises: ValueError: If key & value are an invalid combination. @@ -1809,20 +1790,25 @@ def _forward_for_mode( kv_kwargs = dict(key=key, value=value) if mode == ForwardMode.FORWARD: - i_proj_state, (q_proj, k_proj, v_proj) = None, self.i_proj(query, **kv_kwargs) + i_proj_state, i_proj_output = None, self.i_proj(query, **kv_kwargs) elif mode == ForwardMode.INIT_STATES: assert cached_states is not None - i_proj_state, (q_proj, k_proj, v_proj) = self.i_proj.prefill_states( + i_proj_state, i_proj_output = self.i_proj.init_states( time_step=cached_states["i_proj"], query=query, **kv_kwargs ) elif mode == ForwardMode.EXTEND_STEP: assert cached_states is not None - i_proj_state, (q_proj, k_proj, v_proj) = self.i_proj.extend_step( + i_proj_state, i_proj_output = self.i_proj.extend_step( cached_states["i_proj"], query, **kv_kwargs ) else: raise ValueError(f"Unrecognized mode {mode}.") + if i_proj_output is None: + assert mode == ForwardMode.INIT_STATES + return dict(i_proj=i_proj_state), None + + q_proj, k_proj, v_proj = i_proj_output kv_state = KVState(k_proj=k_proj, v_proj=v_proj) q_proj = self._remat_name(q_proj, "q_proj") k_proj = self._remat_name(k_proj, "k_proj") @@ -2022,50 +2008,27 @@ def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: def init_states( self, *, - target_batch_size: int, - target_max_len: int, - kv_state: Optional[KVState] = None, - ) -> NestedTensor: - """Initializes cache for autoregressive cached decoding. - - Args: - target_batch_size: The batch size of the target to be decoded. - target_max_len: The sequence length of the target to be decoded. - kv_state: An optional KVState. - - Returns: - The cache as a `NestedTensor` with key and value initialized. - """ - return dict( - i_proj=self.i_proj.init_states( - target_batch_size=target_batch_size, - target_max_len=target_max_len, - kv_state=kv_state, - ) - ) - - def prefill_states( - self, - *, - time_step: Tensor, - query: Tensor, + time_step: Optional[Tensor], + query: Union[Tensor, TensorSpec], key: Optional[Tensor] = None, value: Optional[Tensor] = None, kv_state: Optional[KVState] = None, attention_logit_biases: Optional[Tensor], return_aux: Optional[set[str]] = None, - ) -> tuple[NestedTensor, Output]: + ) -> tuple[Nested[Tensor], Optional[Output]]: """Initializes cache for autoregressive cached decoding. - TODO(markblee): Rename to init_states once we add support for decoding at non-zero time - step. + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `query` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `query` as Tensors. Args: - time_step: A Tensor of shape [B]. Each value is an index into the length dimension + time_step: A Tensor of shape [batch]. Each value is an index into the length dimension indicating where decoding will start from. - query: Tensor of shape [B, T, D] corresponding to query projection input vector - up to `time_step`. For batch index `i`, only `query[i, :time_step[i], ...]` - will affect subsequent decoding. + query: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding + to query projection input vector up to `time_step`. For batch index `i`, only + `query[i, :time_step[i], ...]` will affect subsequent decoding. key: Same description as `query`, but for the key projection input vector. Key and value have to both be tensors or both be None. If they are tensors, key and value are used as the unique input to the @@ -2077,9 +2040,12 @@ def prefill_states( return_aux: See comments on `Output`. Returns: - A `NestedTensor` state of key and value pair along with index updated at `time_step`. - An Output instance, where .data is of the same shape as query and .probs is of shape - [batch, num_heads, target_length, source_length]. + A tuple (init_states, output): + * init_states: A Nested Tensor state of key and value pair along with index updated at + `time_step`. + * output: In the prefill case, an Output instance, where .data is of the same shape as + query and .probs is of shape [batch, num_heads, target_length, source_length]. + Otherwise, if initializing cache from scratch, output will be None. """ return self._forward_for_mode( mode=ForwardMode.INIT_STATES, @@ -2601,30 +2567,32 @@ def _forward_for_mode( self, *, mode: ForwardMode, - target: Tensor, + target: Union[Tensor, TensorSpec], source: Optional[Union[Tensor, KVState]] = None, attention_logit_biases: Optional[Tensor] = None, segment_ids: Optional[Tensor] = None, cached_states: Optional[NestedTensor] = None, return_aux: Optional[set[str]] = None, - ) -> tuple[Optional[NestedTensor], Output]: + ) -> tuple[Optional[Nested[Tensor]], Optional[Output]]: """Computes either self-attention or cross-attention for the given target and source. Args: mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for details. - target: A Tensor of shape [batch, target_length, target_dim]. + target: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. source: An optional KVState or Tensor of shape [batch, source_length, source_dim]. If None, uses norm(target) as source (self-attention). attention_logit_biases: See ``On attention logit biases`` in the file comments. segment_ids: segment_ids: See ``On segment_ids`` in the file comments. - cached_states: Optional NestedTensor as produced by `prefill_states`. + cached_states: Optional NestedTensor as produced by `init_states`. return_aux: See comments on `Output`. Returns: - An optional NestedTensor of cache states, depending on `mode`. - An Output instance, where .data is of the same shape as query and .probs is of shape - [batch, num_heads, target_length, source_length]. + A tuple (cached_states, output): + * cached_states: An optional Nested Tensor of cache states, depending on `mode`. + * output: An optional Output instance, where .data is of the same shape as query and + .probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. Raises: ValueError: If `mode` is unsupported. @@ -2655,7 +2623,7 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: ) elif mode == ForwardMode.INIT_STATES: assert cached_states is not None - atten_state, atten_output = self.attention.prefill_states( + atten_state, atten_output = self.attention.init_states( time_step=cached_states["attention"], query=target, **kv_kwargs, @@ -2673,6 +2641,12 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: raise ValueError(f"Unrecognized mode {mode}.") return atten_state, atten_output + if mode == ForwardMode.INIT_STATES: + assert cached_states is not None + if cached_states["attention"] is None: + atten_state, atten_output = attention_thunk(TensorSpec(target.shape, target.dtype)) + return dict(attention=atten_state), atten_output + if cfg.structure == "prenorm": skip_input = target # pre-norm: where normalization happens within the residual part. norm_target = self.norm(target) @@ -2736,41 +2710,18 @@ def forward( def init_states( self, *, - target_batch_size: int, - target_max_len: int, - kv_state: Optional[KVState] = None, - ) -> NestedTensor: - """Initializes cache for autoregressive cached decoding. - - Args: - target_batch_size: The batch size of the target to be decoded. - target_max_len: The sequence length of the target to be decoded. - kv_state: An optional KVState. - - Returns: - The cache as a `NestedTensor` with key and value initialized. - """ - return dict( - attention=self.attention.init_states( - target_batch_size=target_batch_size, - target_max_len=target_max_len, - kv_state=kv_state, - ) - ) - - def prefill_states( - self, - *, - time_step: NestedTensor, - target: Tensor, + time_step: Optional[Tensor], + target: Union[Tensor, TensorSpec], source: Optional[Union[Tensor, KVState]] = None, attention_logit_biases: Optional[Tensor] = None, return_aux: Optional[set[str]] = None, - ) -> tuple[NestedTensor, Output]: + ) -> tuple[Nested[Tensor], Optional[Output]]: """Initializes cache for autoregressive cached decoding. - TODO(markblee): Rename to init_states once we add support for decoding at non-zero time - step. + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `target` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `target` as Tensors. Args: time_step: A Tensor of shape [batch]. Each value is an index into the length dimension @@ -2784,9 +2735,11 @@ def prefill_states( return_aux: See comments on `Output`. Returns: - A `NestedTensor` state depending on the `attention` layer implementation. - An Output instance, where .data is of the same shape as query, .probs is of shape - [batch, num_heads, target_length, source_length]. + A tuple (init_states, output): + * init_states: A Nested Tensor state depending on the `attention` layer implementation. + * output: In the prefill case, an Output instance, where .data is of the same shape as + query, .probs is of shape [batch, num_heads, target_length, source_length]. + Otherwise, if initializing cache from scratch, output will be None. """ return self._forward_for_mode( mode=ForwardMode.INIT_STATES, @@ -2805,7 +2758,7 @@ def extend_step( source: Optional[Union[Tensor, KVState]] = None, attention_logit_biases: Optional[Tensor] = None, return_aux: Optional[set[str]] = None, - ) -> tuple[NestedTensor, Output]: + ) -> tuple[Nested[Tensor], Output]: """Computes the value vector given the query of the current step. This function is used by autoregressive decoding. @@ -2831,7 +2784,7 @@ def extend_step( Raises: NotImplementedError: If cfg.structure is not supported. """ - return self._forward_for_mode( + return self._forward_for_mode( # pytype: disable=bad-return-type mode=ForwardMode.EXTEND_STEP, target=target, source=source, @@ -3128,7 +3081,7 @@ def _forward_for_mode( self, *, mode: ForwardMode, - data: Tensor, + data: Union[Tensor, TensorSpec], self_attention_kv_state: Optional[KVState] = None, self_attention_logit_biases: Optional[Tensor] = None, cross_attention_data: Optional[Tensor] = None, @@ -3136,32 +3089,35 @@ def _forward_for_mode( target_segment_ids: Optional[Tensor] = None, cached_states: Optional[NestedTensor] = None, return_aux: Optional[set[str]] = None, - ) -> tuple[Optional[NestedTensor], BaseTransformerLayer.Output]: + ) -> tuple[Optional[NestedTensor], Optional[BaseTransformerLayer.Output]]: """Computes transformer layer outputs and self/cross-attention probabilities. Args: mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for details. - data: A Tensor of shape [batch, target_length, target_dim]. + data: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. self_attention_kv_state: An optional KVState used for self-attention. self_attention_logit_biases: An optional Tensor representing the self-attention biases. cross_attention_data: An optional Tensor of shape [batch, source_length, source_dim]. cross_attention_logit_biases: An optional Tensor representing the cross-attention biases. target_segment_ids: See ``segment_ids`` in the file comments. - cached_states: Optional NestedTensor as produced by `prefill_states`. + cached_states: Optional NestedTensor as produced by `init_states`. return_aux: See comments on BaseTransformerLayer.forward. Returns: - An optional NestedTensor of cache states, depending on `mode`. - An Output instance, where .data is of the same shape as `data`, .self_attention_probs is - of shape [batch, num_heads, target_length, target_length], and .cross_attention_probs is - of shape [batch, num_heads, target_length, source_length]. + A tuple (cached_states, output): + * cached_states: An optional Nested Tensor of cache states, depending on `mode`. + * output: An optional Output instance, where .data is of the same shape as `data`, + .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; + .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. Raises: ValueError: If `mode` is unsupported. """ - self.vlog(3, "transformer.input=%s", data.sum()) + if isinstance(data, Tensor): + self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error self_attention_return_aux = set() cross_attention_return_aux = set() if return_aux: @@ -3186,7 +3142,7 @@ def _forward_for_mode( assert cached_states is not None if target_segment_ids is not None: raise NotImplementedError("target_segment_ids is not supported in INIT_STATES.") - self_atten_state, self_atten_outputs = self.self_attention.prefill_states( + self_atten_state, self_atten_outputs = self.self_attention.init_states( time_step=cached_states["self_attention"], target=data, source=self_attention_kv_state, @@ -3206,6 +3162,11 @@ def _forward_for_mode( ) else: raise ValueError(f"Unrecognized mode {mode}.") + + if self_atten_outputs is None: + assert mode == ForwardMode.INIT_STATES + return dict(self_attention=self_atten_state), self_atten_outputs + data = self_atten_outputs.data self.vlog(3, "self_attention.output=%s", data.sum()) if cross_attention_data is not None: @@ -3237,35 +3198,16 @@ def forward( **kwargs, ) -> BaseTransformerLayer.Output: _, output = self._forward_for_mode( - mode=ForwardMode.FORWARD, - data=data, - cached_states=None, - **kwargs, + mode=ForwardMode.FORWARD, data=data, cached_states=None, **kwargs ) return output def init_states( self, - *, - target_batch_size: int, - target_max_len: int, - self_attention_kv_state: Optional[KVState] = None, - ) -> NestedTensor: - return dict( - self_attention=self.self_attention.init_states( - target_batch_size=target_batch_size, - target_max_len=target_max_len, - kv_state=self_attention_kv_state, - ) - ) - - def prefill_states( - self, - *, - time_step: Tensor, - data: Tensor, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], **kwargs, - ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: + ) -> tuple[Nested[Tensor], Optional[BaseTransformerLayer.Output]]: return self._forward_for_mode( mode=ForwardMode.INIT_STATES, cached_states=dict(self_attention=time_step), @@ -3279,7 +3221,7 @@ def extend_step( data: Tensor, **kwargs, ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: - return self._forward_for_mode( + return self._forward_for_mode( # pytype:disable=bad-return-type mode=ForwardMode.EXTEND_STEP, cached_states=cached_states, data=data, @@ -3407,33 +3349,36 @@ def _forward_for_mode( self, *, mode: ForwardMode, - data: Tensor, + data: Union[Tensor, TensorSpec], cached_states: Optional[NestedTensor] = None, **kwargs, - ) -> tuple[Optional[NestedTensor], Tensor]: + ) -> tuple[Optional[Nested[Tensor]], Optional[Tensor]]: """Computes transformer layer outputs and self/cross-attention probabilities. Args: mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for details. data: A Tensor of shape [batch, target_length, target_dim]. - cached_states: Optional NestedTensor as produced by `prefill_states`. + cached_states: Optional NestedTensor as produced by `init_states`. Returns: - An optional NestedTensor of cache states, depending on `mode`. - An Output instance, where .data is of the same shape as `data`, .self_attention_probs is - of shape [batch, num_heads, target_length, target_length], and .cross_attention_probs is - of shape [batch, num_heads, target_length, source_length]. + A tuple (cached_states, output): + * cached_states: An optional NestedTensor of cache states, depending on `mode`. + * output: An Output instance, where .data is of the same shape as `data`; + .self_attention_probs is of shape [batch, num_heads, target_length, target_length]; + .cross_attention_probs is of shape [batch, num_heads, target_length, source_length]. + If initializing cache from scratch, output will be None. Raises: ValueError: If `mode` is unsupported. """ - self.vlog(3, "transformer.input=%s", data.sum()) + if isinstance(data, Tensor): + self.vlog(3, "transformer.input=%s", data.sum()) # pytype: disable=attribute-error if mode == ForwardMode.FORWARD: output = self.layer.forward(data=data, **kwargs) elif mode == ForwardMode.INIT_STATES: assert cached_states is not None - cached_states, output = self.layer.prefill_states( + cached_states, output = self.layer.init_states( time_step=cached_states["layer"], data=data, **kwargs, @@ -3447,6 +3392,11 @@ def _forward_for_mode( ) else: raise ValueError(f"Unrecognized mode {mode}.") + + if output is None: + assert mode == ForwardMode.INIT_STATES and cached_states["layer"] is None + return cached_states, output + skip_input = output.data data = self.adapter(output.data) data += skip_input @@ -3466,16 +3416,13 @@ def forward( ) return output - def init_states(self, **kwargs) -> NestedTensor: - return dict(layer=self.layer.init_states(**kwargs)) - - def prefill_states( + def init_states( self, *, - time_step: Tensor, - data: Tensor, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], **kwargs, - ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: + ) -> tuple[Nested[Tensor], Optional[BaseTransformerLayer.Output]]: return self._forward_for_mode( mode=ForwardMode.INIT_STATES, cached_states=dict(layer=time_step), @@ -3489,7 +3436,7 @@ def extend_step( data: Tensor, **kwargs, ) -> tuple[NestedTensor, BaseTransformerLayer.Output]: - return self._forward_for_mode( + return self._forward_for_mode( # pytype: disable=bad-return-type mode=ForwardMode.EXTEND_STEP, cached_states=cached_states, data=data, @@ -3674,39 +3621,44 @@ def _forward_for_mode( self, *, mode: ForwardMode, - data: Tensor, - cached_states: Optional[NestedTensor] = None, + data: Union[Tensor, TensorSpec], + cached_states: Optional[Nested[Tensor]] = None, **layer_kwargs, - ) -> tuple[list[Optional[NestedTensor]], TransformerLayer.Output]: + ) -> tuple[list[Optional[Nested[Tensor]]], Optional[TransformerLayer.Output]]: """Computes transformer stack outputs. Args: mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for details. - data: A Tensor of shape [batch, target_length, target_dim]. - cached_states: Optional NestedTensor as produced by `prefill_states`. + data: A Tensor or TensorSpec of shape [batch, target_length, target_dim]. + cached_states: Optional Nested Tensor as produced by `init_states`. Returns: - (updated_cache_states, outputs), where - updated_cached_states is an optional NestedTensor of cache states, depending on `mode`; - outputs is an instance of Output (see comments on BaseStackedTransformerLayer). + A tuple (updated_cache_states, outputs): + * updated_cached_states: An optional NestedTensor of cache states, depending on `mode`; + * outputs: An optional instance of Output (see comments on BaseStackedTransformerLayer). Raises: ValueError: If `mode` is unsupported. """ all_layer_outputs = [] all_layer_states = [] + + # True iff we are initializing an empty cache (i.e., not prefilling). + cache_init = mode == ForwardMode.INIT_STATES and cached_states is None + for i, layer in enumerate(self._layers): # Prepare inputs to the current layer. if self._update_data is not None: data = self._update_data(data, all_layer_outputs) + # TODO(markblee): Consider folding into _update_data. self._update_layer_kwargs(layer_kwargs, all_layer_outputs=all_layer_outputs) if mode == ForwardMode.FORWARD: layer_states, layer_outputs = None, layer(data, **layer_kwargs) elif mode == ForwardMode.INIT_STATES: - assert cached_states is not None - layer_states, layer_outputs = layer.prefill_states( + # cached_states is allowed to be None in the case where we initialize from scratch. + layer_states, layer_outputs = layer.init_states( time_step=cached_states, data=data, **layer_kwargs, @@ -3720,11 +3672,36 @@ def _forward_for_mode( ) else: raise ValueError(f"Unrecognized mode {mode}.") - all_layer_outputs.append(layer_outputs) + all_layer_states.append(layer_states) + + # If initializing the cache from scratch, layer_outputs will be None. Further, `data` + # can be effectively treated as a TensorSpec, and thus does not need to be carried + # across layers. + if layer_outputs is None: + assert cache_init + continue + + all_layer_outputs.append(layer_outputs) data = layer_outputs.data - return all_layer_states, self._aggregate_layer_outputs(all_layer_outputs) + outputs = None if cache_init else self._aggregate_layer_outputs(all_layer_outputs) + return all_layer_states, outputs + + def init_states( + self, + *, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], + **layer_kwargs, + ) -> tuple[list[Nested[Tensor]], Optional[TransformerLayer.Output]]: + """See `BaseTransformerLayer.init_states` for details.""" + return self._forward_for_mode( + mode=ForwardMode.INIT_STATES, + cached_states=time_step, + data=data, + **layer_kwargs, + ) def _update_layer_kwargs( self, @@ -3771,31 +3748,13 @@ def forward( ) return output - def init_states(self, *args: Any, **kwargs: Any) -> NestedTensor: - # TODO(sneha): any better ds? - return [layer.init_states(*args, **kwargs) for layer in self._layers] - - def prefill_states( - self, - *, - time_step: Tensor, - data: Tensor, - **layer_kwargs, - ) -> tuple[list[NestedTensor], TransformerLayer.Output]: - return self._forward_for_mode( - mode=ForwardMode.INIT_STATES, - cached_states=time_step, - data=data, - **layer_kwargs, - ) - def extend_step( self, cached_states: list[NestedTensor], data: Tensor, **layer_kwargs, - ) -> tuple[list[NestedTensor], TransformerLayer.Output]: - return self._forward_for_mode( + ) -> tuple[list[Nested[Tensor]], TransformerLayer.Output]: + return self._forward_for_mode( # pytype: disable=bad-return-type mode=ForwardMode.EXTEND_STEP, cached_states=cached_states, data=data, @@ -3823,27 +3782,31 @@ def _forward_for_mode( self, *, mode: ForwardMode, - data: Tensor, - cached_states: Optional[NestedTensor] = None, + data: Union[Tensor, TensorSpec], + cached_states: Optional[Nested[Tensor]] = None, **layer_kwargs, - ) -> tuple[Optional[NestedTensor], TransformerLayer.Output]: + ) -> tuple[Optional[Nested[Tensor]], Optional[TransformerLayer.Output]]: """Computes transformer stack outputs. Args: mode: Configures whether `cached_states` are consumed or emitted. See `ForwardMode` for details. data: A Tensor of shape [batch, target_length, target_dim]. - cached_states: Optional NestedTensor as produced by `prefill_states`. + cached_states: Optional Nested Tensor as produced by `init_states`. + layer_kwargs: Additional kwargs to each layer. Returns: - (updated_cache_states, outputs), where - updated_cached_states is an optional NestedTensor of cache states, depending on `mode`; - outputs is an instance of Output (see comments on BaseStackedTransformerLayer). + A tuple (updated_cache_states, outputs): + * updated_cached_states: An optional NestedTensor of cache states, depending on `mode`; + * outputs: An optional instance of Output (see comments on BaseStackedTransformerLayer). Raises: ValueError: If `mode` is unsupported. """ - cfg = self.config + cfg: _TransformerRepeat.Config = self.config + + # True iff we are initializing an empty cache (i.e., not prefilling). + cache_init = mode == ForwardMode.INIT_STATES and cached_states is None if cached_states is not None: for path, value in flatten_items(cached_states): @@ -3853,25 +3816,29 @@ def layer_fn(carry, x_i): if mode == ForwardMode.FORWARD: layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs) elif mode == ForwardMode.INIT_STATES: - assert x_i is not None - layer_states, layer_outputs = self.layer.prefill_states( - time_step=x_i, - **carry, - **layer_kwargs, + # Note that x_i can be None if initializing an empty cache. This corresponds to the + # case where `cached_states=None`. + layer_states, layer_outputs = self.layer.init_states( + time_step=x_i, **carry, **layer_kwargs ) elif mode == ForwardMode.EXTEND_STEP: assert x_i is not None layer_states, layer_outputs = self.layer.extend_step( - cached_states=x_i, - **carry, - **layer_kwargs, + cached_states=x_i, **carry, **layer_kwargs ) else: raise ValueError(f"Unrecognized mode {mode}.") - ys = {k: v for k, v in layer_outputs._asdict().items() if k not in carry} + ys = {} if layer_states is not None: ys["cached_states"] = layer_states + + # If initializing the cache from scratch, layer_outputs will be None. + if layer_outputs is None: + assert cache_init + return carry, ys + + ys.update({k: v for k, v in layer_outputs._asdict().items() if k not in carry}) return {k: getattr(layer_outputs, k) for k in carry}, ys if cfg.carry is None: @@ -3879,11 +3846,16 @@ def layer_fn(carry, x_i): else: layer_kwargs["data"] = data carry = {k: layer_kwargs.pop(k) for k in cfg.carry} + repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states) carry = repeat_outputs.carry ys = repeat_outputs.ys updated_states = ys.pop("cached_states", None) + if cache_init: + assert ys == {} + return updated_states, None + for k in ("data", "self_attention_kv_state"): if k in carry: continue @@ -3910,26 +3882,31 @@ def forward( ) return output - def init_states(self, *args: Any, **kwargs: Any) -> NestedTensor: - cfg = self.config - - def layer_fn(_): - return self.layer.init_states(*args, **kwargs) - - return jax.vmap(layer_fn)(jnp.empty(cfg.num_layers)) - - def prefill_states( + def init_states( self, *, - time_step: Tensor, - data: Tensor, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], **layer_kwargs, - ) -> tuple[NestedTensor, TransformerLayer.Output]: - cfg = self.config + ) -> tuple[Nested[Tensor], Optional[TransformerLayer.Output]]: + cfg: _TransformerRepeat.Config = self.config + # time_step is allowed to be None if initializing an empty cache. + if time_step is not None: + time_step = jnp.tile(time_step, [cfg.num_layers, 1]) + + # In the repeat case, scan requires a Tensor rather than ShapeDtypeStruct. + # Use vmap rather than materializing the Tensor. + if isinstance(data, TensorSpec): + + def layer_fn(_): + return self.layer.init_states(time_step=time_step, data=data, **layer_kwargs) + + return jax.vmap(layer_fn)(jnp.empty(cfg.num_layers)) + return self._forward_for_mode( mode=ForwardMode.INIT_STATES, data=data, - cached_states=jnp.tile(time_step, [cfg.num_layers, 1]), + cached_states=time_step, **layer_kwargs, ) @@ -3939,7 +3916,7 @@ def extend_step( data: Tensor, **layer_kwargs, ) -> tuple[NestedTensor, TransformerLayer.Output]: - return self._forward_for_mode( + return self._forward_for_mode( # pytype: disable=bad-return-type mode=ForwardMode.EXTEND_STEP, data=data, cached_states=cached_states, @@ -3988,20 +3965,9 @@ def forward( ) -> TransformerLayer.Output: return self.repeat(data, **layer_kwargs) - def init_states(self, *args: Any, **kwargs: Any) -> NestedTensor: - return VDict(repeat=self.repeat.init_states(*args, **kwargs)) - - def prefill_states( - self, - *, - time_step: Tensor, - data: Tensor, - **layer_kwargs, - ) -> tuple[list[NestedTensor], TransformerLayer.Output]: - repeat_cached_states, output = self.repeat.prefill_states( - time_step=time_step, data=data, **layer_kwargs - ) - return VDict(repeat=repeat_cached_states), output + def init_states(self, *args, **kwargs): + cached_states, output = self.repeat.init_states(*args, **kwargs) + return VDict(repeat=cached_states), output def extend_step( self, diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 692180b25..a42184d14 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -116,6 +116,7 @@ Nested, PartitionSpec, Tensor, + TensorSpec, VDict, as_tensor, flatten_items, @@ -1472,7 +1473,10 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], ex inputs=dict(query=query), ) - cache_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) + cache_state, init_output = layer.init_states( + time_step=None, query=TensorSpec([batch_size, tgt_len]) + ) + self.assertIsNone(init_output) step_querys = [] step_keys = step_values = None for t in range(0, tgt_len, extend_step_len): @@ -1531,18 +1535,19 @@ def __init__(self, cfg: Config, *, parent: Module): qkv_linear = parent.qkv_linear state = qkv_linear.initialize_parameters_recursively(jax.random.PRNGKey(0)) - # Check dtypes from init_states - cache, _ = F( + # Check dtypes from init_states. + (cache, init_output), _ = F( qkv_linear, prng_key=jax.random.PRNGKey(0), state=state, inputs=dict( - target_batch_size=target_batch_size, - target_max_len=target_max_len, + time_step=None, + query=TensorSpec([target_batch_size, target_max_len]), ), method="init_states", is_training=False, ) + self.assertIsNone(init_output) self.assertEqual(cache["key"].dtype, dtype) self.assertEqual(cache["value"].dtype, dtype) @@ -1562,7 +1567,7 @@ def __init__(self, cfg: Config, *, parent: Module): prng_key=jax.random.PRNGKey(0), state=state, inputs=dict(time_step=time_step, query=query), - method="prefill_states", + method="init_states", is_training=False, ) self.assertEqual(init_state["key"].dtype, dtype) @@ -2448,9 +2453,14 @@ def _test_extend_step( inputs=inputs, ) - initial_state = layer.init_states( - target_batch_size=batch_size, target_max_len=tgt_len, kv_state=kv_state + initial_state, initial_output = layer.init_states( + time_step=None, + query=TensorSpec([batch_size, tgt_len]), + kv_state=kv_state, + # This is unused for initializing state from scratch. + attention_logit_biases=None, ) + self.assertIsNone(initial_output) if kv_state is None: for k in ["key", "value"]: # Check that the cache dtype is inferred as the layer dtype. @@ -2619,7 +2629,7 @@ def _test_prefill_states( attention_logit_biases=attention_logit_biases, return_aux=return_aux, ), - method="prefill_states", + method="init_states", ) # Check time_step and shapes of state. @@ -3227,6 +3237,96 @@ def test_multihead_attention_xl(self): ) +class TransformerAttentionLayerTest(TestCase): + @parameterized.parameters([False, True]) + def test_forward_vs_extend_step(self, with_source: bool): + init_prng, target_prng, source_prng = jax.random.split(jax.random.PRNGKey(0), 3) + + model_dim = 8 + layer_kwargs = dict(target_dim=model_dim, source_dim=model_dim) + cfg: TransformerAttentionLayer.Config = TransformerAttentionLayer.default_config().set( + **layer_kwargs + ) + cfg.attention.set(num_heads=2, mask=causal_mask) + layer: TransformerAttentionLayer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=init_prng) + + batch, decode_len = 2, 6 + target = jax.random.uniform(target_prng, shape=[batch, decode_len, model_dim]) + input_kwargs = {} + + if with_source: + input_kwargs.update( + source=jax.random.uniform(source_prng, shape=[batch, decode_len, model_dim]) + ) + + forward_outputs, _ = F( + layer, + inputs=dict(target=jnp.asarray(target), **input_kwargs), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + ) + + for start_time_step in (-1, 0, 2, decode_len): + if start_time_step < 0: + (cached_states, init_outputs), _ = F( + layer, + inputs=dict( + time_step=None, + target=TensorSpec(target.shape, target.dtype), + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="init_states", + ) + self.assertIsNone(init_outputs) + data = jnp.zeros([batch, decode_len, model_dim]) + start_time_step = 0 + else: + (cached_states, prefill_outputs), _ = F( + layer, + inputs=dict( + time_step=jnp.array([start_time_step] * batch, dtype=jnp.int32), + target=target, + **input_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="init_states", + ) + data = prefill_outputs.data + + data = jnp.einsum("btd->tbd", data) + + for time_step in range(start_time_step, decode_len): + extend_kwargs = {} + for k, v in input_kwargs.items(): + extend_kwargs[k] = jnp.asarray(v[:, time_step : time_step + 1, :]) + + (cached_states, extend_outputs), _ = F( + layer, + inputs=dict( + target=jnp.asarray(target[:, time_step : time_step + 1, :]), + cached_states=cached_states, + **extend_kwargs, + ), + state=layer_params, + is_training=True, + prng_key=jax.random.PRNGKey(0), + method="extend_step", + ) + data = data.at[time_step].set(jnp.squeeze(extend_outputs.data, axis=1)) + + data = jnp.einsum("tbd->btd", data) + + # Prefill + extend_step == forward. + assert_allclose(forward_outputs.data, data) + + class TransformerFeedForwardLayerTest(TestCase): @parameterized.parameters( dict(rms_norm_summary=[]), @@ -3392,13 +3492,13 @@ def _test_forward_vs_extend_step( for start_time_step in (-1, 0, 2, tgt_len): if start_time_step > tgt_len: continue - print(f"start_time_step={start_time_step}") + print(f"start_time_step={start_time_step} layer={type(layer)}") if start_time_step < 0: - cached_states, _ = F( + (cached_states, init_outputs), _ = F( layer, inputs=dict( - target_batch_size=batch_size, - target_max_len=tgt_len, + time_step=None, + data=TensorSpec([batch_size, tgt_len]), **input_kwargs, ), state=layer_params, @@ -3406,6 +3506,7 @@ def _test_forward_vs_extend_step( prng_key=jax.random.PRNGKey(0), method="init_states", ) + self.assertIsNone(init_outputs) decoder_output = jnp.zeros_like(target) start_time_step = 0 else: @@ -3419,7 +3520,7 @@ def _test_forward_vs_extend_step( state=layer_params, is_training=True, prng_key=jax.random.PRNGKey(0), - method="prefill_states", + method="init_states", ) decoder_output = prefill_outputs.data # Transpose to [tgt_len, batch_size, model_dim]. @@ -3850,7 +3951,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type): batch_size, src_len, tgt_len = 10, 4, 6 num_dec_layers, model_dim, num_heads = 3, 16, 4 - cfg = transformer_type.default_config().set( + cfg: BaseStackedTransformerLayer.Config = transformer_type.default_config().set( name="test", input_dim=model_dim, num_layers=num_dec_layers, @@ -3872,7 +3973,7 @@ def test_transformer_extend_step(self, transformer_type, layer_type): layer_cfg.feed_forward.hidden_dim = model_dim * 4 # Instantiate transformer stack. - layer = cfg.instantiate(parent=None) + layer: BaseStackedTransformerLayer = cfg.instantiate(parent=None) layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) target = jax.random.normal(jax.random.PRNGKey(123), [batch_size, tgt_len, model_dim]) @@ -3897,7 +3998,11 @@ def test_transformer_extend_step(self, transformer_type, layer_type): is_training=False, prng_key=jax.random.PRNGKey(0), ) - initial_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) + initial_state, initial_output = layer.init_states( + time_step=None, + data=TensorSpec([batch_size, tgt_len]), + ) + self.assertIsNone(initial_output) inputs = dict( cached_states=initial_state, cross_attention_data=source, return_aux=return_aux ) @@ -4036,7 +4141,7 @@ def test_transformer_prefill_states(self, transformer_type, layer_type): cross_attention_logit_biases=cross_attention_logit_biases, return_aux=return_aux, ), - method="prefill_states", + method="init_states", ) # Zero-out outputs starting from initial time_step, and test that we can recover the full diff --git a/axlearn/common/decoder.py b/axlearn/common/decoder.py index 941347e5c..755a377a7 100644 --- a/axlearn/common/decoder.py +++ b/axlearn/common/decoder.py @@ -47,7 +47,7 @@ current_context, new_output_collection, ) -from axlearn.common.utils import Nested, NestedTensor, with_sharding_constraint +from axlearn.common.utils import Nested, NestedTensor, TensorSpec, with_sharding_constraint # TODO(markblee): Remove this when we have a better solution at the decoding loop level. @@ -492,7 +492,7 @@ def _forward_for_mode( assert cached_states is not None if input_segment_ids is not None: raise ValueError("input_segment_ids is not supported in INIT_STATES.") - transformer_state, x = self.transformer.prefill_states( + transformer_state, x = self.transformer.init_states( time_step=cached_states["transformer_state"], data=x, self_attention_logit_biases=self_attention_logit_biases, @@ -584,10 +584,12 @@ def forward( def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTensor: """See `BaseDecoder.init_states` for details.""" cfg: Decoder.Config = self.config + init_state, _ = self.transformer.init_states( + time_step=None, + data=TensorSpec([batch_size, max_sequence_length, cfg.dim]), + ) return dict( - transformer_state=self.transformer.init_states( - target_batch_size=batch_size, target_max_len=max_sequence_length - ), + transformer_state=init_state, input_ids=jnp.full( (batch_size, max_sequence_length), cfg.pad_token_id, dtype=jnp.int32 ), diff --git a/axlearn/common/encoder.py b/axlearn/common/encoder.py index 0dd53defc..9846edbb9 100644 --- a/axlearn/common/encoder.py +++ b/axlearn/common/encoder.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Encoder layers.""" + import math from typing import Optional @@ -20,7 +21,7 @@ from axlearn.common.embedding import TransformerTextEmbeddings from axlearn.common.layers import BaseClassificationHead, set_dropout_rate_recursively from axlearn.common.module import Module, Tensor, child_context -from axlearn.common.utils import NestedTensor +from axlearn.common.utils import NestedTensor, TensorSpec class Encoder(BaseLayer): @@ -167,12 +168,15 @@ def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTen Returns: The cache as a `NestedTensor` with key and value initialized. """ + cfg: CausalEncoder.Config = self.config + init_state, _ = self.transformer.init_states( + time_step=None, + data=TensorSpec([batch_size, max_sequence_length, cfg.dim]), + ) return dict( - transformer_state=self.transformer.init_states( - target_batch_size=batch_size, target_max_len=max_sequence_length - ), + transformer_state=init_state, input_ids=jnp.full( - (batch_size, max_sequence_length), self.config.pad_token_id, dtype=jnp.int32 + (batch_size, max_sequence_length), cfg.pad_token_id, dtype=jnp.int32 ), time_step=jnp.zeros(batch_size, dtype=jnp.int32), ) @@ -279,7 +283,7 @@ def prefill_states( # Note: this follows `Decoder.prefill_states` closely. Refer to that method for details. # TODO(markblee): Possibly consolidate some of this with decoder. x = self.emb(input_ids, token_type_ids=token_type_ids, positions=None) - transformer_state, x = self.transformer.prefill_states( + transformer_state, x = self.transformer.init_states( time_step=time_step, data=x, self_attention_logit_biases=self.compute_attention_logit_biases(input_ids), diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index 52d1bc4ec..e04d5e3de 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -38,6 +38,7 @@ from axlearn.common.module import Module from axlearn.common.module import functional as F from axlearn.common.test_utils import TestCase, is_supported_mesh_shape +from axlearn.common.utils import TensorSpec def _fake_inputs( @@ -650,12 +651,20 @@ def test_extend_step( ) # Prepare initial states. - initial_state = test_layer.init_states( - target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state + initial_state, initial_output = test_layer.init_states( + time_step=None, + query=TensorSpec([batch, seq_len]), + kv_state=kv_state, + attention_logit_biases=None, ) - ref_initial_state = ref_layer.init_states( - target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state + ref_initial_state, ref_inital_output = ref_layer.init_states( + time_step=None, + query=TensorSpec([batch, seq_len]), + kv_state=kv_state, + attention_logit_biases=None, ) + self.assertIsNone(initial_output) + self.assertIsNone(ref_inital_output) for k in ["key", "value"]: self.assertEqual(ref_initial_state["i_proj"][k].dtype, dtype) self.assertEqual(initial_state["i_proj"][k].dtype, dtype) diff --git a/axlearn/common/lora_test.py b/axlearn/common/lora_test.py index 6e367e045..02cf95847 100644 --- a/axlearn/common/lora_test.py +++ b/axlearn/common/lora_test.py @@ -26,7 +26,7 @@ from axlearn.common.module import functional as F from axlearn.common.param_converter import as_torch_tensor from axlearn.common.test_utils import TestCase, assert_allclose -from axlearn.common.utils import Tensor +from axlearn.common.utils import Tensor, TensorSpec class LoraLinearTest(TestCase): @@ -233,9 +233,11 @@ def test_extend_step(self, layer): q_proj, k_proj, v_proj = outputs forward_outputs = jnp.stack([q_proj, k_proj, v_proj]) - initial_cache_state = layer.init_states( - target_batch_size=batch_size, target_max_len=seq_len + initial_cache_state, init_output = layer.init_states( + time_step=None, + query=TensorSpec([batch_size, seq_len]), ) + self.assertIsNone(init_output) decoder_inputs = dict(cached_states=initial_cache_state) decoder_outputs = jnp.zeros(shape=[seq_len, 3, batch_size, num_heads, per_head_dim]) @@ -305,7 +307,7 @@ def test_prefill_states(self): is_training=False, prng_key=jax.random.PRNGKey(456), inputs=dict(time_step=time_step, query=inputs), - method="prefill_states", + method="init_states", ) time_step_mask = jnp.arange(seq_len) < time_step[:, None] # [batch, tgt_len, num_heads, per_head_dim]. diff --git a/axlearn/common/multiway_transformer.py b/axlearn/common/multiway_transformer.py index 79790daea..cbbe39f58 100644 --- a/axlearn/common/multiway_transformer.py +++ b/axlearn/common/multiway_transformer.py @@ -13,7 +13,8 @@ https://arxiv.org/pdf/2111.02358.pdf https://github.com/microsoft/unilm/tree/master/vlmo """ -from typing import Optional + +from typing import Optional, Union import numpy as np from jax import numpy as jnp @@ -35,6 +36,7 @@ from axlearn.common.module import Module, NestedTensor, Tensor, child_context from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, GaussianInitializer from axlearn.common.poolings import BasePoolingLayer, FirstNTokenPooling +from axlearn.common.utils import Nested, TensorSpec from axlearn.common.vision_transformer import VisualEmbedding TEXT_MODALITY = 0 @@ -94,7 +96,7 @@ def _forward_for_mode( cross_attention_logit_biases: Optional[Tensor] = None, cached_states: Optional[NestedTensor] = None, return_aux: Optional[set[str]] = None, - ) -> tuple[Optional[NestedTensor], Tensor]: + ) -> tuple[Optional[Nested[Tensor]], Optional[Tensor]]: """Computes transformer layer outputs and self/cross-attention probabilities. Args: @@ -118,9 +120,9 @@ def _forward_for_mode( Raises: ValueError: If `mode` is unsupported. """ - cfg = self.config - self.vlog(3, "transformer.input=%s", data.sum()) + if isinstance(data, Tensor): + self.vlog(3, "transformer.input=%s", data.sum()) self_attention_return_aux = set() cross_attention_return_aux = set() if return_aux: @@ -131,13 +133,16 @@ def _forward_for_mode( if "cross_attention_probs" in return_aux: cross_attention_return_aux.add("probs") if mode == ForwardMode.FORWARD: - self_atten_state, self_atten_outputs = None, self.self_attention( - target=data, - attention_logit_biases=self_attention_logit_biases, - return_aux=self_attention_return_aux, + self_atten_state, self_atten_outputs = ( + None, + self.self_attention( + target=data, + attention_logit_biases=self_attention_logit_biases, + return_aux=self_attention_return_aux, + ), ) elif mode == ForwardMode.INIT_STATES: - self_atten_state, self_atten_outputs = self.self_attention.prefill_states( + self_atten_state, self_atten_outputs = self.self_attention.init_states( time_step=cached_states["self_attention"], target=data, attention_logit_biases=self_attention_logit_biases, @@ -152,6 +157,11 @@ def _forward_for_mode( ) else: raise ValueError(f"Unrecognized mode {mode}.") + + if self_atten_outputs is None: + assert mode == ForwardMode.INIT_STATES + return dict(self_attention=self_atten_state), None + data = self_atten_outputs.data self.vlog(3, "self_attention.output=%s", data.sum()) if cross_attention_data is not None: @@ -204,22 +214,13 @@ def forward( ) return output - def init_states(self, *, target_batch_size: int, target_max_len: int) -> NestedTensor: - return dict( - self_attention=self.self_attention.init_states( - target_batch_size=target_batch_size, target_max_len=target_max_len - ) - ) - - # pylint: disable-next=arguments-differ - def prefill_states( + def init_states( self, - *, - time_step: Tensor, - data: Tensor, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], feed_forward_index: int = 0, **kwargs, - ) -> tuple[NestedTensor, Output]: + ) -> tuple[Nested[Tensor], Optional[Output]]: return self._forward_for_mode( mode=ForwardMode.INIT_STATES, cached_states=dict(self_attention=time_step), diff --git a/axlearn/common/multiway_transformer_test.py b/axlearn/common/multiway_transformer_test.py index 13ace94fa..16aeba858 100644 --- a/axlearn/common/multiway_transformer_test.py +++ b/axlearn/common/multiway_transformer_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests Multiway transformer layers.""" + # pylint: disable=no-member,no-self-use,duplicate-code import jax import jax.numpy as jnp @@ -27,7 +28,7 @@ _set_model_config, ) from axlearn.common.test_utils import assert_allclose -from axlearn.common.utils import VDict, as_tensor, count_model_params +from axlearn.common.utils import TensorSpec, VDict, as_tensor, count_model_params from axlearn.vision import mask_generator @@ -145,7 +146,10 @@ def test_transformer_extend_step(self): is_training=False, prng_key=jax.random.PRNGKey(0), ) - initial_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) + initial_state, initial_output = layer.init_states( + time_step=None, data=TensorSpec([batch_size, tgt_len]) + ) + self.assertIsNone(initial_output) inputs = dict( cached_states=initial_state, cross_attention_data=source, return_aux=return_aux ) @@ -262,7 +266,7 @@ def test_prefill_states(self, transformer_type): cross_attention_logit_biases=cross_attention_logit_biases, return_aux=return_aux, ), - method="prefill_states", + method="init_states", ) # Zero-out outputs starting from initial time_step, and test that we can recover the full diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 7848058d9..df48f5baf 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -47,7 +47,7 @@ from axlearn.common.module import Module from axlearn.common.param_init import FanAxes, Initializer, Shape, constant_initializer, uniform from axlearn.common.ssm_kernels.mamba_kernels import compute_mamba_scan -from axlearn.common.utils import Nested, Tensor, with_sharding_constraint +from axlearn.common.utils import Nested, Tensor, TensorSpec, with_sharding_constraint class MambaDtProjInitializer(Initializer): @@ -553,10 +553,10 @@ class Config(BaseLayer.Config): # The recurrence implementation to use for full-sequence inputs. mamba_recurrence: BaseMambaRecurrence = LinearScanMambaRecurrence.default_config() # The recurrence implementation to use for inference. - inference_mamba_recurrence: ( - BaseMambaRecurrence - ) = LinearScanMambaRecurrence.default_config().set( - output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + inference_mamba_recurrence: BaseMambaRecurrence = ( + LinearScanMambaRecurrence.default_config().set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + ) ) class MambaOutput(NamedTuple): @@ -749,7 +749,7 @@ def _forward_for_mode( Args: mode: Configures whether `cached_states` are consumed or emitted. query: A Tensor of shape [batch, target_length, target_dim]. - cached_states: Optional NestedTensor as produced by `prefill_states`. + cached_states: Optional Nested Tensor as produced by `init_states`. Returns: An optional NestedTensor of cached states, depending on `mode`. @@ -760,12 +760,13 @@ def _forward_for_mode( """ self.vlog(3, "mamba.input=%s", query.sum()) if mode == ForwardMode.FORWARD: - mamba_state, mamba_output = None, self._full_sequence_forward( - query, recurrence=self.recurrence + mamba_state, mamba_output = ( + None, + self._full_sequence_forward(query, recurrence=self.recurrence), ) elif mode == ForwardMode.INIT_STATES: assert cached_states is not None - mamba_state, mamba_output = self.prefill_states( + mamba_state, mamba_output = self.init_states( time_step=cached_states["mamba_layer"], query=query, ) @@ -789,54 +790,52 @@ def forward(self, query: Tensor) -> MambaOutput: _, output = self._forward_for_mode(mode=ForwardMode.FORWARD, query=query) return output - # pylint: disable=unused-argument - def init_states(self, *, target_batch_size: int, **_kwargs) -> Nested[Tensor]: - """Initializes cache for autoregressive cached decoding. - - Args: - target_batch_size: The batch size of the target to be decoded. - - Returns: - The cache as a Nested[Tensor]. - """ - cfg = self.config - dtype = cfg.cache_dtype or cfg.dtype - cache = dict( - conv_input=jnp.zeros((target_batch_size, cfg.conv.window, self.inner_dim), dtype=dtype), - state=jnp.zeros((target_batch_size, 1, cfg.state_dim, self.inner_dim), dtype=dtype), - time_step=jnp.zeros(target_batch_size, dtype=jnp.int32), - ) - return cache - - def prefill_states( + def init_states( self, *, - time_step: Tensor, - query: Tensor, - ) -> tuple[Nested[Tensor], MambaOutput]: + time_step: Optional[Tensor], + query: Union[Tensor, TensorSpec], + ) -> tuple[Nested[Tensor], Optional[MambaOutput]]: """Initializes cache for autoregressive cached decoding. + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `query` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `query` as Tensors. + Args: - time_step: A Tensor of shape [batch_size]. Each value is an index into the length - dimension indicating where decoding will start from. - query: Tensor of shape [batch, target_length, target_dim] corresponding to query vector - up to `time_step` indices. For batch index `i`, only `query[i, :time_step[i], ...]` - will affect subsequent decoding. + time_step: An optional Tensor of shape [batch_size]. Each value is an index into the + length dimension indicating where decoding will start from. + query: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding + to query vector up to `time_step` indices. For batch index `i`, only + `query[i, :time_step[i], ...]` will affect subsequent decoding. Returns: - A Nested[Tensor] containing the cached convolution input, ssm state, - and updated time_step. - A MambaOutput instance where .data is the same shape as query. + A tuple (init_states, output): + * init_states: A Nested Tensor containing the cached convolution input, ssm state, + and updated time_step. + * output: In the prefill case, a MambaOutput instance where .data is the same shape as + query. Otherwise, if initializing cache from scratch, output will be None. """ - cfg = self.config + cfg: MambaMixerLayer.Config = self.config dtype = cfg.cache_dtype or cfg.dtype + batch_size = query.shape[0] + + if time_step is None: + init_state = dict( + conv_input=jnp.zeros((batch_size, cfg.conv.window, self.inner_dim), dtype=dtype), + state=jnp.zeros((batch_size, 1, cfg.state_dim, self.inner_dim), dtype=dtype), + time_step=jnp.zeros(batch_size, dtype=jnp.int32), + ) + return init_state, None + output = self._full_sequence_forward(query, recurrence=self.inference_recurrence) conv_input, states = output.conv_input, output.states # Pad conv input so we can take the last window timesteps that precede time_step. padded_conv_input = jnp.pad( conv_input, ((0, 0), (cfg.conv.window, 0), (0, 0)) ) # [batch_size, target_length+window, input_dim] - batch_range = jnp.arange(conv_input.shape[0]) + batch_range = jnp.arange(batch_size) time_step_range = time_step[:, None] + jnp.arange(cfg.conv.window) conv_input_cache = padded_conv_input[batch_range[:, None], time_step_range] # Pad states so we can take the step preceding time_step, even if time_step is zero. @@ -1047,38 +1046,33 @@ def forward( """ raise NotImplementedError(type(self)) - def init_states(self, *, target_batch_size: int, target_max_len: int) -> Nested[Tensor]: - """Initializes cached states for incremental computation. - - Args: - target_batch_size: The batch size for target sequences. - target_max_len: The maximum number of tokens in a target sequence. - - Returns: - A nested tree of Tensors, which can be used as `cached_states` for the initial call - of `extend_step()`. - """ - raise NotImplementedError(type(self)) - - def prefill_states( + def init_states( self, *, - time_step: Tensor, - data: Tensor, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], **_kwargs, - ) -> tuple[Nested[Tensor], BaseTransformerLayer.Output]: + ): """Initializes cached states for incremental computation. + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `data` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `data` as Tensors. + Args: - time_step: A Tensor of shape [batch]. Each value is an index into the length dimension - indicating where decoding will start from. - data: A Tensor of shape [batch, target_length, input_dim]. For batch index `i`, only - `data[i, :time_step[i], ...]` will affect subsequent decoding. + time_step: An optional Tensor of shape [batch]. Each value is an index into the length + dimension indicating where decoding will start from. + data: A Tensor or TensorSpec of shape [batch, target_length, input_dim]. For batch index + `i`, only `data[i, :time_step[i], ...]` will affect subsequent decoding. Returns: - A nested tree of Tensors, which can be used as `cached_states` for the initial call - of `extend_step()`. - A BaseTransformerLayer.Output instance, where .data is of the same shape as `data`. + A tuple (init_states, output): + * init_states: A nested tree of Tensors, which can be used as `cached_states` for the + initial call of `extend_step()`. + * output: In the prefill case, a BaseTransformerLayer.Output instance, where .data is of + the same shape as `data`. Otherwise, if initializing cache from scratch, output will + be None. """ raise NotImplementedError(type(self)) @@ -1143,7 +1137,7 @@ def _forward_for_mode( self, *, mode: ForwardMode, - data: Tensor, + data: Union[Tensor, TensorSpec], cached_states: Optional[Nested[Tensor]] = None, **_kwargs, ) -> tuple[Optional[Nested[Tensor]], BaseTransformerLayer.Output]: @@ -1154,7 +1148,7 @@ def _forward_for_mode( mode: Configures whether `cached_states` are consumed or emitted. See `axlearn.common.attention.ForwardMode` for details. data: A Tensor of shape [batch, target_length, target_dim]. - cached_states: Optional NestedTensor as produced by `prefill_states`. + cached_states: Optional Nested Tensor as produced by `init_states`. Returns: An optional NestedTensor of cache states, depending on `mode`. @@ -1163,30 +1157,42 @@ def _forward_for_mode( Raises: ValueError: If `mode` is unsupported. """ - cfg = self.config + cfg: MambaBlock.Config = self.config + + def mamba_thunk(target): + if mode == ForwardMode.FORWARD: + state, output = None, self.mamba(query=target) + elif mode == ForwardMode.INIT_STATES: + assert cached_states is not None + state, output = self.mamba.init_states( + time_step=cached_states["mamba_block"], + query=target, + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cached_states is not None + state, output = self.mamba.extend_step( + cached_states["mamba_block"], + target, + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") + return state, output + + if mode == ForwardMode.INIT_STATES: + assert cached_states is not None + if cached_states["mamba_block"] is None: + state, _ = mamba_thunk(data) + return dict(mamba_block=state), None + skip_input = data if cfg.residual_mode == BlockResidualMode.FP32: skip_input = _at_least_float32(skip_input) target = self.norm(data) - - if mode == ForwardMode.FORWARD: - state, output = None, self.mamba(query=target) - elif mode == ForwardMode.INIT_STATES: - assert cached_states is not None - state, output = self.mamba.prefill_states( - time_step=cached_states["mamba_block"], - query=target, - ) - elif mode == ForwardMode.EXTEND_STEP: - assert cached_states is not None - state, output = self.mamba.extend_step( - cached_states["mamba_block"], - target, - ) - else: - raise ValueError(f"Unrecognized mode {mode}.") + state, output = mamba_thunk(target) output = (output.data + skip_input).astype(target.dtype) - return dict(mamba_block=state), self._to_transformer_output(data=output) + output = self._to_transformer_output(data=output) + + return dict(mamba_block=state), output def forward( self, @@ -1209,41 +1215,32 @@ def forward( ) return output - def init_states(self, *, target_batch_size: int, target_max_len: int) -> Nested[Tensor]: - """Initializes cache for autoregressive cached decoding. - - Args: - target_batch_size: The batch size of the target to be decoded. - target_max_len: The sequence length of the target to be decoded. - - Returns: - The cache as a `Nested[Tensor]`. - """ - return dict( - mamba_block=self.mamba.init_states( - target_batch_size=target_batch_size, target_max_len=target_max_len - ) - ) - - def prefill_states( + def init_states( self, *, - time_step: Nested[Tensor], - data: Tensor, + time_step: Optional[Tensor], + data: Union[Tensor, TensorSpec], **_kwargs, ) -> tuple[Nested[Tensor], BaseTransformerLayer.Output]: """Initializes cache for autoregressive cached decoding. + The method supports initializing an empty cache as well as prefilling: + * To initialize an empty cache, specify `time_step=None`. + In this case, `data` is allowed to be a TensorSpec. + * To prefill, provide `time_step` and `data` as Tensors. + Args: time_step: A Tensor of shape [batch]. Each value is an index into the length dimension indicating where decoding will start from. - data: Tensor of shape [batch, target_length, target_dim] corresponding to query vector - at `time_step` indices. For batch index `i`, only `target[i, :time_step[i], ...]` - will affect subsequent decoding. + data: A Tensor or TensorSpec of shape [batch, target_length, target_dim] corresponding + to query vector at `time_step` indices. For batch index `i`, only + `target[i, :time_step[i], ...]` will affect subsequent decoding. Returns: - A `NestedTensor` state depending on the `attention` layer implementation. - An Output instance, where .data is of the same shape as data. + A tuple (init_states, output): + * init_states: A Nested Tensor state depending on the `attention` layer implementation. + * output: In the prefill case, an Output instance, where .data is of the same shape as + data. Otherwise, if initializing cache from scratch, output will be None. """ return self._forward_for_mode( mode=ForwardMode.INIT_STATES, @@ -1300,32 +1297,44 @@ def _forward_for_mode( self, *, mode: ForwardMode, - data: Tensor, + data: Union[Tensor, TensorSpec], cached_states: Optional[Nested[Tensor]] = None, **_kwargs, ) -> tuple[Optional[Nested[Tensor]], BaseTransformerLayer.Output]: cfg = self.config + + def mamba_thunk(target): + if mode == ForwardMode.FORWARD: + state, output = None, self.mamba(query=target) + elif mode == ForwardMode.INIT_STATES: + assert cached_states is not None + state, output = self.mamba.init_states( + time_step=cached_states["mamba_block"], + query=target, + ) + elif mode == ForwardMode.EXTEND_STEP: + assert cached_states is not None + state, output = self.mamba.extend_step( + cached_states["mamba_block"], + target, + ) + else: + raise ValueError(f"Unrecognized mode {mode}.") + return state, output + + # Handle the case where we initialize cache from scratch. + # `data` can be effectively treated as a TensorSpec in this case, so norm doesn't apply. + if mode == ForwardMode.INIT_STATES: + assert cached_states is not None + if cached_states["mamba_block"] is None: + state, _ = mamba_thunk(TensorSpec(shape=data.shape, dtype=data.dtype)) + return dict(mamba_block=state), None + skip_input = data if cfg.residual_mode == BlockResidualMode.FP32: skip_input = _at_least_float32(skip_input) target = self.norm(data) - - if mode == ForwardMode.FORWARD: - state, output = None, self.mamba(query=target) - elif mode == ForwardMode.INIT_STATES: - assert cached_states is not None - state, output = self.mamba.prefill_states( - time_step=cached_states["mamba_block"], - query=target, - ) - elif mode == ForwardMode.EXTEND_STEP: - assert cached_states is not None - state, output = self.mamba.extend_step( - cached_states["mamba_block"], - target, - ) - else: - raise ValueError(f"Unrecognized mode {mode}.") + state, output = mamba_thunk(target) data = (output.data + skip_input).astype(target.dtype) output = self.feed_forward(data).astype(target.dtype) # Feed-forward norms its input. return dict(mamba_block=state), self._to_transformer_output(data=output) diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 722f13d94..ab666e62d 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -15,6 +15,7 @@ """Tests Mamba and Jamba implementations.""" + import math from typing import Optional @@ -41,7 +42,7 @@ StackedSSMLayer, ) from axlearn.common.test_utils import TestCase, assert_allclose -from axlearn.common.utils import Nested, Tensor, cast_floats +from axlearn.common.utils import Nested, Tensor, TensorSpec, cast_floats # The following PyTorch Mamba implementations are adapted from: # https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/mamba/modeling_mamba.py @@ -407,7 +408,11 @@ def test_extend_step(self, dtype: jnp.dtype): prng_key=jax.random.PRNGKey(2), inputs=inputs, ) - initial_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) + initial_state, initial_output = layer.init_states( + time_step=None, + query=TensorSpec([batch_size, tgt_len]), + ) + self.assertIsNone(initial_output) for k in ["conv_input", "state"]: self.assertEqual(initial_state[k].dtype, dtype) @@ -463,7 +468,7 @@ def test_prefill_states(self, dtype: jnp.dtype): is_training=False, prng_key=jax.random.PRNGKey(3), inputs=dict(time_step=time_step, query=query), - method="prefill_states", + method="init_states", ) self.assertTrue(jnp.all(time_step == initial_states["time_step"])) for k in ["conv_input", "state"]: @@ -532,7 +537,12 @@ def _test_extend_step(layer_cfg: InstantiableConfig, *, model_dim: int, dtype: j prng_key=jax.random.PRNGKey(2), inputs=inputs, ) - initial_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) + if isinstance(layer, MambaMixerLayer): + init_kwargs = dict(query=TensorSpec([batch_size, tgt_len])) + else: + init_kwargs = dict(data=TensorSpec([batch_size, tgt_len])) + initial_state, initial_output = layer.init_states(time_step=None, **init_kwargs) + assert initial_output is None inputs = dict(cached_states=initial_state) decoder_output = jnp.zeros(shape=[tgt_len, batch_size, model_dim]) for t in range(tgt_len): @@ -588,7 +598,7 @@ def _test_prefill_states(layer_cfg: InstantiableConfig, *, model_dim: int, dtype data=query, self_attention_logit_biases=self_attention_logit_biases, ), - method="prefill_states", + method="init_states", ) # Zero-out outputs starting from initial time_step, and test that we can recover the full diff --git a/axlearn/vision/coca.py b/axlearn/vision/coca.py index 68abee4cf..dbab386db 100644 --- a/axlearn/vision/coca.py +++ b/axlearn/vision/coca.py @@ -47,7 +47,7 @@ from axlearn.common.module import Module from axlearn.common.multi_stream_model import FusionNetwork, MultiStreamModel, StreamEncoder from axlearn.common.poolings import AttentionPooling, BasePoolingLayer, LastNTokenPooling -from axlearn.common.utils import NestedTensor, Tensor +from axlearn.common.utils import NestedTensor, Tensor, TensorSpec from axlearn.common.vision_transformer import VisionTransformer, layer_norm_config from axlearn.vision.clip import CLIPFusionNetwork @@ -536,7 +536,7 @@ class Config(BaseLayer.Config): lm_head: Optional[CoCaLMHead.Config] = CoCaLMHead.default_config() - dim: Required[int] = None + dim: Required[int] = REQUIRED pad_token_id: int = 0 def __init__(self, cfg: Config, *, parent: Module): @@ -660,11 +660,12 @@ def init_states(self, *, batch_size: int, max_sequence_length: int) -> NestedTen Returns: The cache as a `NestedTensor` with key and value initialized. """ - return dict( - transformer_state=self.transformer.init_states( - target_batch_size=batch_size, target_max_len=max_sequence_length - ), + cfg = self.config + init_state, _ = self.transformer.init_states( + time_step=None, + data=TensorSpec([batch_size, max_sequence_length, cfg.dim]), ) + return dict(transformer_state=init_state) def prefill_states( self, @@ -676,7 +677,7 @@ def prefill_states( cross_attention_logit_biases: Optional[Tensor] = None, ) -> tuple[NestedTensor, NestedTensor]: cfg = self.config - transformer_state, transformer_data = self.transformer.prefill_states( + transformer_state, transformer_data = self.transformer.init_states( time_step=time_step, data=input_features, self_attention_logit_biases=self.attention_mask(