diff --git a/axlearn/audio/subsamplers.py b/axlearn/audio/subsamplers.py index bb1871eb9..378a1b294 100644 --- a/axlearn/audio/subsamplers.py +++ b/axlearn/audio/subsamplers.py @@ -7,7 +7,8 @@ from axlearn.common.base_layer import BaseLayer from axlearn.common.config import REQUIRED, Required, config_class -from axlearn.common.layers import BaseNormalizationLayer, Conv2DWith1DPadding, get_activation_fn +from axlearn.common.convolution import Conv2DWith1DPadding +from axlearn.common.layers import BaseNormalizationLayer, get_activation_fn from axlearn.common.module import Module from axlearn.common.utils import Tensor diff --git a/axlearn/common/conformer.py b/axlearn/common/conformer.py index 3899f9815..23c115856 100644 --- a/axlearn/common/conformer.py +++ b/axlearn/common/conformer.py @@ -30,9 +30,9 @@ ) from axlearn.common.base_layer import BaseLayer from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class +from axlearn.common.convolution import Conv1D from axlearn.common.layers import ( BatchNorm, - Conv1D, Dropout, GroupNorm, LayerNorm, diff --git a/axlearn/common/convolution.py b/axlearn/common/convolution.py new file mode 100644 index 000000000..2b8d25fb5 --- /dev/null +++ b/axlearn/common/convolution.py @@ -0,0 +1,1792 @@ +# Copyright © 2024 Apple Inc. +# pylint: disable=too-many-lines +"""Convolution layers.""" + +from collections.abc import Sequence +from typing import Literal, Optional, Union + +import chex +import einops +import jax +from jax import numpy as jnp + +from axlearn.common.base_layer import BaseLayer, FactorizationSpec, ParameterSpec +from axlearn.common.config import REQUIRED, Required, config_class +from axlearn.common.module import nowrap +from axlearn.common.param_init import FanAxes +from axlearn.common.utils import Tensor + +# The padding type for jax.lax.conv_general_dilated API. Either the strings ‘SAME’, or ‘VALID’, or +# 'CAUSAL' or a sequence of n (low, high) integer pairs that give the padding to apply before and +# after each spatial dimension. The number of tuple is 1 for NHC, 2 for NHWC and 3 for NHWDC. +ConvPaddingType = Union[str, Sequence[tuple[int, int]]] + +SUPPORT_CONV_PADDING = ("SAME", "VALID", "CAUSAL") + + +def _check_conv_cfg( + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], +): + if any(w < 1 for w in window): + raise ValueError(f"window ({window}) must be a positive integer.") + + if any(s < 1 for s in strides): + raise ValueError(f"strides ({strides}) must be a positive integer.") + + if isinstance(padding, str): + if padding not in SUPPORT_CONV_PADDING: + raise ValueError(f"{padding} padding is not supported.") + else: + padding_flattened = jax.tree.leaves(padding) + if any(p < 0 for p in padding_flattened): + raise ValueError("Negative padding is not supported") + + if dilation is not None and any(d < 1 for d in dilation): + raise ValueError(f"dilation ({dilation}) must be a positive integer.") + + +class BaseConv(BaseLayer): + """Base class for convolution layers.""" + + @config_class + class Config(BaseLayer.Config): + input_dim: Required[int] = REQUIRED # Input feature dim. + + # pylint: disable-next=no-self-use + def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: + if not name.endswith("weight"): + return None + if len(parameter_spec.shape) < 2: + raise NotImplementedError( + "Default _compute_fan_axes requires weight parameters to have at least 2 axes " + f"shape({name}) = {parameter_spec.shape}" + ) + # All other axes represent receptive field. + return FanAxes(in_axis=-2, out_axis=-1) + + +# Copied from jax.lax._dilate_shape +# https://github.com/jax-ml/jax/blob/2d78b172266870bd755b039f6faa2056a51930f9/jax/_src/lax/lax.py#L5763 +def conv_dilate_window(*, window: Sequence[int], dilation: Optional[Sequence[int]]): + """Returns dilated effective window size. + + Args: + window: convolution window. + dilation: convolution dilation. + + Returns: + The dilated effective window size. + """ + if dilation is None or all(d == 1 for d in dilation): + return window + + return tuple(max(1 + d * (w - 1), 0) for w, d in zip(window, dilation)) + + +# Copied from subroutine in jax.lax.reduce_window. +# Extend lax.padtype_to_pads for CAUSAL. +def conv_explicit_padding( + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]] = None, +) -> ConvPaddingType: + """Returns the explicit padding for "SAME", "VALID", and "CAUSAL" modes. + + Each mode follows the formulas below: + * SAME: (pad_total//2, pad_total - pad_total//2) s.t. pad_total = window-1 + * VALID: (0, 0) + * CAUSAL: (window - stride, stride - 1) + + Note: In the above equation, `window` will be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + + For example, window=5 and stride=2, + * SAME: padding = (2, 2) + pad| |pad + paddings: 0 0|0 0 0 0 1 1|1 1 + |___^___| + |___^___| + |___^___| + + * VALID: padding = (0, 0) + | | + paddings: |0 0 0 0 1 1| + |^_______| + + * CAUSAL: padding = (3, 1) + pad | |pad + paddings: 0 0 0|0 0 0 0 1 1|1 + |_____^_| + |_____^_| + |_____^_| + + + For example, window=5, stride=2 and dilation=2 + -> dilate_window = 9 (== (window-1)*dilation + 1) and pad_total = 8 + * SAME: padding = (4, 4) + pad| |pad + paddings: 0 0 0 0|0 0 0 0 0 0 0 0 1 1|1 1 1 1 + |_______^_______| + |_______^_______| + |_______^_______| + |_______^_______| + |_______^_______| + + * VALID: padding = (0, 0) + | |pad + paddings: |0 0 0 0 0 0 0 0 1 1| + |^_______________| + + * CAUSAL: padding = (7, 1) + pad | |pad + paddings: 0 0 0 0 0 0 0|0 0 0 0 0 0 0 0 1 1|1 + |_____________^_| + |_____________^_| + |_____________^_| + |_____________^_| + |_____________^_| + + For "CAUSAL", the first component is time and treated as "CAUSAL", while the remaining + components are handled with "SAME" padding. + + Args: + window: convolution window. + strides: convolution strides. + padding: convolution padding. + dilation: convolution dilation. + + Returns: + The padding tuple. + + Raises: + ValueError: If padding is not supported. + """ + if not isinstance(padding, str): + return padding + window = conv_dilate_window(window=window, dilation=dilation) + + def same_padding(window): + pad_total = tuple(w - 1 for w in window) + pad_left = tuple(pt // 2 for pt in pad_total) + pad_right = tuple(pt - pl for pt, pl in zip(pad_total, pad_left)) + return tuple(zip(pad_left, pad_right)) + + if padding == "SAME": + return same_padding(window) + elif padding == "VALID": + return ((0, 0),) * len(window) + elif padding == "CAUSAL": + causal_padding = ((window[0] - strides[0], strides[0] - 1),) + if len(window) > 1: + causal_padding += same_padding(window[1:]) + return causal_padding + else: + raise ValueError(f"{padding} padding is not supported.") + + +def conv_output_shape( + in_shape: Sequence[Optional[int]], + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]] = None, +) -> Sequence[int]: + """Returns output size for convolution. + + Follow https://www.tensorflow.org/api_docs/python/tf/nn/convolution + * SAME: ceil(in_size / stride) + * VALID: ceil((in_size - (window - 1) * dilation) / stride) + + Args: + in_shape: convolution lhs shape. + window: convolution window. + strides: convolution strides. + padding: convolution padding. + dilation: convolution dilation. + + Returns: + The output shape. + + Raises: + ValueError: If the length of in_shape, window, strides, and padding are not equal. + """ + if len(in_shape) != len(window) or len(in_shape) != len(strides): + raise ValueError( + f"len(in_shape) = {len(in_shape)} must be equal to " + f"len(window) = {len(window)} and len(strides) = {len(strides)}" + ) + + padding = conv_explicit_padding( + window=window, strides=strides, padding=padding, dilation=dilation + ) + pad_amount = tuple(sum(p) for p in padding) + dilate_window = conv_dilate_window(window=window, dilation=dilation) + + def output_shape(in_shape: Optional[int], dilate_window: int, pad_amount: int, stride: int): + if in_shape is None: + return None + numerator = max(in_shape + pad_amount - (dilate_window - 1), 0) + # ceil(numerator / stride) + return (numerator + stride - 1) // stride + + return tuple(map(output_shape, in_shape, dilate_window, pad_amount, strides)) + + +def compute_conv_paddings( + in_paddings: Tensor, + *, + window: int, + stride: int, + conv_padding: ConvPaddingType, + dilation: Optional[int] = None, + anchor: Optional[int] = None, +): + """Compute output paddings w.r.t. conv_padding. + + The output paddings value is determined by the padding value at the anchor point in the + window. If anchor is None, the default anchor point is the left time padding from conv + padding config. See `Conv2DWith1DPadding.Config` in details. + + Args: + in_paddings: A Tensor of shape [batch_size, seq_len]. + window: convolution window size of the time axis. + stride: convolution stride size of the time axis. + conv_padding: "SAME", "VALID", "CAUSAL" or ((left_time_padding, right_time_padding),) + dilation: convolution dilation size of the time axis. + anchor: an optional integer in the range of [left_time_padding, window - right_time_padding) + that specifies the anchor position within the convolution window that is used to + determine output paddings. Specifically, the output token is valid iff the input token + at the anchor position of the corresponding window is valid. + If None, anchor defaults to conv_padding[0] (i.e. left_time_padding). + + Returns: + out_paddings: A Tensor of shape [batch_size, seq_len]. + + Raises: + ValueError: If anchor is not between left_time_padding and right_time_padding. + """ + chex.assert_rank(in_paddings, 2) + dilation = dilation or 1 + conv_padding = conv_explicit_padding( + window=(window,), strides=(stride,), padding=conv_padding, dilation=(dilation,) + ) + window = conv_dilate_window(window=(window,), dilation=(dilation,))[0] + left_pad, right_pad = conv_padding[0] + pad_total = window - 1 + + if anchor is None: + # valid_window = pad_total - left_pad - right_pad + # anchor_global = valid_window // 2 + # anchor = anchor_global + left_pad + anchor = left_pad + elif not left_pad <= anchor < window - right_pad: + raise ValueError(f"anchor ({anchor}) must in range [{left_pad}, {window - right_pad}).") + + # This is a method to avoid using jax.pad, by leveraging the property that the valid_window + # is always within the input sequence. + # Note: transform anchor from window frame to input sequence frame. + start_index = anchor - left_pad + valid_window = pad_total - left_pad - right_pad + valid_window_right_pad = valid_window - start_index + seq_len = in_paddings.shape[1] + limit_index = max(seq_len - valid_window_right_pad, start_index) + if seq_len < start_index: + start_index = 0 + limit_index = 0 + out_paddings = jax.lax.slice_in_dim( + in_paddings, start_index=start_index, limit_index=limit_index, stride=stride, axis=1 + ) + return out_paddings + + +class Conv1D(BaseConv): + """The 1D convolution layer. + + Kernel weights have the WIO layout and in the shape of (window, input_dim, output_dim). + Both inputs and outputs will be in the NWC layout. + """ + + @config_class + class Config(BaseConv.Config): + """Configures Conv1D.""" + + window: Required[int] = REQUIRED # The convolution window. + strides: int = 1 # The convolution strides. + # Paddings: "SAME", "VALID", "CAUSAL", or (left, right). + # For causal convolution, set padding to (window - 1, 0). + padding: ConvPaddingType = ((0, 0),) + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + # The number of groups in which the input is split along the channel axis. + # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, + # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). + # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers + # side by side, each seeing half the input and producing half the output channels. + # - At num_input_dim_groups=input_dim, each input channel is convolved with its own + # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, + # where K is a positive integer, the operation is also known as a "depthwise convolution". + num_input_dim_groups: Optional[int] = 1 + # The convolution dilation, indicating dilation factor applied to the weight. It is also + # known as atrous convolution or dilated convolution. If None, assume 1. + dilation: Optional[int] = None + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + dilation = cfg.dilation or 1 + _check_conv_cfg( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(dilation,), + ) + if cfg.padding not in SUPPORT_CONV_PADDING: + left, right = cfg.padding[0] + if any(p < 0 for p in (left, right)): + raise NotImplementedError("Negative padding is not supported") + params = dict( + weight=ParameterSpec( + shape=[cfg.window, cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], + mesh_axes=cfg.param_partition_spec, + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + dilation = cfg.dilation or 1 + conv_padding = conv_explicit_padding( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(dilation,), + ) + return self._conv(x=x, strides=(cfg.strides,), padding=conv_padding, dilation=(dilation,)) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_general_dilated( + lhs=x, + rhs=self.parameters["weight"], + window_strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=("NWC", "WIO", "NWC"), + feature_group_count=cfg.num_input_dim_groups, + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 3: + raise ValueError(f"We expect len(input_shape) = 3, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:2] + dilation = cfg.dilation or 1 + out_shape = conv_output_shape( + in_shape, + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(dilation,), + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +class Conv1DWithPadding(Conv1D): + """The 1-D convolution with 1-D padding on the time axis.""" + + @config_class + class Config(Conv1D.Config): + """Configures Conv1DWithPadding.""" + + # An optional integer in the range of [left_time_padding, window - right_time_padding) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See Conv2DWith1DPadding more details. + anchor: Optional[int] = None + + # We add a kwargs "paddings" to the forward method. + # pylint: disable-next=arguments-differ + def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: + """Computes convolution outputs and paddings. + + Args: + x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + + Returns: + output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + """ + cfg = self.config + chex.assert_rank(x, paddings.ndim + 1) + # Apply padding to the input. + x = x * (1 - paddings[..., None]) + + # Apply Conv1D. + output = super().forward(x) + + # Compute paddings conv output. + output_paddings = compute_conv_paddings( + paddings, + window=cfg.window, + stride=cfg.strides, + conv_padding=cfg.padding, + dilation=cfg.dilation, + anchor=cfg.anchor, + ) + # Apply padding to the outputs. + output = output * (1 - output_paddings[..., None]) + return output, output_paddings + + +# The accuracy of the output of this layer currently doesn't match that of PyTorch +# quite as closely as we would like. See layers_test.py:test_conv2d(). +class Conv2D(BaseConv): + """The 2-D convolution layer. + + Kernel weights have the HWIO layout and in the shape of (window[0], window[1], input_dim, + output_dim). Both inputs and outputs will be in the NHWC layout. + """ + + @config_class + class Config(BaseConv.Config): + """Configures Conv2D.""" + + window: tuple[int, int] = (1, 1) # The convolution window. + strides: tuple[int, int] = (1, 1) # The convolution strides. + # Paddings: "SAME", "VALID", "CAUSAL" or ((top, bottom), (left, right)). + # Note: Sequence models use the first component to represent time. + padding: ConvPaddingType = ((0, 0), (0, 0)) + # The convolution dilation. If None, assume all 1's. + dilation: Optional[tuple[int, int]] = None + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + # The number of groups in which the input is split along the channel axis. + # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, + # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). + # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers + # side by side, each seeing half the input and producing half the output channels. + # - At num_input_dim_groups=input_dim, each input channel is convolved with its own + # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, + # where K is a positive integer, the operation is also known as a "depthwise convolution". + num_input_dim_groups: Optional[int] = 1 + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, None, None) + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + params = dict( + weight=ParameterSpec( + shape=list(cfg.window) + + [cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = conv_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_general_dilated( + lhs=x, + rhs=self.parameters["weight"], + window_strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=("NHWC", "HWIO", "NHWC"), + feature_group_count=cfg.num_input_dim_groups, + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 4: + raise ValueError(f"We expect len(input_shape) = 4, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:3] + out_shape = conv_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +class Conv2DWith1DPadding(Conv2D): + """The 2-D convolution with 1-D padding on the time axis.""" + + @config_class + class Config(Conv2D.Config): + """Configures Conv2DWith1DPadding. + + The output paddings value is determined by the padding value at the anchor point in the + window. If anchor is None, the default anchor point is the left time padding from conv + padding config. + + For examples with window=5, + 1. "SAME" padding case, + * padding=(2,2): (0 0 0 0 0) + * anchor index is 2: (0 0 |0| 0 0) + pad | | pad + paddings: 0 0|0 0 0 1 1 1|1 1 + |___0___| + |___0___| + |___0___| + |___1___| + |___1___| + |___1___| + + 2. "VALID" padding case, + * padding=(0,0): (0 0 0 0 0) + * anchor index is 0: (|0| 0 0 0 0) + pad | | pad + paddings: |0 0 0 1 1 1| + |0_______| + |0_______| + + 3. The legacy "VALID" padding case, + * padding=(0,0) and anchor=4: (0 0 0 0 0) + * anchor index is 4: (0 0 0 0 |0|) + pad | | pad + paddings: |0 0 0 1 1 1| + |________1| + |________1| + + 4. "CAUSAL" padding case, + * padding=(4,0): (0 0 0 0 0) + * anchor index is 4: (0 0 0 0 |0|) + pad | | pad + paddings: 0 0 0 0|0 0 0 1 1 1| + |_______0| + |_______0| + |_______0| + |_______1| + |_______1| + |_______1| + + 5. "CAUSAL" with lookahead=1, + * padding=(3, 1): (0 0 0 0 0) + * anchor index is 3: (0 0 0 |0| 0) + pad | | pad + paddings: 0 0 0|0 0 0 1 1 1|1 + |_____0_| + |_____0_| + |_____0_| + |_____1_| + |_____1_| + |_____1_| + + 6. Arbitrary padding case, + * padding=(2,1): (0 0 0 0 0) + * anchor index is 2: (0 0 |0| 0 0) + pad | | pad + paddings: 0 0|0 0 0 1 1 1|1 + |___0___| + |___0___| + |___0___| + |___1___| + |___1___| + """ + + # An optional integer in the range of [left_time_padding, window - right_time_padding) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. + anchor: Optional[int] = None + + # We add a kwargs "paddings" to the forward method. + # pylint: disable-next=arguments-differ + def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: + """Computes convolution outputs and paddings. + + Args: + x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + + Returns: + output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + """ + cfg = self.config + # Apply padding to the input. + assert len(x.shape) == len(paddings.shape) + 2 + x = x * (1 - paddings[..., None, None]) + + # Apply Conv2D. + output = super().forward(x) + # Compute paddings conv output. + dilation = 1 if cfg.dilation is None else cfg.dilation[0] + output_paddings = compute_conv_paddings( + paddings, + window=cfg.window[0], + stride=cfg.strides[0], + conv_padding=cfg.padding, + dilation=dilation, + anchor=cfg.anchor, + ) + # Apply padding to the outputs. + output = output * (1 - output_paddings[..., None, None]) + return output, output_paddings + + +class Conv3D(BaseConv): + """The 3-D convolution layer. + + Kernel weights have the HWDIO layout and in the shape of (window[0], window[1], + window[2], input_dim, output_dim). Both inputs and outputs will be in the NHWDC layout. + """ + + @config_class + class Config(BaseConv.Config): + """Configures Conv3D.""" + + window: tuple[int, int, int] = (1, 1, 1) # The convolution window. + strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. + + # Paddings: "SAME" or "VALID, or ((top, bottom), (left, right), (front, back)) + padding: ConvPaddingType = ( + (0, 0), + (0, 0), + (0, 0), + ) + # The convolution dilation. If None, assume all 1's. + dilation: Optional[tuple[int, int, int]] = None + + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + + # The number of groups in which the input is split along the channel axis. + # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, + # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). + # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers + # side by side, each seeing half the input and producing half the output channels. + # - At num_input_dim_groups=input_dim, each input channel is convolved with its own + # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, + # where K is a positive integer, the operation is also known as a "depthwise convolution". + num_input_dim_groups: Optional[int] = 1 + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, None, None, None) + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + params = dict( + weight=ParameterSpec( + shape=list(cfg.window) + + [cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = conv_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_general_dilated( + lhs=x, + rhs=self.parameters["weight"], + window_strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), + feature_group_count=cfg.num_input_dim_groups, + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 5: + raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:4] + out_shape = conv_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +############################## Transposed Convolution ############################################## + + +# Based on jax.lax.convolution._conv_transpose_padding, but ours is more intuitive. +def conv_transpose_explicit_padding( + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], +) -> ConvPaddingType: + """Convert str padding to tuple padding for conv_transpose. + + Each mode follows the formulas below, + * SAME: (min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) + pad_total = window+stride-2 + when stride > window -> (window-1, stride-1) + * VALID: (window-1, max(stride-1, window-1)) + pad_total = window+stride-2 + max(window-stride, 0) + when stride > window -> (window-1, stride-1) + * CAUSAL: (window-1, stride-1) + pad_total = window+stride-2 + + Note: output_size = input_size*stride - (window+stride-2) + pad_total + = input_size*stride <- "SAME", "CAUSAL" + = input_size*stride + max(window-stride, 0) <- "VALID" + + Note: In the above equation, `window` will be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + + The following illustration demonstrates how Conv Transpose operates, assuming all kernel values + are set to 1 for simplicity in showcasing output values. + + In the window=3 and stride=1 case, this function creates outputs as follows: + * "SAME" padding=(1, 1) + pad| |pad + paddings: 0|0 0 1 1|0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + 1 1 0 -> 2 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 0 1 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + 1 1 0 -> 2 + 1 0 0 -> 1 + + * "CAUSAL" padding=(2, 0) + pad | |pad + paddings: 0 0|0 0 1 1| + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + + In the window=3 and stride=2 case, this function creates outputs as follows: + * "SAME" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + 1 0 0 -> 1 + + * "CAUSAL" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + + In the window=3 and stride=3 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * * 0 * * 1 * * 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + + In the window=3 and stride=4 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 3) + pad | |pad + paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 0 -> 0 + Here is how to compute output_size, given the above example, + 1. |_| -(window-1) + 2. |_______________________| (input_size-1)*stride + 1 + 3. |_| |___| + pad_total + + So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total + = input_size*stride - (window+stride-2) + pad_total + = input_size*stride <- "SAME", "CAUSAL" + = input_size*stride + max(window-stride, 0) <- "VALID" + + OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. + For example, when window=3 and dilation=2, dilate_window=5. + + In the stride=2 case, this function creates outputs as follows: + * "SAME" padding=(3, 2) + pad | |pad + paddings: 0 0 0|0 * 0 * 1 * 1|0 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + 1 * 1 * 0 -> 2 + + * "VALID" padding=(4, 4) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + 1 * 1 * 0 -> 2 + 0 * 0 * 0 -> 0 + 1 * 0 * 0 -> 1 + + * "CAUSAL" padding=(4, 1) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + + For "CAUSAL", the first component is time and treated as "CAUSAL", while the remaining + components are handled with "SAME" padding. + + Args: + window: convolution window. + strides: transposed convolution strides. It's lhs_dilation, not window_stride. + padding: convolution padding. + dilation: convolution dilation, a.k.a rhs_dilation. + + Returns: + The padding tuple. + + Raises: + ValueError: If padding is not supported. + """ + if not isinstance(padding, str): + return padding + + window = conv_dilate_window(window=window, dilation=dilation) + + def same_padding(window, strides): + pad_left = tuple(min(w - 1, (w + s - 1) // 2) for w, s in zip(window, strides)) + pad_right = tuple(max(s - 1, (w + s - 2) // 2) for w, s in zip(window, strides)) + return tuple(zip(pad_left, pad_right)) + + if padding == "SAME": + return same_padding(window, strides) + elif padding == "VALID": + pad_left = tuple(w - 1 for w in window) + pad_right = tuple(max(s - 1, w - 1) for w, s in zip(window, strides)) + return tuple(zip(pad_left, pad_right)) + elif padding == "CAUSAL": + causal_padding = ((window[0] - 1, strides[0] - 1),) + if len(window) > 1: + causal_padding += same_padding(window[1:], strides[1:]) + return causal_padding + else: + raise ValueError(f"{padding} padding is not supported.") + + +def conv_transpose_output_shape( + in_shape: Sequence[Optional[int]], + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], +) -> Sequence[int]: + """Returns output size for conv transpose. + + Each mode follows the formulas below, + * SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) + pad_total = window+stride-2 + output_size = input_size*stride + * VALID: padding=(window-1, max(stride-1, window-1)) + pad_total = window+stride-2 + max(window-stride, 0) + output_size = input_size*stride + max(window-stride, 0) + * CAUSAL: padding=(window-1, stride-1) + pad_total = window+stride-2 + output_size = input_size*stride + + Note: In the above equation, `window` will be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + + Refer to + https://towardsdatascience.com/understand-transposed-convolutions-and-build-your-own-transposed-convolution-layer-from-scratch-4f5d97b2967 + + Args: + in_shape: convolution lhs shape. + window: convolution window. + strides: convolution strides. + padding: convolution padding. + dilation: convolution dilation. + + Returns: + The output shape. + + Raises: + ValueError: If the length of in_shape, window, strides, and padding are not equal. + """ + if len(in_shape) != len(window) or len(in_shape) != len(strides): + raise ValueError( + f"len(in_shape) = {len(in_shape)} must be equal to " + f"len(window) = {len(window)} and len(strides) = {len(strides)}" + ) + + window = conv_dilate_window(window=window, dilation=dilation) + + def output_shape(in_shape: Optional[int], window: int, stride: int): + if in_shape is None: + return None + + if padding == "SAME": + return in_shape * stride + elif padding == "VALID": + return in_shape * stride + max(window - stride, 0) + elif padding == "CAUSAL": + return in_shape * stride + else: + raise ValueError(f"{padding} padding is not supported.") + + return tuple(map(output_shape, in_shape, window, strides)) + + +def compute_conv_transpose_paddings( + in_paddings: Tensor, + *, + window: int, + stride: int, + conv_padding: ConvPaddingType, + dilation: int = 1, + anchor: Optional[int] = None, +): + """Compute output paddings w.r.t. conv_padding for conv transpose. + + The output paddings value is determined by the padding value at the anchor point in the + window. If anchor is None, the default anchor point is the left time padding from conv + padding config. See `Conv2DWith1DPadding.Config` in details. + + In the window=3 and stride=1 case, this function creates paddings as follows: + + The following illustration demonstrates how Conv Transpose operates, assuming all kernel values + are set to 1 for simplicity in showcasing output values. + + In the window=3 and stride=1 case, this function creates outputs as follows: + * "SAME" padding=(1, 1) + pad| |pad + paddings: 0|0 0 1 1|1 + |_____| + * 0 * -> 0 + * 0 * -> 0 + * 1 * -> 1 + * 1 0 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 0 1 1|1 1 + |_________| + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "CAUSAL" padding=(2, 0) + pad | |pad + paddings: 0 0|0 0 1 1| + |_____| + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + + In the window=3 and stride=2 case, this function creates outputs as follows: + * "SAME" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 1 + |_______________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "CAUSAL" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + In the window=3 and stride=3 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * * 0 * * 1 * * 1|1 1 + |_____________________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. + For example, when window=3 and dilation=2, dilate_window=5. + + In the stride=2 case, this function creates outputs as follows: + * "SAME" padding=(3, 2) + pad | |pad + paddings: 0 0 0|0 * 0 * 1 * 1|1 1 + |_____________| + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 1 * -> 1 + * * * 1 * -> 1 + * * * 1 * -> 1 + * * * 1 * -> 1 + + * "VALID" padding=(4, 4) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|1 1 1 1 + |___________________| + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + + * "CAUSAL" padding=(4, 1) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + + Args: + in_paddings: A Tensor of shape [batch_size, seq_len]. + window: convolution window size of the time axis. + stride: convolution stride size of the time axis. + conv_padding: "SAME", "VALID", "CAUSAL" or ((left_time_padding, right_time_padding),) + dilation: convolution dilation size of the time axis. + anchor: an optional integer in the range of [0, window) + that specifies the anchor position within the convolution window that is used to + determine output paddings. Specifically, the output token is valid iff the input token + at the anchor position of the corresponding window is valid. + If None, anchor defaults to conv_padding[0] (i.e. left_time_padding). + + Returns: + out_paddings: A Tensor of shape [batch_size, seq_len]. + + Raises: + ValueError: If anchor is not between left_time_padding and window. + """ + + chex.assert_rank(in_paddings, 2) + conv_padding = conv_transpose_explicit_padding( + window=(window,), strides=(stride,), padding=conv_padding, dilation=(dilation,) + ) + window = conv_dilate_window(window=(window,), dilation=(dilation,))[0] + # Note: in transposed conv, left_pad + right_pad >= window - 1. + # See conv_transpose_explicit_padding(). + left_pad, right_pad = conv_padding[0] + + if anchor is None: + anchor = left_pad + # elif not left_pad <= anchor < window: + elif not anchor < window: + raise ValueError(f"anchor ({anchor}) must in range [0, {window}).") + + # Consider the case where window=3, strides=2, dilation=2, and padding="SAME" + # explicit padding=(3, 2) + # pad | |pad + # paddings: 0 0 0|0 * 0 * 1 * 1|1 1 + # |_____________| + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + + # |0 0 1 1| -> |0 * 0 * 1 * 1| + def dilate_paddings(paddings): + most, last = jnp.split(paddings, [paddings.shape[1] - 1], axis=1) + dilated = einops.repeat(most, "b t -> b (t s)", s=stride) + return jnp.concatenate([dilated, last], axis=1) + + in_paddings = dilate_paddings(in_paddings) + + # |0 * 0 * 1 * 1| -> 0 0 0|0 * 0 * 1 * 1|1 1 + # |_____________| which is |0 * 0 * 1 * 1|1 + window_pad_total = window - 1 # Note: we already check `anchor < window`` always. + window_right_pad = window_pad_total - anchor + assert window_right_pad >= 0, f"{anchor=} < {window=} always." + # Note: left_pad + right_pad >= window + stride - 2 >= window - 1 == anchor + window_right_pad + valid_right_pad = right_pad - window_right_pad + if valid_right_pad >= 0: + out_paddings = jnp.pad(in_paddings, ((0, 0), (0, valid_right_pad)), mode="edge") + else: + out_paddings = in_paddings[:, :valid_right_pad] + + start_index = anchor - left_pad + if start_index < 0: + out_paddings = jnp.pad(out_paddings, ((0, 0), (-start_index, 0)), mode="edge") + else: + out_paddings = out_paddings[:, start_index:] + return out_paddings + + +class Conv1DTranspose(BaseConv): + """The 1D transposed convolution layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv1DTranspose.""" + + window: int = 1 + strides: int = 1 + padding: Required[ConvPaddingType] = REQUIRED + dilation: int = 1 # Dilation for dilated Convolution. + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + + # An optional integer in the range of [0, window) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See compute_conv_transpose_paddings more details. + anchor: Optional[int] = None + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) + params = dict( + weight=ParameterSpec( + shape=(cfg.window, cfg.input_dim, cfg.output_dim), + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward( + self, x: Tensor, *, paddings: Optional[Tensor] = None + ) -> tuple[Tensor, Optional[Tensor]]: + cfg = self.config + conv_padding = conv_transpose_explicit_padding( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) + + if paddings is not None: + chex.assert_rank(x, paddings.ndim + 1) + # Apply padding to the input. + x = x * (1 - paddings[..., None]) + + output = self._conv( + x=x, strides=(cfg.strides,), padding=conv_padding, dilation=(cfg.dilation,) + ) + + if paddings is None: + output_paddings = None + else: + # Compute paddings conv output. + output_paddings = compute_conv_transpose_paddings( + paddings, + window=cfg.window, + stride=cfg.strides, + conv_padding=cfg.padding, + dilation=cfg.dilation, + anchor=cfg.anchor, + ) + output = output * (1 - output_paddings[..., None]) + return output, output_paddings + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_transpose( + lhs=x, + rhs=self.parameters["weight"], + strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=("NWC", "WIO", "NWC"), + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 3: + raise ValueError(f"We expect len(input_shape) = 3, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + "cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:2] + out_shape = conv_transpose_output_shape( + in_shape, + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +# TODO(dhwang2): move to convolution transpose section. +class Conv2DTranspose(BaseConv): + """The 2-D transposed convolution layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv2DTranspose.""" + + window: tuple[int, int] = (1, 1) + strides: tuple[int, int] = (1, 1) + padding: Required[ConvPaddingType] = REQUIRED + dilation: tuple[int, int] = (1, 1) + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + # If True, kernel weights have the HWOI layout, following the format used by + # keras.layers.Conv2DTranspose. + # Otherwise, the standard layout HWIO is used, which is more efficient. + transpose_kernel: bool = False + + @classmethod + def default_config(cls): + cfg = super().default_config() + if cfg.transpose_kernel: + cfg.param_partition_spec = (None, None, "model", None) + else: + cfg.param_partition_spec = (None, None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + if cfg.transpose_kernel: + io_shape = (cfg.output_dim, cfg.input_dim) + else: + io_shape = (cfg.input_dim, cfg.output_dim) + params = dict( + weight=ParameterSpec( + shape=tuple(cfg.window) + io_shape, + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = conv_transpose_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_transpose( + lhs=x, + rhs=self.parameters["weight"], + strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=("NHWC", "HWIO", "NHWC"), + transpose_kernel=cfg.transpose_kernel, + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 4: + raise ValueError(f"We expect len(input_shape) = 4, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + "cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:3] + out_shape = conv_transpose_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +class Conv2DTransposeWith1DPadding(Conv2DTranspose): + """The 2-D convolution transpose with 1-D padding on the time axis.""" + + @config_class + class Config(Conv2DTranspose.Config): + """Configures Conv2DTransposeWith1DPadding.""" + + transpose_kernel: bool = False + # An optional integer in the range of [0, window) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See compute_conv_transpose_paddings more details. + anchor: Optional[int] = None + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.transpose_kernel = False # Choose better one unlike parent. + return cfg + + # We add a kwargs "paddings" to the forward method. + # pylint: disable-next=arguments-differ + def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: + """Computes convolution outputs and paddings. + + Args: + x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + + Returns: + output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + """ + cfg = self.config + # Apply padding to the input. + assert len(x.shape) == len(paddings.shape) + 2 + x = x * (1 - paddings[..., None, None]) + + # Apply Conv2D. + output = super().forward(x) + # Compute paddings conv output. + output_paddings = compute_conv_transpose_paddings( + paddings, + window=cfg.window[0], + stride=cfg.strides[0], + conv_padding=cfg.padding, + dilation=cfg.dilation[0], + anchor=cfg.anchor, + ) + # Apply padding to the outputs. + output = output * (1 - output_paddings[..., None, None]) + return output, output_paddings + + +class Conv3DTranspose(BaseConv): + """The 3-D convolution transpose layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv3DTranspose.""" + + window: tuple[int, int, int] = (1, 1, 1) # The convolution window. + strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. + # Paddings: "SAME", "VALID or "CAUSAL", or ((top, bottom), (left, right), (front, back)) + padding: Required[ConvPaddingType] = REQUIRED + dilation: tuple[int, int, int] = (1, 1, 1) # The convolution dilation. + + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + params = dict( + weight=ParameterSpec( + shape=cfg.window + (cfg.input_dim, cfg.output_dim), + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = conv_transpose_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_transpose( + lhs=x, + rhs=self.parameters["weight"], + strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 5: + raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:4] + out_shape = conv_transpose_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +############################## Others ############################################################## + + +class StackOverTime(BaseLayer): + """Stack inputs along the time axis. + + StackOverTime behaves the same as Conv2DWith1DPadding w.r.t. paddings along the time axis. + Please refer to the docstring of Conv2DWith1DPadding to understand how the padding work + including "SAME", "VALID", and "CAUSAL" literals. The padding anchor is set to `left padding`. + """ + + @config_class + class Config(BaseLayer.Config): + """Configures StackOverTime.""" + + stride: Required[int] = REQUIRED # Number of frames to stack. + + # Number of paddings to apply along the time axis. The two integers specify the amount + # of leading and trailing padding, respectively. Alternatively, this can be a + # convolution padding literals type such as 'SAME', 'VALID', or 'CAUSAL'. + # Note: For backward compatibility, the default is set to VALID, but in most cases, + # CAUSAL is more appropriate as it preserves the sequence length. + padding: Union[tuple[int, int], Literal["SAME", "VALID", "CAUSAL"]] = "VALID" + + def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: + """Stacks stride number of frames into one frame along the time axis. + + Args: + inputs: Tensor of shape [batch, time, input_dim]. + paddings: 0/1 Tensor of shape [batch, time], paddings of the input sequences. + + Returns: + stacked_inputs: Tensor of shape [batch, time // stride, input_dim * stride]. + stacked_paddings: 0/1 Tensor of shape [batch, time // stride]. An output frame + is padding if at least one of the stacked input frames is padding. + + Raises: + ValueError: If stride is <= 1. + """ + cfg = self.config + if cfg.stride <= 1: + raise ValueError(f"stride should be greater than 1, but got {cfg.stride}.") + + # For the last partial frame. + inputs = inputs * (1 - paddings)[:, :, None] + + padding = cfg.padding + if isinstance(padding, str): + padding = conv_explicit_padding( + window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) + )[0] + inputs = jnp.pad(inputs, ((0, 0), padding, (0, 0)), constant_values=0) + + batch_size, seq_len, input_dim = inputs.shape + output_length = seq_len // cfg.stride + new_shape = [batch_size, output_length, input_dim * cfg.stride] + # Stack inputs over the time dimension. + stacked_inputs = jnp.reshape(inputs[:, : output_length * cfg.stride, :], new_shape) + # An output frame is padding if at least one of the stacked input frames is padding. + stacked_paddings = compute_conv_paddings( + paddings, window=cfg.stride, stride=cfg.stride, conv_padding=(padding,) + ) + stacked_inputs = stacked_inputs * (1 - stacked_paddings)[:, :, None] + return stacked_inputs, stacked_paddings + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + """Computes stacked output shape. + + Args: + input_shape: The input dimensions are (batch, time, feature_dim). + If the value of the dimension is not available, use None. + + Returns: + The output shape. The dimensions are (batch, time, feature_dim). + """ + cfg = self.config + batch_size, seq_len, input_dim = input_shape + padding = cfg.padding + if isinstance(padding, tuple): + padding = (padding,) + out_shape = conv_output_shape( + [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) + ) + return [batch_size, *out_shape, input_dim * cfg.stride] diff --git a/axlearn/common/convolution_test.py b/axlearn/common/convolution_test.py new file mode 100644 index 000000000..14626daac --- /dev/null +++ b/axlearn/common/convolution_test.py @@ -0,0 +1,1897 @@ +# Copyright © 2024 Apple Inc. +"""Tests convolution layers.""" + +# pylint: disable=no-self-use +from typing import Optional, Union + +import einops +import jax.random +import numpy as np +import torch +from absl.testing import absltest, parameterized +from jax import numpy as jnp + +from axlearn.common import convolution, utils +from axlearn.common.convolution import ( + Conv1D, + Conv1DWithPadding, + Conv2D, + Conv2DTranspose, + Conv2DWith1DPadding, + Conv3D, + ConvPaddingType, + StackOverTime, + compute_conv_paddings, +) +from axlearn.common.module import functional as F +from axlearn.common.param_converter import as_torch_tensor +from axlearn.common.test_utils import TestCase, assert_allclose +from axlearn.common.utils import shapes + + +def _copy(src: jnp.ndarray, dst: torch.nn.Parameter): + with torch.no_grad(): + src = np.asarray(src).copy() + src = torch.as_tensor(src) + dst.copy_(src) + + +class ConvTest(TestCase): + @parameterized.parameters((1, 1, 1), (1, 2, 1), (2, 1, 2), (3, 1, 3), (3, 2, 5)) + def test_conv_dilate_window(self, window, dilation, expected): + effective_window = convolution.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + self.assertEqual(effective_window, expected) + + @parameterized.parameters( + (10, 3, 1, "SAME", 1, 10), + (10, 3, 2, "SAME", 1, 5), + (10, 3, 1, "SAME", 2, 10), + (10, 3, 2, "SAME", 2, 5), + (10, 3, 1, "VALID", 1, 8), + (10, 3, 2, "VALID", 1, 4), + (10, 3, 1, "VALID", 2, 6), + (10, 3, 2, "VALID", 2, 3), + (10, 3, 1, "CAUSAL", 1, 10), + (10, 3, 2, "CAUSAL", 1, 5), + (10, 3, 1, "CAUSAL", 2, 10), + (10, 3, 2, "CAUSAL", 2, 5), + ) + def test_conv_output_shape(self, in_shape, window, strides, padding, dilation, expected): + out_shape = convolution.conv_output_shape( + in_shape=(in_shape,), + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + )[0] + self.assertEqual(out_shape, expected) + + @parameterized.parameters( + ([0, 0, 0, 1], [0, 0, 0, 1], 1, "SAME"), + ([0], [], 1, "VALID"), + ([0, 0], [], 1, "VALID"), + ([0, 0, 0], [0], 1, "VALID"), + ([0, 0, 0, 0], [0, 0], 1, "VALID"), + ([0, 0, 0, 1], [0, 0], 1, "VALID"), + ([0, 0, 0, 0], [0], 2, "VALID"), + ([0, 0, 0, 1], [0], 2, "VALID"), + ([0, 0, 1, 1], [0], 2, "VALID"), + ([0, 0, 0, 0, 0], [0, 0], 2, "VALID"), + ([0, 0, 0, 0, 1], [0, 0], 2, "VALID"), + ([0, 0, 0, 1, 1], [0, 0], 2, "VALID"), + ([0, 0, 1, 1, 1], [0, 1], 2, "VALID"), + ([0, 0, 0, 0, 0, 0], [0, 0], 2, "VALID"), + ([0, 0, 0, 0, 0, 1], [0, 0], 2, "VALID"), + ([0, 0, 0, 0, 1, 1], [0, 0], 2, "VALID"), + ([0, 0, 0, 1, 1, 1], [0, 0], 2, "VALID"), + ([0, 0, 1, 1, 1, 1], [0, 1], 2, "VALID"), + ) + def test_conv_padding(self, input_paddings, expected_paddings, stride: int, padding_cfg: str): + """Tests conv_output_shape() with SAME and VALID padding cfg.""" + # This test is from lingvo + # https://github.com/tensorflow/lingvo/blob/master/lingvo/core/conv_layers_with_time_padding_test.py#L157. + window = 3 + out_paddings = compute_conv_paddings( + jnp.array([input_paddings]), window=window, stride=stride, conv_padding=padding_cfg + ) + assert_allclose(out_paddings[0], expected_paddings) + + @parameterized.parameters( + (5, 1, "SAME", 1, (2, 2)), + (5, 2, "SAME", 1, (2, 2)), + (5, 3, "SAME", 1, (2, 2)), + (5, 1, "SAME", 2, (4, 4)), + (5, 2, "SAME", 2, (4, 4)), + (5, 3, "SAME", 2, (4, 4)), + (5, 1, "VALID", 1, (0, 0)), + (5, 2, "VALID", 1, (0, 0)), + (5, 3, "VALID", 1, (0, 0)), + (5, 1, "VALID", 2, (0, 0)), + (5, 2, "VALID", 2, (0, 0)), + (5, 3, "VALID", 2, (0, 0)), + (5, 1, "CAUSAL", 1, (4, 0)), + (5, 2, "CAUSAL", 1, (3, 1)), + (5, 3, "CAUSAL", 1, (2, 2)), + (5, 1, "CAUSAL", 2, (8, 0)), + (5, 2, "CAUSAL", 2, (7, 1)), + (5, 3, "CAUSAL", 2, (6, 2)), + ) + def test_conv_explicit_padding( + self, window: int, stride: int, padding: ConvPaddingType, dilation: int, expected + ): + """Tests the cases in conv_explicit_padding() description.""" + explicit_padding = convolution.conv_explicit_padding( + window=(window,), + strides=(stride,), + padding=padding, + dilation=(dilation,), + ) + assert_allclose(explicit_padding[0], expected) + + @parameterized.parameters( + (5, 1, "SAME", [0, 0, 0, 0, 1, 1]), + (5, 2, "SAME", [0, 0, 1]), + (5, 1, "VALID", [0, 0]), + (5, 2, "VALID", [0]), + (5, 1, "SAME", [0, 0, 0, 0, 1, 1]), + (5, 2, "SAME", [0, 0, 1]), + ) + def test_conv_output_1d_padding_simple( + self, window: int, stride: int, padding: ConvPaddingType, expected + ): + """Tests the cases in conv_explicit_padding() description.""" + paddings = jnp.array([[0, 0, 0, 0, 1, 1]]) + out_paddings = compute_conv_paddings( + paddings, window=window, stride=stride, conv_padding=padding + ) + assert_allclose(out_paddings[0], expected) + + @parameterized.parameters( + ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0]), + ([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0]), + ([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0]), + ([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0]), + ([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1]), + ([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 1, 1]), + ([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 1, 1]), + ([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 1]), + ) + def test_conv_output_1d_padding_causal(self, in_paddings, expected): + """Test the below cases. + + The formula for CAUSAL padding is `(window - stride, stride - 1)`. + With window=15 and stride=6, padding is (9, 5). + Below are examples illustrating how input paddings are transformed into output + paddings across different scenarios. + + left_pad | input paddings -> outputs paddings + 1) |1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|0 0 0|0 0 0|0 0 0|0 0 0| -> 0 0 0 + 2) |1 1 1|1 1 1|1 1 1|1 0 0|0 0 0|0 0 0|0 0 0|0 0 0|0 0 0| -> 1 0 0 + 3) |1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|0 0 0|0 0 0| -> 1 0 0 + 4) |1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 0 0|0 0 0|0 0 0|0 0 0| -> 1 1 0 + 5) |1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1| -> 1 1 1 + 6) |1 1 1|1 1 1|1 1 1|0 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1| -> 0 1 1 + 7) |1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|1 1 1|1 1 1|1 1 1|1 1 1| -> 0 1 1 + 8) |1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|0 1 1|1 1 1|1 1 1|1 1 1| -> 0 0 1 + |_________________^_________| + |_________________^_________| + |_________________^_________| + + Let's take a closer look at case 7). In case 7), the first window component fully + covers all 0s, so the first component of the output padding should be the last + 0 component, meaning the second component is 1. + + In case 8), however, the first window component does not cover all 0s, so the + next component should also be 0. If the second component were 1, information + from the last partial window of the input would be lost. + + In general, the anchor point should be the next position after the right edge + of the previous window. Since the anchor is defined by the left pad, + `left_pad = window - stride`, and `right_pad = (window - 1) - left_pad`, + simplifying to `right_pad = stride - 1`. + """ + window = 15 + stride = 6 + padding = "CAUSAL" + explicit_padding = convolution.conv_explicit_padding( + window=(window,), strides=(stride,), padding=padding, dilation=(1,) + ) + assert_allclose(explicit_padding[0], (9, 5)) + + in_paddings = jnp.array([in_paddings]) + out_paddings = compute_conv_paddings( + in_paddings, window=window, stride=stride, conv_padding=padding + )[0] + assert_allclose(out_paddings, expected) + + @parameterized.parameters( + (3, 1, ((1, 1),), "SAME"), + (3, 1, ((0, 0),), "VALID"), + (3, 1, ((2, 0),), "CAUSAL"), + (3, 2, ((1, 1),), "SAME"), + (3, 2, ((0, 0),), "VALID"), + (3, 2, ((1, 1),), "CAUSAL"), + (5, 2, ((2, 2),), "SAME"), + (5, 2, ((0, 0),), "VALID"), + (5, 2, ((3, 1),), "CAUSAL"), + ) + def test_conv_output_1d_padding_against_str_padding( + self, window: int, stride: int, padding: ConvPaddingType, ref_padding: ConvPaddingType + ): + """Tests conv_output_shape() with explicit padding cfg.""" + batch_size = 5 + seq_len = 5 + paddings = jnp.triu(jnp.ones((batch_size, seq_len)), k=1) + + explicit_padding = convolution.conv_explicit_padding( + window=(window,), strides=(stride,), padding=ref_padding, dilation=(1,) + ) + assert_allclose(explicit_padding, padding[:1]) + + out_paddings = compute_conv_paddings( + paddings, window=window, stride=stride, conv_padding=padding + ) + ref_paddings = compute_conv_paddings( + paddings, window=window, stride=stride, conv_padding=ref_padding + ) + assert_allclose(out_paddings, ref_paddings) + + @parameterized.parameters( + ("SAME", 1, [0, 0, 0, 0, 1, 1], [0, 0, 1]), + ("VALID", 1, [0, 0, 0, 0, 1, 1], [0]), + ("CAUSAL", 1, [0, 0, 0, 0, 1, 1], [0, 0, 1]), + ("SAME", 2, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1]), + ("VALID", 2, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0]), + ("CAUSAL", 2, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1]), + ) + def test_compute_conv_paddings_with_dilation( + self, padding: ConvPaddingType, dilation: int, paddings, expected + ): + """Tests compute_conv_paddings() as described in conv_explicit_padding().""" + window, stride = 5, 2 + out_paddings = compute_conv_paddings( + jnp.array([paddings]), + window=window, + stride=stride, + conv_padding=padding, + dilation=dilation, + )[0] + assert_allclose(out_paddings, expected) + + @parameterized.parameters( + (5, "SAME", None, [0, 0, 0, 1, 1, 1]), + (5, "SAME", 1, ValueError), + (5, "SAME", 2, [0, 0, 0, 1, 1, 1]), + (5, "SAME", 3, ValueError), + (5, ((1, 1),), None, [0, 0, 0, 1]), + (5, ((1, 1),), 0, ValueError), + (5, ((1, 1),), 1, [0, 0, 0, 1]), + (5, ((1, 1),), 2, [0, 0, 1, 1]), + (5, ((1, 1),), 3, [0, 1, 1, 1]), + (5, ((1, 1),), 4, ValueError), + (5, "VALID", None, [0, 0]), + (5, "VALID", 0, [0, 0]), + (5, "VALID", 1, [0, 0]), + (5, "VALID", 2, [0, 1]), + (5, "VALID", 3, [1, 1]), + (5, "VALID", 4, [1, 1]), + (5, "CAUSAL", None, [0, 0, 0, 1, 1, 1]), + (5, "CAUSAL", 3, ValueError), + (5, "CAUSAL", 4, [0, 0, 0, 1, 1, 1]), + (5, "CAUSAL", 5, ValueError), + ) + def test_conv_output_1d_padding_with_anchor(self, window, padding, anchor, expected_paddings): + input_paddings = [0, 0, 0, 1, 1, 1] + try: + out_paddings = compute_conv_paddings( + jnp.array([input_paddings]), + window=window, + stride=1, + conv_padding=padding, + anchor=anchor, + ) + assert_allclose(out_paddings[0], expected_paddings) + except ValueError as e: + self.assertTrue(isinstance(e, expected_paddings)) + + @parameterized.named_parameters( + ("w3s1d1_VALID", 3, 1, "VALID", None), + ("w3s1d2_VALID", 3, 1, "VALID", 2), + ("w3s1d1_SAME", 3, 1, "SAME", None), + ("w4s1d1_SAME", 4, 1, "SAME", None), + ("w4s1d3_SAME", 4, 1, "SAME", 3), + ("w4s1d1_CAUSAL", 4, 1, ((3, 0),), None), + ("w4s1d5_CAUSAL", 4, 1, ((3, 0),), 5), + ) + def test_conv1d( + self, + window: int, + strides: int, + padding: ConvPaddingType, + dilation: Optional[int], + ): + input_dim, output_dim = 4, 6 + cfg = Conv1D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + ) + layer: Conv1D = cfg.instantiate(parent=None) + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual( + dict(weight=(window, input_dim, output_dim), bias=(output_dim,)), + shapes(layer_params), + ) + bias = layer_params["bias"] + assert_allclose(bias, jnp.zeros_like(bias)) + # Randomize bias. + layer_params["bias"] = jax.random.normal( + jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype + ) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 17, input_dim]) + # Compute layer outputs. + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + + # Compute ref outputs. + if isinstance(padding, str): + ref_padding = padding.lower() + ref_inputs = inputs + else: + # torch.nn.Conv1d does not support asymmetric padding, so pad manually and use "valid". + ref_padding = "valid" + ref_inputs = jnp.pad(inputs, ((0, 0), padding[0], (0, 0))) + ref = torch.nn.Conv1d( + in_channels=input_dim, + out_channels=output_dim, + groups=1, + kernel_size=window, + stride=strides, + padding=ref_padding, + dilation=1 if dilation is None else dilation, + ) + # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). + _copy(layer_params["weight"].transpose(2, 1, 0), ref.weight) + _copy(layer_params["bias"], ref.bias) + ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) + assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) + + @parameterized.named_parameters( + ("w3s1_VALID", 3, 1, "VALID"), + ("w3s1_SAME", 3, 1, "SAME"), + ("w4s1_SAME", 4, 1, "SAME"), + ("w4s1_CAUSAL", 4, 1, ((3, 0),)), + ) + def test_depthwise_conv1d( + self, + window: int, + strides: int, + padding: ConvPaddingType, + ): + input_dim = 4 + cfg = Conv1D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=input_dim, + num_input_dim_groups=input_dim, + window=window, + strides=strides, + padding=padding, + ) + layer: Conv1D = cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual( + dict(weight=(window, 1, input_dim), bias=(input_dim,)), + shapes(layer_params), + ) + bias = layer_params["bias"] + assert_allclose(bias, jnp.zeros_like(bias)) + # Randomize bias. + layer_params["bias"] = jax.random.normal( + jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype + ) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 7, input_dim]) + + # Compute layer outputs. + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + + # Compute ref outputs. + if isinstance(padding, str): + ref_padding = padding.lower() + ref_inputs = inputs + else: + # torch.nn.Conv1d does not support asymmetric padding, so pad manually and use "valid". + ref_padding = "valid" + ref_inputs = jnp.pad(inputs, ((0, 0), padding[0], (0, 0))) + ref = torch.nn.Conv1d( + in_channels=input_dim, + out_channels=input_dim, + groups=input_dim, + kernel_size=window, + stride=strides, + padding=ref_padding, + ) + # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). + _copy(layer_params["weight"].transpose(2, 1, 0), ref.weight) + _copy(layer_params["bias"], ref.bias) + ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) + assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) + + # Fails if tolerance is made smaller. + @parameterized.named_parameters( + { + "testcase_name": "1x1", + "window": (1, 1), + "strides": (1, 1), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "2x2_VALID", + "window": (2, 2), + "strides": (1, 1), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "2x2_SAME", + "window": (2, 2), + "strides": (1, 1), + "padding": "SAME", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "2x2_S2_VALID", + "window": (2, 2), + "strides": (2, 2), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3_VALID", + "window": (3, 3), + "strides": (1, 1), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3_SAME", + "window": (3, 3), + "strides": (1, 1), + "padding": "SAME", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3_S2_VALID", + "window": (3, 3), + "strides": (2, 2), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3_S2_PADDING1", + "window": (3, 3), + "strides": (2, 2), + "padding": (1, 1), + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3_GROUPS4", + "window": (3, 3), + "strides": (1, 1), + "padding": "SAME", + "num_input_dim_groups": 4, + }, + ) + def test_conv2d( + self, + window: tuple[int, int], + strides: tuple[int, int], + padding: Union[str, tuple[int, int]], + num_input_dim_groups: int, + ): + input_dim, output_dim = 256, 128 + if isinstance(padding, tuple): + conv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) + else: + conv_padding = padding + cfg = Conv2D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=conv_padding, + num_input_dim_groups=num_input_dim_groups, + ) + layer: Conv2D = cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual( + dict( + weight=(window[0], window[1], input_dim // num_input_dim_groups, output_dim), + bias=(output_dim,), + ), + shapes(layer_params), + ) + bias = layer_params["bias"] + assert_allclose(bias, jnp.zeros_like(bias)) + # Randomize bias. + layer_params["bias"] = jax.random.normal( + jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype + ) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 10, 7, input_dim]) + + # Compute layer outputs. + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + + # Compute ref outputs. + ref_padding = padding.lower() if isinstance(padding, str) else padding + ref = torch.nn.Conv2d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=window, + stride=strides, + padding=ref_padding, + groups=num_input_dim_groups, + ) + # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). + _copy(layer_params["weight"].transpose(3, 2, 0, 1), ref.weight) + _copy(layer_params["bias"], ref.bias) + ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 3, 1, 2))) + # We currently don't match PyTorch as closely as we would like. + assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 1), atol=4e-6) + # Tests output_shape. + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + + @parameterized.named_parameters( + ("1x1", (1, 1), (1, 1), "VALID", None), + ("2x2_VALID", (2, 2), (1, 1), "VALID", None), + ("2x2_SAME", (2, 2), (1, 1), "SAME", None), + ("2x2_CAUSAL", (2, 2), (1, 1), "CAUSAL", None), + ("2x2_S2_VALID", (2, 2), (2, 2), "VALID", None), + ("2x2_S2_CAUSAL", (2, 2), (2, 2), "CAUSAL", None), + ("3x3_VALID", (3, 3), (1, 1), "VALID", None), + ("3x3_VALID_A0", (3, 3), (1, 1), "VALID", 0), + ("3x3_VALID_A1", (3, 3), (1, 1), "VALID", 1), + ("3x3_VALID_A2", (3, 3), (1, 1), "VALID", 2), + ("3x3_SAME", (3, 3), (1, 1), "SAME", None), + ("3x3_CAUSAL", (3, 3), (1, 1), "CAUSAL", None), + ("3x3_S2_VALID", (3, 3), (2, 2), "VALID", None), + ("3x3_S2_CAUSAL", (3, 3), (2, 2), "CAUSAL", None), + ("3x3_S2_PADDING1", (3, 3), (2, 2), (1, 1), None), + ) + def test_conv2d_with_1d_padding( + self, + window: tuple[int, int], + strides: tuple[int, int], + padding: Union[str, tuple[int, int]], + anchor: Optional[int], + ): + """Tests that Conv2DWith1DPadding has consistent outputs under different padding lengths. + + Generates a batch of input sequences. Pads the sequences under different lengths. + Checks that the outputs are the same. + """ + input_dim, input_channel, output_dim = 4, 7, 6 + if isinstance(padding, tuple): + conv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) + else: + conv_padding = padding + cfg = Conv2DWith1DPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=conv_padding, + anchor=anchor, + ) + layer: Conv2DWith1DPadding = cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual( + dict(weight=(window[0], window[1], input_dim, output_dim), bias=(output_dim,)), + shapes(layer_params), + ) + # Generate a batch of 10 input sequences. + batch_size, max_seq_len = 10, 10 + + prng_key, input_key = jax.random.split(prng_key) + inputs = ( + jax.random.normal(input_key, [batch_size, max_seq_len, input_channel, input_dim]) * 100 + ) + + # The 10 sequences have length 1 to 10. + paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) + + # Compute layer outputs. + (ref_outputs, ref_paddings), _ = F( + layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(ref_outputs.shape, output_shape) + + random_keys = jax.random.split(input_key, num=2 * max_seq_len) + for seq_len in range(1, max_seq_len): + # We create a new batch. The time axis of the new batch is of length seq_len. + permute_idx = jax.random.permutation(random_keys[2 * (seq_len - 1)], seq_len) + inputs_batch = jnp.take_along_axis(inputs, permute_idx[:, None, None, None], axis=0)[ + :, :seq_len + ] + paddings_batch = jnp.take_along_axis(paddings, permute_idx[:, None], axis=0)[ + :, :seq_len + ] + + # Generate random data at padding positions. + random_data = ( + jax.random.normal( + random_keys[2 * seq_len - 1], + [len(permute_idx), seq_len, input_channel, input_dim], + ) + * 1000 + ) + inputs_new_batch = jnp.where( + paddings_batch[:, :, None, None], random_data, inputs_batch + ) + + (outputs_batch, output_paddings_batch), _ = F( + layer, + inputs=dict(x=inputs_new_batch, paddings=paddings_batch), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + output_len = output_paddings_batch.shape[1] + if output_len > 0: + assert_allclose( + outputs_batch, + jnp.take_along_axis(ref_outputs, permute_idx[:, None, None, None], axis=0)[ + :, :output_len + ], + ) + assert_allclose( + output_paddings_batch, + jnp.take_along_axis(ref_paddings, permute_idx[:, None], axis=0)[:, :output_len], + ) + + @parameterized.named_parameters( + ("1_S1", 1, 1, "VALID", None), + ("2_S1_VALID", 2, 1, "VALID", None), + ("2_S2_SAME", 2, 2, "SAME", None), + ("2_S_CAUSAL", 2, 1, "CAUSAL", None), + ("2_S2_VALID", 2, 2, "VALID", None), + ("2_S2_CAUSAL", 2, 2, "CAUSAL", None), + ("3_S1_VALID", 3, 1, "VALID", None), + ("3_S1_VALID_A0", 3, 1, "VALID", 0), + ("3_S1_VALID_A1", 3, 1, "VALID", 1), + ("3_S1_VALID_A2", 3, 1, "VALID", 2), + ("3_S1_SAME", 3, 1, "SAME", None), + ("3_S1_CAUSAL", 3, 1, "CAUSAL", None), + ("3_S2_VALID", 3, 2, "VALID", None), + ("3_S2_CAUSAL", 3, 2, "CAUSAL", None), + ) + def test_conv1d_against_conv2d_with_1d_padding( + self, + window: int, + strides: int, + padding: ConvPaddingType, + anchor: Optional[int], + ): + input_dim, output_dim = 4, 6 + ref_cfg = Conv2DWith1DPadding.default_config().set( + name="ref", + input_dim=input_dim, + output_dim=output_dim, + window=(window, 1), + strides=(strides, 1), + padding=padding, + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = Conv1DWithPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + anchor=anchor, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + state = ref_layer.initialize_parameters_recursively(init_key) + test_state = dict( + bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o") + ) + + # Generate a batch of 10 input sequences. + batch_size, max_seq_len = 10, 10 + + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim]) + # The 10 sequences have length 1 to 10. + paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=test_state, + prng_key=prng_key, + ) + output_shape = test_layer.output_shape(input_shape=inputs.shape) + assert_allclose(test_outputs.shape, output_shape) + + inputs = einops.rearrange(inputs, "b t i -> b t 1 i") + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=state, + prng_key=prng_key, + ) + output_shape = ref_layer.output_shape(input_shape=inputs.shape) + assert_allclose(ref_outputs.shape, output_shape) + ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") + + assert_allclose(ref_paddings, test_paddings) + assert_allclose(ref_outputs, test_outputs) + + @parameterized.named_parameters( + { + "testcase_name": "1x1x1", + "window": (1, 1, 1), + "strides": (1, 1, 1), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "2x2x2_VALID", + "window": (2, 2, 2), + "strides": (1, 1, 1), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "2x2x2_SAME", + "window": (2, 2, 2), + "strides": (1, 1, 1), + "padding": "SAME", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "2x2x2_S2_VALID", + "window": (2, 2, 2), + "strides": (2, 2, 2), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3x3_VALID", + "window": (3, 3, 3), + "strides": (1, 1, 1), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3x3_SAME", + "window": (3, 3, 3), + "strides": (1, 1, 1), + "padding": "SAME", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3x3_S2_VALID", + "window": (3, 3, 3), + "strides": (2, 2, 2), + "padding": "VALID", + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3x3_S2_PADDING1", + "window": (3, 3, 3), + "strides": (2, 2, 2), + "padding": (1, 1, 1), + "num_input_dim_groups": 1, + }, + { + "testcase_name": "3x3x3_GROUPS4", + "window": (3, 3, 3), + "strides": (1, 1, 1), + "padding": "SAME", + "num_input_dim_groups": 4, + }, + ) + def test_conv3d( + self, + window: tuple[int, int], + strides: tuple[int, int], + padding: Union[str, tuple[int, int]], + num_input_dim_groups: int, + ): + input_dim, output_dim = 4, 8 + if isinstance(padding, tuple): + conv_padding = ( + (padding[0], padding[0]), + (padding[1], padding[1]), + (padding[2], padding[2]), + ) + else: + conv_padding = padding + cfg = Conv3D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=conv_padding, + num_input_dim_groups=num_input_dim_groups, + ) + layer: Conv3D = cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + expected = dict( + weight=(window[0], window[1], window[2], input_dim // num_input_dim_groups, output_dim), + bias=(output_dim,), + ) + self.assertEqual( + expected, + shapes(layer_params), + ) + bias = layer_params["bias"] + assert_allclose(bias, jnp.zeros_like(bias)) + # Randomize bias. + layer_params["bias"] = jax.random.normal( + jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype + ) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + + batch_size = 2 + inputs = jax.random.normal(input_key, [batch_size, 10, 7, 4, input_dim]) + + # Compute layer outputs. + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + + # Compute ref outputs. + ref_padding = padding.lower() if isinstance(padding, str) else padding + ref = torch.nn.Conv3d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=window, + stride=strides, + padding=ref_padding, + groups=num_input_dim_groups, + ) + + # weight.shape: (H, W, D, I, O) + # ref.weight.shape: (O, I, H, W, D) + _copy(layer_params["weight"].transpose(4, 3, 0, 1, 2), ref.weight) + _copy(layer_params["bias"], ref.bias) + + ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 4, 1, 2, 3))) + assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 4, 1)) + + # Tests output_shape. + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + + +class ConvTransposeTest(TestCase): + CONVT_EXPLICIT_PADDING_PARAMS = [ + (3, 1, "SAME", 1, (1, 1)), + (3, 2, "SAME", 1, (2, 1)), + (3, 3, "SAME", 1, (2, 2)), + (3, 4, "SAME", 1, (2, 3)), + (3, 1, "SAME", 2, (2, 2)), + (3, 2, "SAME", 2, (3, 2)), + (3, 3, "SAME", 2, (3, 3)), + (3, 1, "VALID", 1, (2, 2)), + (3, 2, "VALID", 1, (2, 2)), + (3, 3, "VALID", 1, (2, 2)), + (3, 4, "VALID", 1, (2, 3)), + (3, 1, "VALID", 2, (4, 4)), + (3, 2, "VALID", 2, (4, 4)), + (3, 3, "VALID", 2, (4, 4)), + (3, 1, "CAUSAL", 1, (2, 0)), + (3, 2, "CAUSAL", 1, (2, 1)), + (3, 3, "CAUSAL", 1, (2, 2)), + (3, 4, "CAUSAL", 1, (2, 3)), + (3, 1, "CAUSAL", 2, (4, 0)), + (3, 2, "CAUSAL", 2, (4, 1)), + (3, 3, "CAUSAL", 2, (4, 2)), + ] + + @parameterized.parameters(*CONVT_EXPLICIT_PADDING_PARAMS) + def test_conv_transpose_explicit_padding(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + explicit_padding = convolution.conv_transpose_explicit_padding( + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + ) + assert_allclose(explicit_padding[0], expected) + + @parameterized.parameters(*CONVT_EXPLICIT_PADDING_PARAMS) + def test_conv_transpose_explicit_padding_against_jax( + self, window, strides, padding, dilation, expected + ): + """Compare with jax.lax.convolution._conv_transpose_padding().""" + if padding == "CAUSAL": + self.skipTest("Causal padding is not supported in JAX.") + + # Copied from jax.lax.convolution._conv_transpose_padding. + def _conv_transpose_padding(k, s, padding): + if padding == "SAME": + pad_len = k + s - 2 + if s > k - 1: + pad_a = k - 1 + else: + pad_a = int(np.ceil(pad_len / 2)) + elif padding == "VALID": + pad_len = k + s - 2 + max(k - s, 0) + pad_a = k - 1 + else: + raise ValueError("Padding mode must be `SAME` or `VALID`.") + pad_b = pad_len - pad_a + return pad_a, pad_b + + dilate_window = convolution.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + ref_padding = _conv_transpose_padding(dilate_window, strides, padding) + + explicit_padding = convolution.conv_transpose_explicit_padding( + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + ) + + assert_allclose(explicit_padding[0], ref_padding) + assert_allclose(expected, ref_padding) + + @parameterized.parameters( + (3, 1, "SAME", 1, 4), + (3, 2, "SAME", 1, 8), + (3, 3, "SAME", 1, 12), + (3, 4, "SAME", 1, 16), + (3, 1, "SAME", 2, 4), + (3, 2, "SAME", 2, 8), + (3, 3, "SAME", 2, 12), + (3, 1, "VALID", 1, 6), + (3, 2, "VALID", 1, 9), + (3, 3, "VALID", 1, 12), + (3, 4, "VALID", 1, 16), + (3, 1, "VALID", 2, 8), + (3, 2, "VALID", 2, 11), + (3, 3, "VALID", 2, 14), + (3, 1, "CAUSAL", 1, 4), + (3, 2, "CAUSAL", 1, 8), + (3, 3, "CAUSAL", 1, 12), + (3, 4, "CAUSAL", 1, 16), + (3, 1, "CAUSAL", 2, 4), + (3, 2, "CAUSAL", 2, 8), + (3, 3, "CAUSAL", 2, 12), + ) + def test_conv_transpose_output_shape(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + out_shape = convolution.conv_transpose_output_shape( + in_shape=(4,), + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + ) + assert_allclose(out_shape[0], expected) + + @parameterized.parameters( + (3, 1, "SAME", 1, [0, 0, 1, 1]), + (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "SAME", 2, [0, 0, 1, 1]), + (3, 2, "SAME", 2, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 1, "VALID", 1, [0, 0, 1, 1, 1, 1]), + (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 1, 1, 1]), + (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), + (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]), + (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "CAUSAL", 1, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + ) + def test_compute_conv_transpose_paddings(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + in_paddings = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :] + out_paddings = convolution.compute_conv_transpose_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + expected = jnp.array(expected).astype(out_paddings.dtype) + self.assertNestedEqual(out_paddings[0], expected) + + @parameterized.product( + window=[1, 3], + strides=[1, 2, 3], + padding=["SAME", "VALID", "CAUSAL"], + dilation=[1, 2], + value=[0, 1], + ) + def test_compute_conv_transpose_paddings_all0or1( + self, window, strides, padding, dilation, value + ): + """If in_paddings is all valid or invalid, out_paddings must be all valid or invalid.""" + in_paddings = jnp.full([1, 4], fill_value=value) + out_paddings = convolution.compute_conv_transpose_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + expected = jnp.ones_like(out_paddings) * value + self.assertNestedEqual(out_paddings, expected) + + CONVT_PADDINGS_PARAMS = dict( + in_paddings=[ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 1], + [1, 0, 0, 0, 1], + [1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 0, 0, 1, 1], + [1, 1, 0, 1, 1, 1, 1, 0, 0, 0], + ], + window=[1, 3], + padding=["SAME", "VALID", "CAUSAL"], + dilation=[1, 2], + ) + + @parameterized.product(**CONVT_PADDINGS_PARAMS, strides=[1, 2, 3]) + def test_compute_conv_transpose_paddings_with_conv_paddings( + self, in_paddings, window, strides, padding, dilation + ): + """Check if ConvT -> Conv preserves information.""" + in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] + out_paddings = convolution.compute_conv_transpose_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + + recon_paddings = convolution.compute_conv_paddings( + out_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + self.assertNestedEqual(recon_paddings[0], in_paddings[0]) + + @parameterized.product(**CONVT_PADDINGS_PARAMS) + def test_compute_conv_transpose_paddings_against_conv_paddings( + self, in_paddings, window, padding, dilation + ): + # compute_conv_transpose_paddings and compute_conv_paddings are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + strides = 1 + if padding == "VALID": + # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint + # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. + # For reference, the midpoint in SAME and CAUSAL is left_pad. + dilate_window = convolution.conv_dilate_window(window=(window,), dilation=(dilation,))[ + 0 + ] + conv_padding = convolution.conv_explicit_padding( + window=(window,), strides=(strides,), padding=padding, dilation=(dilation,) + )[0] + pad_left, pad_right = conv_padding + anchor_range = dilate_window - pad_left - pad_right + mid_point = anchor_range // 2 + anchor = pad_left + mid_point + else: + anchor = None + + in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] + ref_paddings = convolution.compute_conv_paddings( + in_paddings, + window=window, + stride=strides, + conv_padding=padding, + dilation=dilation, + anchor=anchor, + ) + + test_paddings = convolution.compute_conv_transpose_paddings( + in_paddings, + window=window, + stride=strides, + conv_padding=padding, + dilation=dilation, + anchor=anchor, + ) + + if ref_paddings.shape != test_paddings.shape: + self.assertEqual(padding, "VALID") + dilate_window = convolution.conv_dilate_window(window=(window,), dilation=(dilation,))[ + 0 + ] + pad_left = dilate_window - 1 + test_paddings = test_paddings[:, pad_left:-pad_left] + + assert_allclose(ref_paddings, test_paddings) + + CONVT_PARAMS = [ + (3, 1, "SAME", 1, [0, 1, 2, 2]), + (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 2, 1]), + (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "SAME", 2, [1, 1, 1, 1]), + (3, 2, "SAME", 2, [0, 0, 0, 1, 0, 2, 0, 2]), + (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0]), + (3, 1, "VALID", 1, [0, 0, 1, 2, 2, 1]), + (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 2, 1, 1]), + (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), + (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 1]), + (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1]), + (3, 1, "CAUSAL", 1, [0, 0, 1, 2]), + (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 2, 1]), + (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 0, 2, 0]), + (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1]), + ] + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv1d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None] + cfg = convolution.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + (outputs, paddings), _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, out_shape) + self.assertIsNone(paddings) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0], expected) + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv2d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + window = (window, 1) + strides = (strides, 1) + dilation = (dilation, 1) + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None] + cfg = convolution.Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + outputs, _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, out_shape) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0, 0], expected) + + @parameterized.named_parameters( + { + "testcase_name": "2x2", + "window": (2, 2), + "strides": (1, 1), + "padding": "VALID", + }, + { + "testcase_name": "2x2_S2", + "window": (2, 2), + "strides": (2, 2), + "padding": "VALID", + }, + { + "testcase_name": "3x3_S2", + "window": (3, 3), + "strides": (2, 2), + "padding": "VALID", + }, + ) + def test_conv2d_transpose_against_pytorch( + self, + window: tuple[int, int], + strides: tuple[int, int], + padding: Union[str, tuple[int, int]], + ): + input_dim, output_dim = 4, 8 + if isinstance(padding, tuple): + deconv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) + else: + deconv_padding = padding + cfg = Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=deconv_padding, + transpose_kernel=True, + ) + layer: Conv2DTranspose = cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual( + dict( + weight=(window[0], window[1], output_dim, input_dim), + bias=(output_dim,), + ), + shapes(layer_params), + ) + bias = layer_params["bias"] + assert_allclose(bias, jnp.zeros_like(bias)) + # Randomize bias. + layer_params["bias"] = jax.random.normal( + jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype + ) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 10, 7, input_dim]) + # Compute layer outputs. + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + + # Compute ref outputs. + if isinstance(padding, tuple): + ref_padding = padding[0] + elif isinstance(padding, str): + ref_padding = padding.lower() + if ref_padding == "valid": + ref_padding = 0 + else: + ref_padding = 0 + + ref = torch.nn.ConvTranspose2d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=window, + stride=strides, + padding=ref_padding, + ) + # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). + _copy(layer_params["weight"].transpose(3, 2, 0, 1), ref.weight) + _copy(layer_params["bias"], ref.bias) + ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 3, 1, 2))) + assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 1)) + # Tests output_shape. + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv3d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + window = (window, 1, 1) + strides = (strides, 1, 1) + dilation = (dilation, 1, 1) + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None, None] + cfg = convolution.Conv3DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + outputs, _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, out_shape) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0, 0, 0], expected) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv1d_transpose_against_conv1d(self, window, padding, dilation): + # Conv1D and Conv1DTranspose are same when window_stride=1 + # (stride of Conv1D) and lhs_dilation=1 (stride of Conv1DTranspose). + input_dim, output_dim = 4, 6 + ref_cfg = Conv1D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = convolution.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 17, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key + ) + + (test_outputs, _), _ = F( + test_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = convolution.conv_dilate_window(window=(window,), dilation=(dilation,))[ + 0 + ] + pad_left = dilate_window - 1 + test_outputs = test_outputs[:, pad_left:-pad_left] + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv2d_transpose_against_conv2d(self, window, padding, dilation): + # Conv2D and Conv2DTranspose are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + window = (window, window) + dilation = (dilation, dilation) + input_dim, output_dim = 4, 6 + ref_cfg = Conv2D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = convolution.Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + transpose_kernel=False, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height = 12, 13 + inputs = jax.random.normal(input_key, [2, width, height, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + test_outputs, _ = F( + test_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = convolution.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] + + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv2d_transpose_against_conv2d_with_paddings(self, window, padding, dilation): + # Conv2DWith1DPadding and Conv2DTransposeWith1DPadding are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + window = (window, window) + dilation = (dilation, dilation) + input_dim, output_dim = 4, 6 + if padding == "VALID": + # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint + # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. + # For reference, the midpoint in SAME and CAUSAL is left_pad. + strides = (1, 1) + dilate_window = convolution.conv_dilate_window(window=window, dilation=dilation)[0] + conv_padding = convolution.conv_explicit_padding( + window=window, strides=strides, padding=padding, dilation=dilation + ) + pad_left, pad_right = conv_padding[0] + anchor_range = dilate_window - pad_left - pad_right + mid_point = anchor_range // 2 + anchor = pad_left + mid_point + else: + anchor = None + + ref_cfg = Conv2DWith1DPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = convolution.Conv2DTransposeWith1DPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + anchor=anchor, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height = 12, 13 + inputs = jax.random.normal(input_key, [2, width, height, input_dim]) + paddings = jnp.zeros([2, width], dtype=inputs.dtype).at[:, -2:].set(1) + # Compute layer outputs. + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = convolution.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] + test_paddings = test_paddings[:, pad_left[0] : -pad_left[0]] + + assert_allclose(ref_outputs, test_outputs) + assert_allclose(ref_paddings, test_paddings) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv3d_transpose_against_conv3d(self, window, padding, dilation): + # Conv3D and Conv3DTranspose are same when window_stride=1 + # (stride of Conv3D) and lhs_dilation=1 (stride of Conv3DTranspose). + window = (window, window, window) + dilation = (dilation, dilation, dilation) + input_dim, output_dim = 4, 6 + ref_cfg = Conv3D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = convolution.Conv3DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height, depth = 9, 8, 7 + inputs = jax.random.normal(input_key, [2, width, height, depth, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + test_outputs, _ = F( + test_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = convolution.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[ + :, + pad_left[0] : -pad_left[0], + pad_left[1] : -pad_left[1], + pad_left[2] : -pad_left[2], + ] + + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product( + window=(1, 3, 5), + strides=(1, 2), + padding=("SAME", "VALID", "CAUSAL"), + dilation=(1, 2), + anchor=(None, 1), + ) + def test_conv1d_transpose_against_conv2d_transpose_with_1d_padding( + self, + window, + strides, + padding: ConvPaddingType, + dilation, + anchor, + ): + if anchor is not None: + dilate_window = convolution.conv_dilate_window(window=(window,), dilation=(dilation,))[ + 0 + ] + anchor = dilate_window - 1 + + input_dim, output_dim = 4, 6 + ref_cfg = convolution.Conv2DTransposeWith1DPadding.default_config().set( + name="ref", + input_dim=input_dim, + output_dim=output_dim, + window=(window, 1), + strides=(strides, 1), + padding=padding, + dilation=(dilation, 1), + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = convolution.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + anchor=anchor, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + state = ref_layer.initialize_parameters_recursively(init_key) + test_state = dict( + bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o") + ) + + # Generate a batch of 10 input sequences. + batch_size, max_seq_len = 10, 10 + + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim]) + # The 10 sequences have length 1 to 10. + paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=test_state, + prng_key=prng_key, + ) + + inputs = einops.rearrange(inputs, "b t i -> b t 1 i") + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=state, + prng_key=prng_key, + ) + ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") + + assert_allclose(ref_paddings, test_paddings) + assert_allclose(ref_outputs, test_outputs) + + +class StackOverTimeTest(TestCase): + @parameterized.parameters( + ( + 2, + (0, 0), + [[[1, 1, 2, 2], [3, 3, 4, 4]], [[7, 7, 8, 8], [0, 0, 0, 0]]], + [[0, 0], [0, 1]], + ), + ( + 3, + (0, 0), + [[[1, 1, 2, 2, 3, 3]], [[7, 7, 8, 8, 0, 0]]], + [[0], [0]], + ), + ( + 3, + (2, 0), + [[[0, 0, 0, 0, 1, 1], [2, 2, 3, 3, 4, 4]], [[0, 0, 0, 0, 7, 7], [0, 0, 0, 0, 0, 0]]], + [[0, 0], [0, 1]], + ), + ) + def test_stack_over_time(self, stride, pad, expected_outputs, expected_output_paddings): + # Input shape [2, 5, 2]. + inputs = jnp.array( + [[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0]]], + dtype=jnp.float32, + ) + paddings = jnp.array([[0, 0, 0, 0, 0], [0, 0, 1, 1, 1]]) + layer: StackOverTime = ( + StackOverTime.default_config() + .set( + name="test", + stride=stride, + padding=pad, + ) + .instantiate(parent=None) + ) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + (outputs, output_paddings), _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=False, + state=layer_params, + prng_key=jax.random.PRNGKey(5), + ) + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + assert_allclose(jnp.array(expected_outputs, dtype=jnp.float32), outputs) + assert_allclose(jnp.array(expected_output_paddings, dtype=jnp.int32), output_paddings) + + def test_stack_over_time_data_change(self): + """Tests that the stacked outputs is masked with the output paddings.""" + np.random.seed(500) + inputs = np.random.normal(size=[2, 21, 16]) + paddings = np.ones([2, 21], dtype=np.float32) + paddings[0, :9] = 0 + paddings[1, :14] = 0 + inputs = inputs * (1 - paddings)[:, :, None] + + layer: StackOverTime = ( + StackOverTime.default_config() + .set( + name="test", + stride=2, + ) + .instantiate(parent=None) + ) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + (outputs, output_paddings), _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=False, + state=layer_params, + prng_key=jax.random.PRNGKey(5), + ) + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + assert_allclose(np.array([5, 7], dtype=np.float32), np.sum(1 - output_paddings, axis=1)) + assert_allclose(np.sum(inputs**2, (1, 2)), np.sum(outputs**2, (1, 2))) + + @parameterized.product(stride=(2, 3, 4), pad=("VALID", "SAME", "CAUSAL")) + def test_stack_consistent_outputs(self, stride, pad): + """Tests that StackOverTime has consistent outputs under different padding lengths.""" + batch_size, input_dim = 2, 1 + input_length = 7 + layer: StackOverTime = ( + StackOverTime.default_config() + .set( + name="test", + stride=stride, + padding=pad, + ) + .instantiate(parent=None) + ) + expected_output_length = layer.output_shape(input_shape=[1, input_length, 1])[1] + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + for ll in range(4, 11): + # Batch with another example of length ll. + length = max(input_length, ll) + inputs = jnp.ones([batch_size, length, input_dim]) + paddings = jnp.arange(length)[None, :] >= jnp.array([input_length, ll])[:, None] + (outputs, output_paddings), _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=False, + state=layer_params, + prng_key=jax.random.PRNGKey(5), + ) + output_shape = layer.output_shape(input_shape=inputs.shape) + assert_allclose(outputs.shape, output_shape) + if pad != "VALID": # VALID doesn't preserve length. + self.assertEqual(expected_output_length, np.sum(1 - output_paddings, axis=1)[0]) + + @parameterized.parameters(((0, 1), (0, 0)), ((1, 1), (3, 0)), ((1, 1), (0, 3))) + def test_stack_vs_conv2d_output_len_match(self, conv_padding, stack_padding): + # Note that to get the same output length, we need to pad the sequence differently + # for convolution and stacking layer. + for audio_seq_len in [16000, 16160, 16320, 16480, 16640, 16800, 16960, 17120]: + sampling_rate, window_size_ms, window_step_ms = 16000, 25, 10 + window_size = window_size_ms * sampling_rate // 1000 + window_step = window_step_ms * sampling_rate // 1000 + seq_len = max(audio_seq_len - window_size, 0) // window_step + 1 + conv_layer: Conv2DWith1DPadding = ( + Conv2DWith1DPadding.default_config() + .set( + name="test_conv", + input_dim=3, + output_dim=3, + window=(3, 3), + strides=(2, 2), + padding=(conv_padding, (0, 1)), + ) + .instantiate(parent=None) + ) + stack_layer: StackOverTime = ( + StackOverTime.default_config() + .set(name="test_stack", stride=4, padding=stack_padding) + .instantiate(parent=None) + ) + # Computes downsampler output shape. + down_sample_shape1 = conv_layer.output_shape(input_shape=[None, seq_len, 80, 3]) + down_sample_shape2 = conv_layer.output_shape(input_shape=down_sample_shape1) + + # Computes stack output shape. + stack_shape = stack_layer.output_shape(input_shape=[None, seq_len, 80]) + # Tests that the sequence length dimension matches. + self.assertEqual(down_sample_shape2[1], stack_shape[1]) + + +if __name__ == "__main__": + with utils.numeric_checks(True): + absltest.main() diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index 6736fb580..cc9798afd 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -19,16 +19,15 @@ import enum from collections.abc import Sequence -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Optional, Union -import chex -import einops import jax from absl import logging from jax import nn from jax import numpy as jnp from jax.sharding import PartitionSpec +from axlearn.common import convolution from axlearn.common.base_layer import BaseLayer, FactorizationSpec, ParameterNoise, ParameterSpec from axlearn.common.config import ( REQUIRED, @@ -38,6 +37,7 @@ UnknownFieldError, config_class, ) +from axlearn.common.convolution import Conv2D from axlearn.common.loss import binary_cross_entropy, categorical_hinge_loss, cross_entropy from axlearn.common.metrics import WeightedScalar from axlearn.common.metrics_classification import precision_recall_f_score @@ -58,12 +58,15 @@ with_sharding_constraint, ) -# The padding type for jax.lax.conv_general_dilated API. Either the strings ‘SAME’, or ‘VALID’, or -# 'CAUSAL' or a sequence of n (low, high) integer pairs that give the padding to apply before and -# after each spatial dimension. The number of tuple is 1 for NHC, 2 for NHWC and 3 for NHWDC. -ConvPaddingType = Union[str, Sequence[tuple[int, int]]] - -SUPPORT_CONV_PADDING = ("SAME", "VALID", "CAUSAL") +# TODO(dhwang2): remove them. +# DEPRECATED: Avoid using this; instead, directly import convolution.py. Aliases for convolution are +# provided for backward compatibility. +ConvPaddingType = convolution.ConvPaddingType +Conv1D = convolution.Conv1D +Conv2D = convolution.Conv2D +Conv2DTranspose = convolution.Conv2DTranspose +Conv2DWith1DPadding = convolution.Conv2DWith1DPadding +Conv3D = convolution.Conv3D def get_activation_fn(name) -> Callable[[Tensor], Tensor]: @@ -765,1693 +768,6 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti return [input_shape[0], output_height, output_width, input_shape[3]] -############################## Convolution ######################################################### - - -def _check_conv_cfg( - *, - window: Sequence[int], - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Optional[Sequence[int]], -): - if any(w < 1 for w in window): - raise ValueError(f"window ({window}) must be a positive integer.") - - if any(s < 1 for s in strides): - raise ValueError(f"strides ({strides}) must be a positive integer.") - - if isinstance(padding, str): - if padding not in SUPPORT_CONV_PADDING: - raise ValueError(f"{padding} padding is not supported.") - else: - padding_flattened = jax.tree.leaves(padding) - if any(p < 0 for p in padding_flattened): - raise ValueError("Negative padding is not supported") - - if dilation is not None and any(d < 1 for d in dilation): - raise ValueError(f"dilation ({dilation}) must be a positive integer.") - - -class BaseConv(BaseLayer): - """Base class for convolution layers.""" - - @config_class - class Config(BaseLayer.Config): - input_dim: Required[int] = REQUIRED # Input feature dim. - - # pylint: disable-next=no-self-use - def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optional[FanAxes]: - if not name.endswith("weight"): - return None - if len(parameter_spec.shape) < 2: - raise NotImplementedError( - "Default _compute_fan_axes requires weight parameters to have at least 2 axes " - f"shape({name}) = {parameter_spec.shape}" - ) - # All other axes represent receptive field. - return FanAxes(in_axis=-2, out_axis=-1) - - -# Copied from jax.lax._dilate_shape -# https://github.com/jax-ml/jax/blob/2d78b172266870bd755b039f6faa2056a51930f9/jax/_src/lax/lax.py#L5763 -def conv_dilate_window(*, window: Sequence[int], dilation: Optional[Sequence[int]]): - """Returns dilated effective window size. - - Args: - window: convolution window. - dilation: convolution dilation. - - Returns: - The dilated effective window size. - """ - if dilation is None or all(d == 1 for d in dilation): - return window - - return tuple(max(1 + d * (w - 1), 0) for w, d in zip(window, dilation)) - - -# Copied from subroutine in jax.lax.reduce_window. -# Extend lax.padtype_to_pads for CAUSAL. -def conv_explicit_padding( - *, - window: Sequence[int], - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Optional[Sequence[int]] = None, -) -> ConvPaddingType: - """Returns the explicit padding for "SAME", "VALID", and "CAUSAL" modes. - - Each mode follows the formulas below: - * SAME: (pad_total//2, pad_total - pad_total//2) s.t. pad_total = window-1 - * VALID: (0, 0) - * CAUSAL: (window - stride, stride - 1) - - Note: In the above equation, `window` will be replaced with `dilate_window` when dilation > 1. - dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() - - For example, window=5 and stride=2, - * SAME: padding = (2, 2) - pad| |pad - paddings: 0 0|0 0 0 0 1 1|1 1 - |___^___| - |___^___| - |___^___| - - * VALID: padding = (0, 0) - | | - paddings: |0 0 0 0 1 1| - |^_______| - - * CAUSAL: padding = (3, 1) - pad | |pad - paddings: 0 0 0|0 0 0 0 1 1|1 - |_____^_| - |_____^_| - |_____^_| - - - For example, window=5, stride=2 and dilation=2 - -> dilate_window = 9 (== (window-1)*dilation + 1) and pad_total = 8 - * SAME: padding = (4, 4) - pad| |pad - paddings: 0 0 0 0|0 0 0 0 0 0 0 0 1 1|1 1 1 1 - |_______^_______| - |_______^_______| - |_______^_______| - |_______^_______| - |_______^_______| - - * VALID: padding = (0, 0) - | |pad - paddings: |0 0 0 0 0 0 0 0 1 1| - |^_______________| - - * CAUSAL: padding = (7, 1) - pad | |pad - paddings: 0 0 0 0 0 0 0|0 0 0 0 0 0 0 0 1 1|1 - |_____________^_| - |_____________^_| - |_____________^_| - |_____________^_| - |_____________^_| - - For "CAUSAL", the first component is time and treated as "CAUSAL", while the remaining - components are handled with "SAME" padding. - - Args: - window: convolution window. - strides: convolution strides. - padding: convolution padding. - dilation: convolution dilation. - - Returns: - The padding tuple. - - Raises: - ValueError: If padding is not supported. - """ - if not isinstance(padding, str): - return padding - window = conv_dilate_window(window=window, dilation=dilation) - - def same_padding(window): - pad_total = tuple(w - 1 for w in window) - pad_left = tuple(pt // 2 for pt in pad_total) - pad_right = tuple(pt - pl for pt, pl in zip(pad_total, pad_left)) - return tuple(zip(pad_left, pad_right)) - - if padding == "SAME": - return same_padding(window) - elif padding == "VALID": - return ((0, 0),) * len(window) - elif padding == "CAUSAL": - causal_padding = ((window[0] - strides[0], strides[0] - 1),) - if len(window) > 1: - causal_padding += same_padding(window[1:]) - return causal_padding - else: - raise ValueError(f"{padding} padding is not supported.") - - -def conv_output_shape( - in_shape: Sequence[Optional[int]], - *, - window: Sequence[int], - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Optional[Sequence[int]] = None, -) -> Sequence[int]: - """Returns output size for convolution. - - Follow https://www.tensorflow.org/api_docs/python/tf/nn/convolution - * SAME: ceil(in_size / stride) - * VALID: ceil((in_size - (window - 1) * dilation) / stride) - - Args: - in_shape: convolution lhs shape. - window: convolution window. - strides: convolution strides. - padding: convolution padding. - dilation: convolution dilation. - - Returns: - The output shape. - - Raises: - ValueError: If the length of in_shape, window, strides, and padding are not equal. - """ - if len(in_shape) != len(window) or len(in_shape) != len(strides): - raise ValueError( - f"len(in_shape) = {len(in_shape)} must be equal to " - f"len(window) = {len(window)} and len(strides) = {len(strides)}" - ) - - padding = conv_explicit_padding( - window=window, strides=strides, padding=padding, dilation=dilation - ) - pad_amount = tuple(sum(p) for p in padding) - dilate_window = conv_dilate_window(window=window, dilation=dilation) - - def output_shape(in_shape: Optional[int], dilate_window: int, pad_amount: int, stride: int): - if in_shape is None: - return None - numerator = max(in_shape + pad_amount - (dilate_window - 1), 0) - # ceil(numerator / stride) - return (numerator + stride - 1) // stride - - return tuple(map(output_shape, in_shape, dilate_window, pad_amount, strides)) - - -# The accuracy of the output of this layer currently doesn't match that of PyTorch -# quite as closely as we would like. See layers_test.py:test_conv2d(). -class Conv2D(BaseConv): - """The 2-D convolution layer. - - Kernel weights have the HWIO layout and in the shape of (window[0], window[1], input_dim, - output_dim). Both inputs and outputs will be in the NHWC layout. - """ - - @config_class - class Config(BaseConv.Config): - """Configures Conv2D.""" - - window: tuple[int, int] = (1, 1) # The convolution window. - strides: tuple[int, int] = (1, 1) # The convolution strides. - # Paddings: "SAME", "VALID", "CAUSAL" or ((top, bottom), (left, right)). - # Note: Sequence models use the first component to represent time. - padding: ConvPaddingType = ((0, 0), (0, 0)) - # The convolution dilation. If None, assume all 1's. - dilation: Optional[tuple[int, int]] = None - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - # The number of groups in which the input is split along the channel axis. - # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, - # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). - # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers - # side by side, each seeing half the input and producing half the output channels. - # - At num_input_dim_groups=input_dim, each input channel is convolved with its own - # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, - # where K is a positive integer, the operation is also known as a "depthwise convolution". - num_input_dim_groups: Optional[int] = 1 - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, None, None) - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - _check_conv_cfg( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - params = dict( - weight=ParameterSpec( - shape=list(cfg.window) - + [cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=(None, None, "row", "col")), - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward(self, x: Tensor) -> Tensor: - cfg = self.config - conv_padding = conv_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) - - def _conv( - self, - x: Tensor, - *, - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Optional[Sequence[int]], - ) -> Tensor: - cfg = self.config - output = jax.lax.conv_general_dilated( - lhs=x, - rhs=self.parameters["weight"], - window_strides=strides, - padding=padding, - rhs_dilation=dilation, - dimension_numbers=("NHWC", "HWIO", "NHWC"), - feature_group_count=cfg.num_input_dim_groups, - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 4: - raise ValueError(f"We expect len(input_shape) = 4, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - f"cfg.input_dim = {cfg.input_dim}." - ) - - in_shape = input_shape[1:3] - out_shape = conv_output_shape( - in_shape, - window=cfg.window, - strides=cfg.strides, - padding=cfg.padding, - dilation=cfg.dilation, - ) - return [input_shape[0], *out_shape, cfg.output_dim] - - -def compute_conv_paddings( - in_paddings: Tensor, - *, - window: int, - stride: int, - conv_padding: ConvPaddingType, - dilation: Optional[int] = None, - anchor: Optional[int] = None, -): - """Compute output paddings w.r.t. conv_padding. - - The output paddings value is determined by the padding value at the anchor point in the - window. If anchor is None, the default anchor point is the left time padding from conv - padding config. See `Conv2DWith1DPadding.Config` in details. - - Args: - in_paddings: A Tensor of shape [batch_size, seq_len]. - window: convolution window size of the time axis. - stride: convolution stride size of the time axis. - conv_padding: "SAME", "VALID", "CAUSAL" or ((left_time_padding, right_time_padding),) - dilation: convolution dilation size of the time axis. - anchor: an optional integer in the range of [left_time_padding, window - right_time_padding) - that specifies the anchor position within the convolution window that is used to - determine output paddings. Specifically, the output token is valid iff the input token - at the anchor position of the corresponding window is valid. - If None, anchor defaults to conv_padding[0] (i.e. left_time_padding). - - Returns: - out_paddings: A Tensor of shape [batch_size, seq_len]. - - Raises: - ValueError: If anchor is not between left_time_padding and right_time_padding. - """ - chex.assert_rank(in_paddings, 2) - dilation = dilation or 1 - conv_padding = conv_explicit_padding( - window=(window,), strides=(stride,), padding=conv_padding, dilation=(dilation,) - ) - window = conv_dilate_window(window=(window,), dilation=(dilation,))[0] - left_pad, right_pad = conv_padding[0] - pad_total = window - 1 - - if anchor is None: - # valid_window = pad_total - left_pad - right_pad - # anchor_global = valid_window // 2 - # anchor = anchor_global + left_pad - anchor = left_pad - elif not left_pad <= anchor < window - right_pad: - raise ValueError(f"anchor ({anchor}) must in range [{left_pad}, {window - right_pad}).") - - # This is a method to avoid using jax.pad, by leveraging the property that the valid_window - # is always within the input sequence. - # Note: transform anchor from window frame to input sequence frame. - start_index = anchor - left_pad - valid_window = pad_total - left_pad - right_pad - valid_window_right_pad = valid_window - start_index - seq_len = in_paddings.shape[1] - limit_index = max(seq_len - valid_window_right_pad, start_index) - if seq_len < start_index: - start_index = 0 - limit_index = 0 - out_paddings = jax.lax.slice_in_dim( - in_paddings, start_index=start_index, limit_index=limit_index, stride=stride, axis=1 - ) - return out_paddings - - -# TODO(dhwang2): move to convolution transpose section. -class Conv2DTranspose(BaseConv): - """The 2-D transposed convolution layer.""" - - @config_class - class Config(BaseConv.Config): - """Configures Conv2DTranspose.""" - - window: tuple[int, int] = (1, 1) - strides: tuple[int, int] = (1, 1) - padding: Required[ConvPaddingType] = REQUIRED - dilation: tuple[int, int] = (1, 1) - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - # If True, kernel weights have the HWOI layout, following the format used by - # keras.layers.Conv2DTranspose. - # Otherwise, the standard layout HWIO is used, which is more efficient. - transpose_kernel: bool = False - - @classmethod - def default_config(cls): - cfg = super().default_config() - if cfg.transpose_kernel: - cfg.param_partition_spec = (None, None, "model", None) - else: - cfg.param_partition_spec = (None, None, None, "model") - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - _check_conv_cfg( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - if cfg.transpose_kernel: - io_shape = (cfg.output_dim, cfg.input_dim) - else: - io_shape = (cfg.input_dim, cfg.output_dim) - params = dict( - weight=ParameterSpec( - shape=tuple(cfg.window) + io_shape, - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=(None, None, "row", "col")), - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward(self, x: Tensor) -> Tensor: - cfg = self.config - conv_padding = conv_transpose_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) - - def _conv( - self, - x: Tensor, - *, - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Sequence[int], - ) -> Tensor: - cfg = self.config - output = jax.lax.conv_transpose( - lhs=x, - rhs=self.parameters["weight"], - strides=strides, - padding=padding, - rhs_dilation=dilation, - dimension_numbers=("NHWC", "HWIO", "NHWC"), - transpose_kernel=cfg.transpose_kernel, - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 4: - raise ValueError(f"We expect len(input_shape) = 4, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - "cfg.input_dim = {cfg.input_dim}." - ) - - in_shape = input_shape[1:3] - out_shape = conv_transpose_output_shape( - in_shape, - window=cfg.window, - strides=cfg.strides, - padding=cfg.padding, - dilation=cfg.dilation, - ) - return [input_shape[0], *out_shape, cfg.output_dim] - - -class Conv2DWith1DPadding(Conv2D): - """The 2-D convolution with 1-D padding on the time axis.""" - - @config_class - class Config(Conv2D.Config): - """Configures Conv2DWith1DPadding. - - The output paddings value is determined by the padding value at the anchor point in the - window. If anchor is None, the default anchor point is the left time padding from conv - padding config. - - For examples with window=5, - 1. "SAME" padding case, - * padding=(2,2): (0 0 0 0 0) - * anchor index is 2: (0 0 |0| 0 0) - pad | | pad - paddings: 0 0|0 0 0 1 1 1|1 1 - |___0___| - |___0___| - |___0___| - |___1___| - |___1___| - |___1___| - - 2. "VALID" padding case, - * padding=(0,0): (0 0 0 0 0) - * anchor index is 0: (|0| 0 0 0 0) - pad | | pad - paddings: |0 0 0 1 1 1| - |0_______| - |0_______| - - 3. The legacy "VALID" padding case, - * padding=(0,0) and anchor=4: (0 0 0 0 0) - * anchor index is 4: (0 0 0 0 |0|) - pad | | pad - paddings: |0 0 0 1 1 1| - |________1| - |________1| - - 4. "CAUSAL" padding case, - * padding=(4,0): (0 0 0 0 0) - * anchor index is 4: (0 0 0 0 |0|) - pad | | pad - paddings: 0 0 0 0|0 0 0 1 1 1| - |_______0| - |_______0| - |_______0| - |_______1| - |_______1| - |_______1| - - 5. "CAUSAL" with lookahead=1, - * padding=(3, 1): (0 0 0 0 0) - * anchor index is 3: (0 0 0 |0| 0) - pad | | pad - paddings: 0 0 0|0 0 0 1 1 1|1 - |_____0_| - |_____0_| - |_____0_| - |_____1_| - |_____1_| - |_____1_| - - 6. Arbitrary padding case, - * padding=(2,1): (0 0 0 0 0) - * anchor index is 2: (0 0 |0| 0 0) - pad | | pad - paddings: 0 0|0 0 0 1 1 1|1 - |___0___| - |___0___| - |___0___| - |___1___| - |___1___| - """ - - # An optional integer in the range of [left_time_padding, window - right_time_padding) - # that specifies the anchor position within the convolution window that is used to - # determine output paddings. Specifically, the output token is valid iff the input token - # at the anchor position of the corresponding window is valid. - # If None, defaults to left time padding. - anchor: Optional[int] = None - - # We add a kwargs "paddings" to the forward method. - # pylint: disable-next=arguments-differ - def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: - """Computes convolution outputs and paddings. - - Args: - x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. - paddings: 0/1 Tensor of shape [batch_size, seq_len]. - - Returns: - output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. - paddings: 0/1 Tensor of shape [batch_size, seq_len]. - """ - cfg = self.config - # Apply padding to the input. - assert len(x.shape) == len(paddings.shape) + 2 - x = x * (1 - paddings[..., None, None]) - - # Apply Conv2D. - output = super().forward(x) - # Compute paddings conv output. - dilation = 1 if cfg.dilation is None else cfg.dilation[0] - output_paddings = compute_conv_paddings( - paddings, - window=cfg.window[0], - stride=cfg.strides[0], - conv_padding=cfg.padding, - dilation=dilation, - anchor=cfg.anchor, - ) - # Apply padding to the outputs. - output = output * (1 - output_paddings[..., None, None]) - return output, output_paddings - - -class Conv3D(BaseConv): - """The 3-D convolution layer. - - Kernel weights have the HWDIO layout and in the shape of (window[0], window[1], - window[2], input_dim, output_dim). Both inputs and outputs will be in the NHWDC layout. - """ - - @config_class - class Config(BaseConv.Config): - """Configures Conv3D.""" - - window: tuple[int, int, int] = (1, 1, 1) # The convolution window. - strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. - - # Paddings: "SAME" or "VALID, or ((top, bottom), (left, right), (front, back)) - padding: ConvPaddingType = ( - (0, 0), - (0, 0), - (0, 0), - ) - # The convolution dilation. If None, assume all 1's. - dilation: Optional[tuple[int, int, int]] = None - - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - - # The number of groups in which the input is split along the channel axis. - # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, - # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). - # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers - # side by side, each seeing half the input and producing half the output channels. - # - At num_input_dim_groups=input_dim, each input channel is convolved with its own - # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, - # where K is a positive integer, the operation is also known as a "depthwise convolution". - num_input_dim_groups: Optional[int] = 1 - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, None, None, None) - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - _check_conv_cfg( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - params = dict( - weight=ParameterSpec( - shape=list(cfg.window) - + [cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward(self, x: Tensor) -> Tensor: - cfg = self.config - conv_padding = conv_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) - - def _conv( - self, - x: Tensor, - *, - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Optional[Sequence[int]], - ) -> Tensor: - cfg = self.config - output = jax.lax.conv_general_dilated( - lhs=x, - rhs=self.parameters["weight"], - window_strides=strides, - padding=padding, - rhs_dilation=dilation, - dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), - feature_group_count=cfg.num_input_dim_groups, - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 5: - raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - f"cfg.input_dim = {cfg.input_dim}." - ) - - in_shape = input_shape[1:4] - out_shape = conv_output_shape( - in_shape, - window=cfg.window, - strides=cfg.strides, - padding=cfg.padding, - dilation=cfg.dilation, - ) - return [input_shape[0], *out_shape, cfg.output_dim] - - -class Conv1D(BaseConv): - """The 1D convolution layer. - - Kernel weights have the WIO layout and in the shape of (window, input_dim, output_dim). - Both inputs and outputs will be in the NWC layout. - """ - - @config_class - class Config(BaseConv.Config): - """Configures Conv1D.""" - - window: Required[int] = REQUIRED # The convolution window. - strides: int = 1 # The convolution strides. - # Paddings: "SAME", "VALID", "CAUSAL", or (left, right). - # For causal convolution, set padding to (window - 1, 0). - padding: ConvPaddingType = ((0, 0),) - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - # The number of groups in which the input is split along the channel axis. - # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, - # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). - # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers - # side by side, each seeing half the input and producing half the output channels. - # - At num_input_dim_groups=input_dim, each input channel is convolved with its own - # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, - # where K is a positive integer, the operation is also known as a "depthwise convolution". - num_input_dim_groups: Optional[int] = 1 - # The convolution dilation, indicating dilation factor applied to the weight. It is also - # known as atrous convolution or dilated convolution. If None, assume 1. - dilation: Optional[int] = None - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, "model") - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - dilation = cfg.dilation or 1 - _check_conv_cfg( - window=(cfg.window,), - strides=(cfg.strides,), - padding=cfg.padding, - dilation=(dilation,), - ) - if cfg.padding not in SUPPORT_CONV_PADDING: - left, right = cfg.padding[0] - if any(p < 0 for p in (left, right)): - raise NotImplementedError("Negative padding is not supported") - params = dict( - weight=ParameterSpec( - shape=[cfg.window, cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], - mesh_axes=cfg.param_partition_spec, - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward(self, x: Tensor) -> Tensor: - cfg = self.config - dilation = cfg.dilation or 1 - conv_padding = conv_explicit_padding( - window=(cfg.window,), - strides=(cfg.strides,), - padding=cfg.padding, - dilation=(dilation,), - ) - return self._conv(x=x, strides=(cfg.strides,), padding=conv_padding, dilation=(dilation,)) - - def _conv( - self, - x: Tensor, - *, - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Optional[Sequence[int]], - ) -> Tensor: - cfg = self.config - output = jax.lax.conv_general_dilated( - lhs=x, - rhs=self.parameters["weight"], - window_strides=strides, - padding=padding, - rhs_dilation=dilation, - dimension_numbers=("NWC", "WIO", "NWC"), - feature_group_count=cfg.num_input_dim_groups, - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 3: - raise ValueError(f"We expect len(input_shape) = 3, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - f"cfg.input_dim = {cfg.input_dim}." - ) - - in_shape = input_shape[1:2] - dilation = cfg.dilation or 1 - out_shape = conv_output_shape( - in_shape, - window=(cfg.window,), - strides=(cfg.strides,), - padding=cfg.padding, - dilation=(dilation,), - ) - return [input_shape[0], *out_shape, cfg.output_dim] - - -class Conv1DWithPadding(Conv1D): - """The 1-D convolution with 1-D padding on the time axis.""" - - @config_class - class Config(Conv1D.Config): - """Configures Conv1DWithPadding.""" - - # An optional integer in the range of [left_time_padding, window - right_time_padding) - # that specifies the anchor position within the convolution window that is used to - # determine output paddings. Specifically, the output token is valid iff the input token - # at the anchor position of the corresponding window is valid. - # If None, defaults to left time padding. See Conv2DWith1DPadding more details. - anchor: Optional[int] = None - - # We add a kwargs "paddings" to the forward method. - # pylint: disable-next=arguments-differ - def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: - """Computes convolution outputs and paddings. - - Args: - x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. - paddings: 0/1 Tensor of shape [batch_size, seq_len]. - - Returns: - output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. - paddings: 0/1 Tensor of shape [batch_size, seq_len]. - """ - cfg = self.config - chex.assert_rank(x, paddings.ndim + 1) - # Apply padding to the input. - x = x * (1 - paddings[..., None]) - - # Apply Conv1D. - output = super().forward(x) - - # Compute paddings conv output. - output_paddings = compute_conv_paddings( - paddings, - window=cfg.window, - stride=cfg.strides, - conv_padding=cfg.padding, - dilation=cfg.dilation, - anchor=cfg.anchor, - ) - # Apply padding to the outputs. - output = output * (1 - output_paddings[..., None]) - return output, output_paddings - - -############################## Transposed Convolution ############################################## - - -# Based on jax.lax.convolution._conv_transpose_padding, but ours is more intuitive. -def conv_transpose_explicit_padding( - *, - window: Sequence[int], - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Sequence[int], -) -> ConvPaddingType: - """Convert str padding to tuple padding for conv_transpose. - - Each mode follows the formulas below, - * SAME: (min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) - pad_total = window+stride-2 - when stride > window -> (window-1, stride-1) - * VALID: (window-1, max(stride-1, window-1)) - pad_total = window+stride-2 + max(window-stride, 0) - when stride > window -> (window-1, stride-1) - * CAUSAL: (window-1, stride-1) - pad_total = window+stride-2 - - Note: output_size = input_size*stride - (window+stride-2) + pad_total - = input_size*stride <- "SAME", "CAUSAL" - = input_size*stride + max(window-stride, 0) <- "VALID" - - Note: In the above equation, `window` will be replaced with `dilate_window` when dilation > 1. - dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() - - The following illustration demonstrates how Conv Transpose operates, assuming all kernel values - are set to 1 for simplicity in showcasing output values. - - In the window=3 and stride=1 case, this function creates outputs as follows: - * "SAME" padding=(1, 1) - pad| |pad - paddings: 0|0 0 1 1|0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 1 -> 2 - 1 1 0 -> 2 - - * "VALID" padding=(2, 2) - pad | |pad - paddings: 0 0|0 0 1 1|0 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 1 -> 2 - 1 1 0 -> 2 - 1 0 0 -> 1 - - * "CAUSAL" padding=(2, 0) - pad | |pad - paddings: 0 0|0 0 1 1| - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 1 -> 2 - - In the window=3 and stride=2 case, this function creates outputs as follows: - * "SAME" padding=(2, 1) - pad | |pad - paddings: 0 0|0 * 0 * 1 * 1|0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 0 -> 1 - 1 0 1 -> 2 - 0 1 0 -> 1 - - * "VALID" padding=(2, 2) - pad | |pad - paddings: 0 0|0 * 0 * 1 * 1|0 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 0 -> 1 - 1 0 1 -> 2 - 0 1 0 -> 1 - 1 0 0 -> 1 - - * "CAUSAL" padding=(2, 1) - pad | |pad - paddings: 0 0|0 * 0 * 1 * 1|0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 0 -> 1 - 1 0 1 -> 2 - 0 1 0 -> 1 - - In the window=3 and stride=3 case, this function creates outputs as follows: - * "SAME", "VALID" and "CAUSAL" padding=(2, 2) - pad | |pad - paddings: 0 0|0 * * 0 * * 1 * * 1|0 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 0 -> 1 - 1 0 0 -> 1 - 0 0 1 -> 1 - 0 1 0 -> 1 - 1 0 0 -> 1 - - In the window=3 and stride=4 case, this function creates outputs as follows: - * "SAME", "VALID" and "CAUSAL" padding=(2, 3) - pad | |pad - paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 0 -> 1 - 1 0 0 -> 1 - 0 0 0 -> 0 - 0 0 1 -> 1 - 0 1 0 -> 1 - 1 0 0 -> 1 - 0 0 0 -> 0 - Here is how to compute output_size, given the above example, - 1. |_| -(window-1) - 2. |_______________________| (input_size-1)*stride + 1 - 3. |_| |___| + pad_total - - So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total - = input_size*stride - (window+stride-2) + pad_total - = input_size*stride <- "SAME", "CAUSAL" - = input_size*stride + max(window-stride, 0) <- "VALID" - - OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. - For example, when window=3 and dilation=2, dilate_window=5. - - In the stride=2 case, this function creates outputs as follows: - * "SAME" padding=(3, 2) - pad | |pad - paddings: 0 0 0|0 * 0 * 1 * 1|0 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 1 -> 1 - 0 * 0 * 0 -> 0 - 0 * 1 * 1 -> 2 - 0 * 0 * 0 -> 0 - 1 * 1 * 0 -> 2 - - * "VALID" padding=(4, 4) - pad | |pad - paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 1 -> 1 - 0 * 0 * 0 -> 0 - 0 * 1 * 1 -> 2 - 0 * 0 * 0 -> 0 - 1 * 1 * 0 -> 2 - 0 * 0 * 0 -> 0 - 1 * 0 * 0 -> 1 - - * "CAUSAL" padding=(4, 1) - pad | |pad - paddings: 0 0 0 0|0 * 0 * 1 * 1|0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 0 -> 0 - 0 * 0 * 1 -> 1 - 0 * 0 * 0 -> 0 - 0 * 1 * 1 -> 2 - 0 * 0 * 0 -> 0 - - For "CAUSAL", the first component is time and treated as "CAUSAL", while the remaining - components are handled with "SAME" padding. - - Args: - window: convolution window. - strides: transposed convolution strides. It's lhs_dilation, not window_stride. - padding: convolution padding. - dilation: convolution dilation, a.k.a rhs_dilation. - - Returns: - The padding tuple. - - Raises: - ValueError: If padding is not supported. - """ - if not isinstance(padding, str): - return padding - - window = conv_dilate_window(window=window, dilation=dilation) - - def same_padding(window, strides): - pad_left = tuple(min(w - 1, (w + s - 1) // 2) for w, s in zip(window, strides)) - pad_right = tuple(max(s - 1, (w + s - 2) // 2) for w, s in zip(window, strides)) - return tuple(zip(pad_left, pad_right)) - - if padding == "SAME": - return same_padding(window, strides) - elif padding == "VALID": - pad_left = tuple(w - 1 for w in window) - pad_right = tuple(max(s - 1, w - 1) for w, s in zip(window, strides)) - return tuple(zip(pad_left, pad_right)) - elif padding == "CAUSAL": - causal_padding = ((window[0] - 1, strides[0] - 1),) - if len(window) > 1: - causal_padding += same_padding(window[1:], strides[1:]) - return causal_padding - else: - raise ValueError(f"{padding} padding is not supported.") - - -def conv_transpose_output_shape( - in_shape: Sequence[Optional[int]], - *, - window: Sequence[int], - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Sequence[int], -) -> Sequence[int]: - """Returns output size for conv transpose. - - Each mode follows the formulas below, - * SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) - pad_total = window+stride-2 - output_size = input_size*stride - * VALID: padding=(window-1, max(stride-1, window-1)) - pad_total = window+stride-2 + max(window-stride, 0) - output_size = input_size*stride + max(window-stride, 0) - * CAUSAL: padding=(window-1, stride-1) - pad_total = window+stride-2 - output_size = input_size*stride - - Note: In the above equation, `window` will be replaced with `dilate_window` when dilation > 1. - dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() - - Refer to - https://towardsdatascience.com/understand-transposed-convolutions-and-build-your-own-transposed-convolution-layer-from-scratch-4f5d97b2967 - - Args: - in_shape: convolution lhs shape. - window: convolution window. - strides: convolution strides. - padding: convolution padding. - dilation: convolution dilation. - - Returns: - The output shape. - - Raises: - ValueError: If the length of in_shape, window, strides, and padding are not equal. - """ - if len(in_shape) != len(window) or len(in_shape) != len(strides): - raise ValueError( - f"len(in_shape) = {len(in_shape)} must be equal to " - f"len(window) = {len(window)} and len(strides) = {len(strides)}" - ) - - window = conv_dilate_window(window=window, dilation=dilation) - - def output_shape(in_shape: Optional[int], window: int, stride: int): - if in_shape is None: - return None - - if padding == "SAME": - return in_shape * stride - elif padding == "VALID": - return in_shape * stride + max(window - stride, 0) - elif padding == "CAUSAL": - return in_shape * stride - else: - raise ValueError(f"{padding} padding is not supported.") - - return tuple(map(output_shape, in_shape, window, strides)) - - -def compute_conv_transpose_paddings( - in_paddings: Tensor, - *, - window: int, - stride: int, - conv_padding: ConvPaddingType, - dilation: int = 1, - anchor: Optional[int] = None, -): - """Compute output paddings w.r.t. conv_padding for conv transpose. - - The output paddings value is determined by the padding value at the anchor point in the - window. If anchor is None, the default anchor point is the left time padding from conv - padding config. See `Conv2DWith1DPadding.Config` in details. - - In the window=3 and stride=1 case, this function creates paddings as follows: - - The following illustration demonstrates how Conv Transpose operates, assuming all kernel values - are set to 1 for simplicity in showcasing output values. - - In the window=3 and stride=1 case, this function creates outputs as follows: - * "SAME" padding=(1, 1) - pad| |pad - paddings: 0|0 0 1 1|1 - |_____| - * 0 * -> 0 - * 0 * -> 0 - * 1 * -> 1 - * 1 0 -> 1 - - * "VALID" padding=(2, 2) - pad | |pad - paddings: 0 0|0 0 1 1|1 1 - |_________| - * * 0 -> 0 - * * 0 -> 0 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - - * "CAUSAL" padding=(2, 0) - pad | |pad - paddings: 0 0|0 0 1 1| - |_____| - * * 0 -> 0 - * * 0 -> 0 - * * 1 -> 1 - * * 1 -> 1 - - In the window=3 and stride=2 case, this function creates outputs as follows: - * "SAME" padding=(2, 1) - pad | |pad - paddings: 0 0|0 * 0 * 1 * 1|1 - |_____________| - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - - * "VALID" padding=(2, 2) - pad | |pad - paddings: 0 0|0 * 0 * 1 * 1|1 1 - |_______________| - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - - * "CAUSAL" padding=(2, 1) - pad | |pad - paddings: 0 0|0 * 0 * 1 * 1|1 - |_____________| - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - - In the window=3 and stride=3 case, this function creates outputs as follows: - * "SAME", "VALID" and "CAUSAL" padding=(2, 2) - pad | |pad - paddings: 0 0|0 * * 0 * * 1 * * 1|1 1 - |_____________________| - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 0 -> 0 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - * * 1 -> 1 - - OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. - For example, when window=3 and dilation=2, dilate_window=5. - - In the stride=2 case, this function creates outputs as follows: - * "SAME" padding=(3, 2) - pad | |pad - paddings: 0 0 0|0 * 0 * 1 * 1|1 1 - |_____________| - * * * 0 * -> 0 - * * * 0 * -> 0 - * * * 0 * -> 0 - * * * 0 * -> 0 - * * * 1 * -> 1 - * * * 1 * -> 1 - * * * 1 * -> 1 - * * * 1 * -> 1 - - * "VALID" padding=(4, 4) - pad | |pad - paddings: 0 0 0 0|0 * 0 * 1 * 1|1 1 1 1 - |___________________| - * * * * 0 -> 0 - * * * * 0 -> 0 - * * * * 0 -> 0 - * * * * 0 -> 0 - * * * * 1 -> 1 - * * * * 1 -> 1 - * * * * 1 -> 1 - * * * * 1 -> 1 - * * * * 1 -> 1 - * * * * 1 -> 1 - * * * * 1 -> 1 - - * "CAUSAL" padding=(4, 1) - pad | |pad - paddings: 0 0 0 0|0 * 0 * 1 * 1|1 - |_____________| - * * * * 0 -> 0 - * * * * 0 -> 0 - * * * * 0 -> 0 - * * * * 0 -> 0 - * * * * 1 -> 1 - * * * * 1 -> 1 - * * * * 1 -> 1 - * * * * 1 -> 1 - - Args: - in_paddings: A Tensor of shape [batch_size, seq_len]. - window: convolution window size of the time axis. - stride: convolution stride size of the time axis. - conv_padding: "SAME", "VALID", "CAUSAL" or ((left_time_padding, right_time_padding),) - dilation: convolution dilation size of the time axis. - anchor: an optional integer in the range of [0, window) - that specifies the anchor position within the convolution window that is used to - determine output paddings. Specifically, the output token is valid iff the input token - at the anchor position of the corresponding window is valid. - If None, anchor defaults to conv_padding[0] (i.e. left_time_padding). - - Returns: - out_paddings: A Tensor of shape [batch_size, seq_len]. - - Raises: - ValueError: If anchor is not between left_time_padding and window. - """ - - chex.assert_rank(in_paddings, 2) - conv_padding = conv_transpose_explicit_padding( - window=(window,), strides=(stride,), padding=conv_padding, dilation=(dilation,) - ) - window = conv_dilate_window(window=(window,), dilation=(dilation,))[0] - # Note: in transposed conv, left_pad + right_pad >= window - 1. - # See conv_transpose_explicit_padding(). - left_pad, right_pad = conv_padding[0] - - if anchor is None: - anchor = left_pad - # elif not left_pad <= anchor < window: - elif not anchor < window: - raise ValueError(f"anchor ({anchor}) must in range [0, {window}).") - - # Consider the case where window=3, strides=2, dilation=2, and padding="SAME" - # explicit padding=(3, 2) - # pad | |pad - # paddings: 0 0 0|0 * 0 * 1 * 1|1 1 - # |_____________| - # * * * 0 * -> 0 - # * * * 0 * -> 0 - # * * * 0 * -> 0 - # * * * 0 * -> 0 - # * * * 1 * -> 1 - # * * * 1 * -> 1 - # * * * 1 * -> 1 - # * * * 1 * -> 1 - - # |0 0 1 1| -> |0 * 0 * 1 * 1| - def dilate_paddings(paddings): - most, last = jnp.split(paddings, [paddings.shape[1] - 1], axis=1) - dilated = einops.repeat(most, "b t -> b (t s)", s=stride) - return jnp.concatenate([dilated, last], axis=1) - - in_paddings = dilate_paddings(in_paddings) - - # |0 * 0 * 1 * 1| -> 0 0 0|0 * 0 * 1 * 1|1 1 - # |_____________| which is |0 * 0 * 1 * 1|1 - window_pad_total = window - 1 # Note: we already check `anchor < window`` always. - window_right_pad = window_pad_total - anchor - assert window_right_pad >= 0, f"{anchor=} < {window=} always." - # Note: left_pad + right_pad >= window + stride - 2 >= window - 1 == anchor + window_right_pad - valid_right_pad = right_pad - window_right_pad - if valid_right_pad >= 0: - out_paddings = jnp.pad(in_paddings, ((0, 0), (0, valid_right_pad)), mode="edge") - else: - out_paddings = in_paddings[:, :valid_right_pad] - - start_index = anchor - left_pad - if start_index < 0: - out_paddings = jnp.pad(out_paddings, ((0, 0), (-start_index, 0)), mode="edge") - else: - out_paddings = out_paddings[:, start_index:] - return out_paddings - - -class Conv1DTranspose(BaseConv): - """The 1D transposed convolution layer.""" - - @config_class - class Config(BaseConv.Config): - """Configures Conv1DTranspose.""" - - window: int = 1 - strides: int = 1 - padding: Required[ConvPaddingType] = REQUIRED - dilation: int = 1 # Dilation for dilated Convolution. - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - - # An optional integer in the range of [0, window) - # that specifies the anchor position within the convolution window that is used to - # determine output paddings. Specifically, the output token is valid iff the input token - # at the anchor position of the corresponding window is valid. - # If None, defaults to left time padding. See compute_conv_transpose_paddings more details. - anchor: Optional[int] = None - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, "model") - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - _check_conv_cfg( - window=(cfg.window,), - strides=(cfg.strides,), - padding=cfg.padding, - dilation=(cfg.dilation,), - ) - params = dict( - weight=ParameterSpec( - shape=(cfg.window, cfg.input_dim, cfg.output_dim), - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=(None, "row", "col")), - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward( - self, x: Tensor, *, paddings: Optional[Tensor] = None - ) -> tuple[Tensor, Optional[Tensor]]: - cfg = self.config - conv_padding = conv_transpose_explicit_padding( - window=(cfg.window,), - strides=(cfg.strides,), - padding=cfg.padding, - dilation=(cfg.dilation,), - ) - - if paddings is not None: - chex.assert_rank(x, paddings.ndim + 1) - # Apply padding to the input. - x = x * (1 - paddings[..., None]) - - output = self._conv( - x=x, strides=(cfg.strides,), padding=conv_padding, dilation=(cfg.dilation,) - ) - - if paddings is None: - output_paddings = None - else: - # Compute paddings conv output. - output_paddings = compute_conv_transpose_paddings( - paddings, - window=cfg.window, - stride=cfg.strides, - conv_padding=cfg.padding, - dilation=cfg.dilation, - anchor=cfg.anchor, - ) - output = output * (1 - output_paddings[..., None]) - return output, output_paddings - - def _conv( - self, - x: Tensor, - *, - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Sequence[int], - ) -> Tensor: - cfg = self.config - output = jax.lax.conv_transpose( - lhs=x, - rhs=self.parameters["weight"], - strides=strides, - padding=padding, - rhs_dilation=dilation, - dimension_numbers=("NWC", "WIO", "NWC"), - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 3: - raise ValueError(f"We expect len(input_shape) = 3, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - "cfg.input_dim = {cfg.input_dim}." - ) - - in_shape = input_shape[1:2] - out_shape = conv_transpose_output_shape( - in_shape, - window=(cfg.window,), - strides=(cfg.strides,), - padding=cfg.padding, - dilation=(cfg.dilation,), - ) - return [input_shape[0], *out_shape, cfg.output_dim] - - -class Conv2DTransposeWith1DPadding(Conv2DTranspose): - """The 2-D convolution transpose with 1-D padding on the time axis.""" - - @config_class - class Config(Conv2DTranspose.Config): - """Configures Conv2DTransposeWith1DPadding.""" - - transpose_kernel: bool = False - # An optional integer in the range of [0, window) - # that specifies the anchor position within the convolution window that is used to - # determine output paddings. Specifically, the output token is valid iff the input token - # at the anchor position of the corresponding window is valid. - # If None, defaults to left time padding. See compute_conv_transpose_paddings more details. - anchor: Optional[int] = None - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.transpose_kernel = False # Choose better one unlike parent. - return cfg - - # We add a kwargs "paddings" to the forward method. - # pylint: disable-next=arguments-differ - def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: - """Computes convolution outputs and paddings. - - Args: - x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. - paddings: 0/1 Tensor of shape [batch_size, seq_len]. - - Returns: - output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. - paddings: 0/1 Tensor of shape [batch_size, seq_len]. - """ - cfg = self.config - # Apply padding to the input. - assert len(x.shape) == len(paddings.shape) + 2 - x = x * (1 - paddings[..., None, None]) - - # Apply Conv2D. - output = super().forward(x) - # Compute paddings conv output. - output_paddings = compute_conv_transpose_paddings( - paddings, - window=cfg.window[0], - stride=cfg.strides[0], - conv_padding=cfg.padding, - dilation=cfg.dilation[0], - anchor=cfg.anchor, - ) - # Apply padding to the outputs. - output = output * (1 - output_paddings[..., None, None]) - return output, output_paddings - - -class Conv3DTranspose(BaseConv): - """The 3-D convolution transpose layer.""" - - @config_class - class Config(BaseConv.Config): - """Configures Conv3DTranspose.""" - - window: tuple[int, int, int] = (1, 1, 1) # The convolution window. - strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. - # Paddings: "SAME", "VALID or "CAUSAL", or ((top, bottom), (left, right), (front, back)) - padding: Required[ConvPaddingType] = REQUIRED - dilation: tuple[int, int, int] = (1, 1, 1) # The convolution dilation. - - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, None, None, "model") - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - _check_conv_cfg( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - params = dict( - weight=ParameterSpec( - shape=cfg.window + (cfg.input_dim, cfg.output_dim), - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward(self, x: Tensor) -> Tensor: - cfg = self.config - conv_padding = conv_transpose_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation - ) - return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) - - def _conv( - self, - x: Tensor, - *, - strides: Sequence[int], - padding: ConvPaddingType, - dilation: Sequence[int], - ) -> Tensor: - cfg = self.config - output = jax.lax.conv_transpose( - lhs=x, - rhs=self.parameters["weight"], - strides=strides, - padding=padding, - rhs_dilation=dilation, - dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 5: - raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - f"cfg.input_dim = {cfg.input_dim}." - ) - - in_shape = input_shape[1:4] - out_shape = conv_transpose_output_shape( - in_shape, - window=cfg.window, - strides=cfg.strides, - padding=cfg.padding, - dilation=cfg.dilation, - ) - return [input_shape[0], *out_shape, cfg.output_dim] - - -############################## Others ############################################################## - - class Embedding(BaseLayer): """Implements an embedding lookup function. @@ -2928,90 +1244,6 @@ def forward(self, inputs: Tensor) -> Tensor: return get_activation_fn(cfg.gating)(x) * inputs -class StackOverTime(BaseLayer): - """Stack inputs along the time axis. - - StackOverTime behaves the same as Conv2DWith1DPadding w.r.t. paddings along the time axis. - Please refer to the docstring of Conv2DWith1DPadding to understand how the padding work - including "SAME", "VALID", and "CAUSAL" literals. The padding anchor is set to `left padding`. - """ - - @config_class - class Config(BaseLayer.Config): - """Configures StackOverTime.""" - - stride: Required[int] = REQUIRED # Number of frames to stack. - - # Number of paddings to apply along the time axis. The two integers specify the amount - # of leading and trailing padding, respectively. Alternatively, this can be a - # convolution padding literals type such as 'SAME', 'VALID', or 'CAUSAL'. - # Note: For backward compatibility, the default is set to VALID, but in most cases, - # CAUSAL is more appropriate as it preserves the sequence length. - padding: Union[tuple[int, int], Literal["SAME", "VALID", "CAUSAL"]] = "VALID" - - def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: - """Stacks stride number of frames into one frame along the time axis. - - Args: - inputs: Tensor of shape [batch, time, input_dim]. - paddings: 0/1 Tensor of shape [batch, time], paddings of the input sequences. - - Returns: - stacked_inputs: Tensor of shape [batch, time // stride, input_dim * stride]. - stacked_paddings: 0/1 Tensor of shape [batch, time // stride]. An output frame - is padding if at least one of the stacked input frames is padding. - - Raises: - ValueError: If stride is <= 1. - """ - cfg = self.config - if cfg.stride <= 1: - raise ValueError(f"stride should be greater than 1, but got {cfg.stride}.") - - # For the last partial frame. - inputs = inputs * (1 - paddings)[:, :, None] - - padding = cfg.padding - if isinstance(padding, str): - padding = conv_explicit_padding( - window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) - )[0] - inputs = jnp.pad(inputs, ((0, 0), padding, (0, 0)), constant_values=0) - - batch_size, seq_len, input_dim = inputs.shape - output_length = seq_len // cfg.stride - new_shape = [batch_size, output_length, input_dim * cfg.stride] - # Stack inputs over the time dimension. - stacked_inputs = jnp.reshape(inputs[:, : output_length * cfg.stride, :], new_shape) - # An output frame is padding if at least one of the stacked input frames is padding. - stacked_paddings = compute_conv_paddings( - paddings, window=cfg.stride, stride=cfg.stride, conv_padding=(padding,) - ) - stacked_inputs = stacked_inputs * (1 - stacked_paddings)[:, :, None] - return stacked_inputs, stacked_paddings - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - """Computes stacked output shape. - - Args: - input_shape: The input dimensions are (batch, time, feature_dim). - If the value of the dimension is not available, use None. - - Returns: - The output shape. The dimensions are (batch, time, feature_dim). - """ - cfg = self.config - batch_size, seq_len, input_dim = input_shape - padding = cfg.padding - if isinstance(padding, tuple): - padding = (padding,) - out_shape = conv_output_shape( - [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) - ) - return [batch_size, *out_shape, input_dim * cfg.stride] - - class MultiLinear(BaseLayer): """A linear layer with multiple outputs.""" diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index d57dccc9e..c8b4c11f4 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -11,7 +11,7 @@ """Tests basic layers.""" -# pylint: disable=no-self-use,too-many-lines,too-many-public-methods +# pylint: disable=no-self-use import copy import itertools import math @@ -19,7 +19,6 @@ from functools import partial from typing import Optional, Union -import einops import jax.random import numpy as np import tensorflow as tf @@ -30,7 +29,7 @@ from sklearn.metrics import precision_score as sklearn_precision_score from sklearn.metrics import recall_score as sklearn_recall_score -from axlearn.common import layers, module, utils +from axlearn.common import module, utils from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import config_class from axlearn.common.decoder import Decoder @@ -39,13 +38,6 @@ BinaryClassificationMetric, CategoricalHingeLossMetric, ClassificationMetric, - Conv1D, - Conv1DWithPadding, - Conv2D, - Conv2DTranspose, - Conv2DWith1DPadding, - Conv3D, - ConvPaddingType, DropToken, Embedding, GroupNorm, @@ -61,12 +53,10 @@ RMSNorm, SeparableSpaceTimePositionalEmbedding, SqueezeExcitation, - StackOverTime, StochasticDepth, UnitNormLinear, VariationalNoise, _compute_moments_with_paddings, - compute_conv_paddings, get_activation_fn, get_stochastic_depth_linear_rate, set_bias_recursively, @@ -90,7 +80,6 @@ def _copy(src: jnp.ndarray, dst: torch.nn.Parameter): dst.copy_(src) -# pylint: disable=too-many-public-methods class LayerTest(TestCase, tf.test.TestCase): @parameterized.parameters( "linear", @@ -754,1703 +743,6 @@ def test_maxpool2d( output_shape = layer.output_shape(input_shape=inputs.shape) self.assertAllEqual(outputs.shape, output_shape) - # Fails if tolerance is made smaller. - @parameterized.named_parameters( - { - "testcase_name": "1x1", - "window": (1, 1), - "strides": (1, 1), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "2x2_VALID", - "window": (2, 2), - "strides": (1, 1), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "2x2_SAME", - "window": (2, 2), - "strides": (1, 1), - "padding": "SAME", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "2x2_S2_VALID", - "window": (2, 2), - "strides": (2, 2), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3_VALID", - "window": (3, 3), - "strides": (1, 1), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3_SAME", - "window": (3, 3), - "strides": (1, 1), - "padding": "SAME", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3_S2_VALID", - "window": (3, 3), - "strides": (2, 2), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3_S2_PADDING1", - "window": (3, 3), - "strides": (2, 2), - "padding": (1, 1), - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3_GROUPS4", - "window": (3, 3), - "strides": (1, 1), - "padding": "SAME", - "num_input_dim_groups": 4, - }, - ) - def test_conv2d( - self, - window: tuple[int, int], - strides: tuple[int, int], - padding: Union[str, tuple[int, int]], - num_input_dim_groups: int, - ): - input_dim, output_dim = 256, 128 - if isinstance(padding, tuple): - conv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) - else: - conv_padding = padding - cfg = Conv2D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=conv_padding, - num_input_dim_groups=num_input_dim_groups, - ) - layer: Conv2D = cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict( - weight=(window[0], window[1], input_dim // num_input_dim_groups, output_dim), - bias=(output_dim,), - ), - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [2, 10, 7, input_dim]) - - # Compute layer outputs. - outputs, _ = F( - layer, - inputs=(inputs,), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - - # Compute ref outputs. - ref_padding = padding.lower() if isinstance(padding, str) else padding - ref = torch.nn.Conv2d( - in_channels=input_dim, - out_channels=output_dim, - kernel_size=window, - stride=strides, - padding=ref_padding, - groups=num_input_dim_groups, - ) - # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). - _copy(layer_params["weight"].transpose(3, 2, 0, 1), ref.weight) - _copy(layer_params["bias"], ref.bias) - ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 3, 1, 2))) - # We currently don't match PyTorch as closely as we would like. - assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 1), atol=4e-6) - # Tests output_shape. - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - - @parameterized.parameters((1, 1, 1), (1, 2, 1), (2, 1, 2), (3, 1, 3), (3, 2, 5)) - def test_conv_dilate_window(self, window, dilation, expected): - effective_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] - self.assertEqual(effective_window, expected) - - @parameterized.parameters( - (10, 3, 1, "SAME", 1, 10), - (10, 3, 2, "SAME", 1, 5), - (10, 3, 1, "SAME", 2, 10), - (10, 3, 2, "SAME", 2, 5), - (10, 3, 1, "VALID", 1, 8), - (10, 3, 2, "VALID", 1, 4), - (10, 3, 1, "VALID", 2, 6), - (10, 3, 2, "VALID", 2, 3), - (10, 3, 1, "CAUSAL", 1, 10), - (10, 3, 2, "CAUSAL", 1, 5), - (10, 3, 1, "CAUSAL", 2, 10), - (10, 3, 2, "CAUSAL", 2, 5), - ) - def test_conv_output_shape(self, in_shape, window, strides, padding, dilation, expected): - out_shape = layers.conv_output_shape( - in_shape=(in_shape,), - window=(window,), - strides=(strides,), - padding=padding, - dilation=(dilation,), - )[0] - self.assertEqual(out_shape, expected) - - @parameterized.parameters( - ([0, 0, 0, 1], [0, 0, 0, 1], 1, "SAME"), - ([0], [], 1, "VALID"), - ([0, 0], [], 1, "VALID"), - ([0, 0, 0], [0], 1, "VALID"), - ([0, 0, 0, 0], [0, 0], 1, "VALID"), - ([0, 0, 0, 1], [0, 0], 1, "VALID"), - ([0, 0, 0, 0], [0], 2, "VALID"), - ([0, 0, 0, 1], [0], 2, "VALID"), - ([0, 0, 1, 1], [0], 2, "VALID"), - ([0, 0, 0, 0, 0], [0, 0], 2, "VALID"), - ([0, 0, 0, 0, 1], [0, 0], 2, "VALID"), - ([0, 0, 0, 1, 1], [0, 0], 2, "VALID"), - ([0, 0, 1, 1, 1], [0, 1], 2, "VALID"), - ([0, 0, 0, 0, 0, 0], [0, 0], 2, "VALID"), - ([0, 0, 0, 0, 0, 1], [0, 0], 2, "VALID"), - ([0, 0, 0, 0, 1, 1], [0, 0], 2, "VALID"), - ([0, 0, 0, 1, 1, 1], [0, 0], 2, "VALID"), - ([0, 0, 1, 1, 1, 1], [0, 1], 2, "VALID"), - ) - def test_conv_padding(self, input_paddings, expected_paddings, stride: int, padding_cfg: str): - """Tests conv_output_shape() with SAME and VALID padding cfg.""" - # This test is from lingvo - # https://github.com/tensorflow/lingvo/blob/master/lingvo/core/conv_layers_with_time_padding_test.py#L157. - window = 3 - out_paddings = compute_conv_paddings( - jnp.array([input_paddings]), window=window, stride=stride, conv_padding=padding_cfg - ) - assert_allclose(out_paddings[0], expected_paddings) - - @parameterized.parameters( - (5, 1, "SAME", 1, (2, 2)), - (5, 2, "SAME", 1, (2, 2)), - (5, 3, "SAME", 1, (2, 2)), - (5, 1, "SAME", 2, (4, 4)), - (5, 2, "SAME", 2, (4, 4)), - (5, 3, "SAME", 2, (4, 4)), - (5, 1, "VALID", 1, (0, 0)), - (5, 2, "VALID", 1, (0, 0)), - (5, 3, "VALID", 1, (0, 0)), - (5, 1, "VALID", 2, (0, 0)), - (5, 2, "VALID", 2, (0, 0)), - (5, 3, "VALID", 2, (0, 0)), - (5, 1, "CAUSAL", 1, (4, 0)), - (5, 2, "CAUSAL", 1, (3, 1)), - (5, 3, "CAUSAL", 1, (2, 2)), - (5, 1, "CAUSAL", 2, (8, 0)), - (5, 2, "CAUSAL", 2, (7, 1)), - (5, 3, "CAUSAL", 2, (6, 2)), - ) - def test_conv_explicit_padding( - self, window: int, stride: int, padding: ConvPaddingType, dilation: int, expected - ): - """Tests the cases in conv_explicit_padding() description.""" - explicit_padding = layers.conv_explicit_padding( - window=(window,), - strides=(stride,), - padding=padding, - dilation=(dilation,), - ) - self.assertAllEqual(explicit_padding[0], expected) - - @parameterized.parameters( - (5, 1, "SAME", [0, 0, 0, 0, 1, 1]), - (5, 2, "SAME", [0, 0, 1]), - (5, 1, "VALID", [0, 0]), - (5, 2, "VALID", [0]), - (5, 1, "SAME", [0, 0, 0, 0, 1, 1]), - (5, 2, "SAME", [0, 0, 1]), - ) - def test_conv_output_1d_padding_simple( - self, window: int, stride: int, padding: ConvPaddingType, expected - ): - """Tests the cases in conv_explicit_padding() description.""" - paddings = jnp.array([[0, 0, 0, 0, 1, 1]]) - out_paddings = compute_conv_paddings( - paddings, window=window, stride=stride, conv_padding=padding - ) - self.assertAllEqual(out_paddings[0], expected) - - @parameterized.parameters( - ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0]), - ([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0]), - ([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0]), - ([1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0]), - ([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1]), - ([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 1, 1]), - ([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 1, 1]), - ([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 1]), - ) - def test_conv_output_1d_padding_causal(self, in_paddings, expected): - """Test the below cases. - - The formula for CAUSAL padding is `(window - stride, stride - 1)`. - With window=15 and stride=6, padding is (9, 5). - Below are examples illustrating how input paddings are transformed into output - paddings across different scenarios. - - left_pad | input paddings -> outputs paddings - 1) |1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|0 0 0|0 0 0|0 0 0|0 0 0| -> 0 0 0 - 2) |1 1 1|1 1 1|1 1 1|1 0 0|0 0 0|0 0 0|0 0 0|0 0 0|0 0 0| -> 1 0 0 - 3) |1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|0 0 0|0 0 0| -> 1 0 0 - 4) |1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 0 0|0 0 0|0 0 0|0 0 0| -> 1 1 0 - 5) |1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1| -> 1 1 1 - 6) |1 1 1|1 1 1|1 1 1|0 1 1|1 1 1|1 1 1|1 1 1|1 1 1|1 1 1| -> 0 1 1 - 7) |1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|1 1 1|1 1 1|1 1 1|1 1 1| -> 0 1 1 - 8) |1 1 1|1 1 1|1 1 1|0 0 0|0 0 0|0 1 1|1 1 1|1 1 1|1 1 1| -> 0 0 1 - |_________________^_________| - |_________________^_________| - |_________________^_________| - - Let's take a closer look at case 7). In case 7), the first window component fully - covers all 0s, so the first component of the output padding should be the last - 0 component, meaning the second component is 1. - - In case 8), however, the first window component does not cover all 0s, so the - next component should also be 0. If the second component were 1, information - from the last partial window of the input would be lost. - - In general, the anchor point should be the next position after the right edge - of the previous window. Since the anchor is defined by the left pad, - `left_pad = window - stride`, and `right_pad = (window - 1) - left_pad`, - simplifying to `right_pad = stride - 1`. - """ - window = 15 - stride = 6 - padding = "CAUSAL" - explicit_padding = layers.conv_explicit_padding( - window=(window,), strides=(stride,), padding=padding, dilation=(1,) - ) - self.assertAllEqual(explicit_padding[0], (9, 5)) - - in_paddings = jnp.array([in_paddings]) - out_paddings = compute_conv_paddings( - in_paddings, window=window, stride=stride, conv_padding=padding - )[0] - self.assertAllEqual(out_paddings, expected) - - @parameterized.parameters( - (3, 1, ((1, 1),), "SAME"), - (3, 1, ((0, 0),), "VALID"), - (3, 1, ((2, 0),), "CAUSAL"), - (3, 2, ((1, 1),), "SAME"), - (3, 2, ((0, 0),), "VALID"), - (3, 2, ((1, 1),), "CAUSAL"), - (5, 2, ((2, 2),), "SAME"), - (5, 2, ((0, 0),), "VALID"), - (5, 2, ((3, 1),), "CAUSAL"), - ) - def test_conv_output_1d_padding_against_str_padding( - self, window: int, stride: int, padding: ConvPaddingType, ref_padding: ConvPaddingType - ): - """Tests conv_output_shape() with explicit padding cfg.""" - batch_size = 5 - seq_len = 5 - paddings = jnp.triu(jnp.ones((batch_size, seq_len)), k=1) - - explicit_padding = layers.conv_explicit_padding( - window=(window,), strides=(stride,), padding=ref_padding, dilation=(1,) - ) - self.assertAllEqual(explicit_padding, padding[:1]) - - out_paddings = compute_conv_paddings( - paddings, window=window, stride=stride, conv_padding=padding - ) - ref_paddings = compute_conv_paddings( - paddings, window=window, stride=stride, conv_padding=ref_padding - ) - self.assertAllEqual(out_paddings, ref_paddings) - - @parameterized.parameters( - ("SAME", 1, [0, 0, 0, 0, 1, 1], [0, 0, 1]), - ("VALID", 1, [0, 0, 0, 0, 1, 1], [0]), - ("CAUSAL", 1, [0, 0, 0, 0, 1, 1], [0, 0, 1]), - ("SAME", 2, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1]), - ("VALID", 2, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0]), - ("CAUSAL", 2, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 1]), - ) - def test_compute_conv_paddings_with_dilation( - self, padding: ConvPaddingType, dilation: int, paddings, expected - ): - """Tests compute_conv_paddings() as described in conv_explicit_padding().""" - window, stride = 5, 2 - out_paddings = compute_conv_paddings( - jnp.array([paddings]), - window=window, - stride=stride, - conv_padding=padding, - dilation=dilation, - )[0] - self.assertAllEqual(out_paddings, expected) - - @parameterized.parameters( - (5, "SAME", None, [0, 0, 0, 1, 1, 1]), - (5, "SAME", 1, ValueError), - (5, "SAME", 2, [0, 0, 0, 1, 1, 1]), - (5, "SAME", 3, ValueError), - (5, ((1, 1),), None, [0, 0, 0, 1]), - (5, ((1, 1),), 0, ValueError), - (5, ((1, 1),), 1, [0, 0, 0, 1]), - (5, ((1, 1),), 2, [0, 0, 1, 1]), - (5, ((1, 1),), 3, [0, 1, 1, 1]), - (5, ((1, 1),), 4, ValueError), - (5, "VALID", None, [0, 0]), - (5, "VALID", 0, [0, 0]), - (5, "VALID", 1, [0, 0]), - (5, "VALID", 2, [0, 1]), - (5, "VALID", 3, [1, 1]), - (5, "VALID", 4, [1, 1]), - (5, "CAUSAL", None, [0, 0, 0, 1, 1, 1]), - (5, "CAUSAL", 3, ValueError), - (5, "CAUSAL", 4, [0, 0, 0, 1, 1, 1]), - (5, "CAUSAL", 5, ValueError), - ) - def test_conv_output_1d_padding_with_anchor(self, window, padding, anchor, expected_paddings): - input_paddings = [0, 0, 0, 1, 1, 1] - try: - out_paddings = compute_conv_paddings( - jnp.array([input_paddings]), - window=window, - stride=1, - conv_padding=padding, - anchor=anchor, - ) - assert_allclose(out_paddings[0], expected_paddings) - except ValueError as e: - self.assertTrue(isinstance(e, expected_paddings)) - - @parameterized.named_parameters( - ("1x1", (1, 1), (1, 1), "VALID", None), - ("2x2_VALID", (2, 2), (1, 1), "VALID", None), - ("2x2_SAME", (2, 2), (1, 1), "SAME", None), - ("2x2_CAUSAL", (2, 2), (1, 1), "CAUSAL", None), - ("2x2_S2_VALID", (2, 2), (2, 2), "VALID", None), - ("2x2_S2_CAUSAL", (2, 2), (2, 2), "CAUSAL", None), - ("3x3_VALID", (3, 3), (1, 1), "VALID", None), - ("3x3_VALID_A0", (3, 3), (1, 1), "VALID", 0), - ("3x3_VALID_A1", (3, 3), (1, 1), "VALID", 1), - ("3x3_VALID_A2", (3, 3), (1, 1), "VALID", 2), - ("3x3_SAME", (3, 3), (1, 1), "SAME", None), - ("3x3_CAUSAL", (3, 3), (1, 1), "CAUSAL", None), - ("3x3_S2_VALID", (3, 3), (2, 2), "VALID", None), - ("3x3_S2_CAUSAL", (3, 3), (2, 2), "CAUSAL", None), - ("3x3_S2_PADDING1", (3, 3), (2, 2), (1, 1), None), - ) - def test_conv2d_with_1d_padding( - self, - window: tuple[int, int], - strides: tuple[int, int], - padding: Union[str, tuple[int, int]], - anchor: Optional[int], - ): - """Tests that Conv2DWith1DPadding has consistent outputs under different padding lengths. - - Generates a batch of input sequences. Pads the sequences under different lengths. - Checks that the outputs are the same. - """ - input_dim, input_channel, output_dim = 4, 7, 6 - if isinstance(padding, tuple): - conv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) - else: - conv_padding = padding - cfg = Conv2DWith1DPadding.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=conv_padding, - anchor=anchor, - ) - layer: Conv2DWith1DPadding = cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict(weight=(window[0], window[1], input_dim, output_dim), bias=(output_dim,)), - shapes(layer_params), - ) - # Generate a batch of 10 input sequences. - batch_size, max_seq_len = 10, 10 - - prng_key, input_key = jax.random.split(prng_key) - inputs = ( - jax.random.normal(input_key, [batch_size, max_seq_len, input_channel, input_dim]) * 100 - ) - - # The 10 sequences have length 1 to 10. - paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) - - # Compute layer outputs. - (ref_outputs, ref_paddings), _ = F( - layer, - inputs=dict(x=inputs, paddings=paddings), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(ref_outputs.shape, output_shape) - - random_keys = jax.random.split(input_key, num=2 * max_seq_len) - for seq_len in range(1, max_seq_len): - # We create a new batch. The time axis of the new batch is of length seq_len. - permute_idx = jax.random.permutation(random_keys[2 * (seq_len - 1)], seq_len) - inputs_batch = jnp.take_along_axis(inputs, permute_idx[:, None, None, None], axis=0)[ - :, :seq_len - ] - paddings_batch = jnp.take_along_axis(paddings, permute_idx[:, None], axis=0)[ - :, :seq_len - ] - - # Generate random data at padding positions. - random_data = ( - jax.random.normal( - random_keys[2 * seq_len - 1], - [len(permute_idx), seq_len, input_channel, input_dim], - ) - * 1000 - ) - inputs_new_batch = jnp.where( - paddings_batch[:, :, None, None], random_data, inputs_batch - ) - - (outputs_batch, output_paddings_batch), _ = F( - layer, - inputs=dict(x=inputs_new_batch, paddings=paddings_batch), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - output_len = output_paddings_batch.shape[1] - if output_len > 0: - assert_allclose( - outputs_batch, - jnp.take_along_axis(ref_outputs, permute_idx[:, None, None, None], axis=0)[ - :, :output_len - ], - ) - self.assertAllEqual( - output_paddings_batch, - jnp.take_along_axis(ref_paddings, permute_idx[:, None], axis=0)[:, :output_len], - ) - - @parameterized.named_parameters( - ("1_S1", 1, 1, "VALID", None), - ("2_S1_VALID", 2, 1, "VALID", None), - ("2_S2_SAME", 2, 2, "SAME", None), - ("2_S_CAUSAL", 2, 1, "CAUSAL", None), - ("2_S2_VALID", 2, 2, "VALID", None), - ("2_S2_CAUSAL", 2, 2, "CAUSAL", None), - ("3_S1_VALID", 3, 1, "VALID", None), - ("3_S1_VALID_A0", 3, 1, "VALID", 0), - ("3_S1_VALID_A1", 3, 1, "VALID", 1), - ("3_S1_VALID_A2", 3, 1, "VALID", 2), - ("3_S1_SAME", 3, 1, "SAME", None), - ("3_S1_CAUSAL", 3, 1, "CAUSAL", None), - ("3_S2_VALID", 3, 2, "VALID", None), - ("3_S2_CAUSAL", 3, 2, "CAUSAL", None), - ) - def test_conv1d_against_conv2d_with_1d_padding( - self, - window: int, - strides: int, - padding: ConvPaddingType, - anchor: Optional[int], - ): - input_dim, output_dim = 4, 6 - ref_cfg = Conv2DWith1DPadding.default_config().set( - name="ref", - input_dim=input_dim, - output_dim=output_dim, - window=(window, 1), - strides=(strides, 1), - padding=padding, - anchor=anchor, - ) - ref_layer = ref_cfg.instantiate(parent=None) - - test_cfg = Conv1DWithPadding.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=padding, - anchor=anchor, - ) - test_layer = test_cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - state = ref_layer.initialize_parameters_recursively(init_key) - test_state = dict( - bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o") - ) - - # Generate a batch of 10 input sequences. - batch_size, max_seq_len = 10, 10 - - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim]) - # The 10 sequences have length 1 to 10. - paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) - - (test_outputs, test_paddings), _ = F( - test_layer, - inputs=dict(x=inputs, paddings=paddings), - is_training=True, - state=test_state, - prng_key=prng_key, - ) - output_shape = test_layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(test_outputs.shape, output_shape) - - inputs = einops.rearrange(inputs, "b t i -> b t 1 i") - (ref_outputs, ref_paddings), _ = F( - ref_layer, - inputs=dict(x=inputs, paddings=paddings), - is_training=True, - state=state, - prng_key=prng_key, - ) - output_shape = ref_layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(ref_outputs.shape, output_shape) - ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") - - assert_allclose(ref_paddings, test_paddings) - assert_allclose(ref_outputs, test_outputs) - - @parameterized.named_parameters( - { - "testcase_name": "2x2", - "window": (2, 2), - "strides": (1, 1), - "padding": "VALID", - }, - { - "testcase_name": "2x2_S2", - "window": (2, 2), - "strides": (2, 2), - "padding": "VALID", - }, - { - "testcase_name": "3x3_S2", - "window": (3, 3), - "strides": (2, 2), - "padding": "VALID", - }, - ) - def test_conv2d_transpose_against_pytorch( - self, - window: tuple[int, int], - strides: tuple[int, int], - padding: Union[str, tuple[int, int]], - ): - input_dim, output_dim = 4, 8 - if isinstance(padding, tuple): - deconv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) - else: - deconv_padding = padding - cfg = Conv2DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=deconv_padding, - transpose_kernel=True, - ) - layer: Conv2DTranspose = cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict( - weight=(window[0], window[1], output_dim, input_dim), - bias=(output_dim,), - ), - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [2, 10, 7, input_dim]) - # Compute layer outputs. - outputs, _ = F( - layer, - inputs=(inputs,), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - - # Compute ref outputs. - if isinstance(padding, tuple): - ref_padding = padding[0] - elif isinstance(padding, str): - ref_padding = padding.lower() - if ref_padding == "valid": - ref_padding = 0 - else: - ref_padding = 0 - - ref = torch.nn.ConvTranspose2d( - in_channels=input_dim, - out_channels=output_dim, - kernel_size=window, - stride=strides, - padding=ref_padding, - ) - # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). - _copy(layer_params["weight"].transpose(3, 2, 0, 1), ref.weight) - _copy(layer_params["bias"], ref.bias) - ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 3, 1, 2))) - assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 1)) - # Tests output_shape. - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - - @parameterized.named_parameters( - { - "testcase_name": "1x1x1", - "window": (1, 1, 1), - "strides": (1, 1, 1), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "2x2x2_VALID", - "window": (2, 2, 2), - "strides": (1, 1, 1), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "2x2x2_SAME", - "window": (2, 2, 2), - "strides": (1, 1, 1), - "padding": "SAME", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "2x2x2_S2_VALID", - "window": (2, 2, 2), - "strides": (2, 2, 2), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3x3_VALID", - "window": (3, 3, 3), - "strides": (1, 1, 1), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3x3_SAME", - "window": (3, 3, 3), - "strides": (1, 1, 1), - "padding": "SAME", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3x3_S2_VALID", - "window": (3, 3, 3), - "strides": (2, 2, 2), - "padding": "VALID", - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3x3_S2_PADDING1", - "window": (3, 3, 3), - "strides": (2, 2, 2), - "padding": (1, 1, 1), - "num_input_dim_groups": 1, - }, - { - "testcase_name": "3x3x3_GROUPS4", - "window": (3, 3, 3), - "strides": (1, 1, 1), - "padding": "SAME", - "num_input_dim_groups": 4, - }, - ) - def test_conv3d( - self, - window: tuple[int, int], - strides: tuple[int, int], - padding: Union[str, tuple[int, int]], - num_input_dim_groups: int, - ): - input_dim, output_dim = 4, 8 - if isinstance(padding, tuple): - conv_padding = ( - (padding[0], padding[0]), - (padding[1], padding[1]), - (padding[2], padding[2]), - ) - else: - conv_padding = padding - cfg = Conv3D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=conv_padding, - num_input_dim_groups=num_input_dim_groups, - ) - layer: Conv3D = cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - expected = dict( - weight=(window[0], window[1], window[2], input_dim // num_input_dim_groups, output_dim), - bias=(output_dim,), - ) - self.assertEqual( - expected, - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - - batch_size = 2 - inputs = jax.random.normal(input_key, [batch_size, 10, 7, 4, input_dim]) - - # Compute layer outputs. - outputs, _ = F( - layer, - inputs=(inputs,), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - - # Compute ref outputs. - ref_padding = padding.lower() if isinstance(padding, str) else padding - ref = torch.nn.Conv3d( - in_channels=input_dim, - out_channels=output_dim, - kernel_size=window, - stride=strides, - padding=ref_padding, - groups=num_input_dim_groups, - ) - - # weight.shape: (H, W, D, I, O) - # ref.weight.shape: (O, I, H, W, D) - _copy(layer_params["weight"].transpose(4, 3, 0, 1, 2), ref.weight) - _copy(layer_params["bias"], ref.bias) - - ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 4, 1, 2, 3))) - assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 4, 1)) - - # Tests output_shape. - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - - @parameterized.named_parameters( - ("w3s1d1_VALID", 3, 1, "VALID", None), - ("w3s1d2_VALID", 3, 1, "VALID", 2), - ("w3s1d1_SAME", 3, 1, "SAME", None), - ("w4s1d1_SAME", 4, 1, "SAME", None), - ("w4s1d3_SAME", 4, 1, "SAME", 3), - ("w4s1d1_CAUSAL", 4, 1, ((3, 0),), None), - ("w4s1d5_CAUSAL", 4, 1, ((3, 0),), 5), - ) - def test_conv1d( - self, - window: int, - strides: int, - padding: ConvPaddingType, - dilation: Optional[int], - ): - input_dim, output_dim = 4, 6 - cfg = Conv1D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=padding, - dilation=dilation, - ) - layer: Conv1D = cfg.instantiate(parent=None) - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict(weight=(window, input_dim, output_dim), bias=(output_dim,)), - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [2, 17, input_dim]) - # Compute layer outputs. - outputs, _ = F( - layer, - inputs=(inputs,), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - - # Compute ref outputs. - if isinstance(padding, str): - ref_padding = padding.lower() - ref_inputs = inputs - else: - # torch.nn.Conv1d does not support asymmetric padding, so pad manually and use "valid". - ref_padding = "valid" - ref_inputs = jnp.pad(inputs, ((0, 0), padding[0], (0, 0))) - ref = torch.nn.Conv1d( - in_channels=input_dim, - out_channels=output_dim, - groups=1, - kernel_size=window, - stride=strides, - padding=ref_padding, - dilation=1 if dilation is None else dilation, - ) - # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). - _copy(layer_params["weight"].transpose(2, 1, 0), ref.weight) - _copy(layer_params["bias"], ref.bias) - ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) - assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) - - @parameterized.named_parameters( - ("w3s1_VALID", 3, 1, "VALID"), - ("w3s1_SAME", 3, 1, "SAME"), - ("w4s1_SAME", 4, 1, "SAME"), - ("w4s1_CAUSAL", 4, 1, ((3, 0),)), - ) - def test_depthwise_conv1d( - self, - window: int, - strides: int, - padding: ConvPaddingType, - ): - input_dim = 4 - cfg = Conv1D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=input_dim, - num_input_dim_groups=input_dim, - window=window, - strides=strides, - padding=padding, - ) - layer: Conv1D = cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict(weight=(window, 1, input_dim), bias=(input_dim,)), - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [2, 7, input_dim]) - - # Compute layer outputs. - outputs, _ = F( - layer, - inputs=(inputs,), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - - # Compute ref outputs. - if isinstance(padding, str): - ref_padding = padding.lower() - ref_inputs = inputs - else: - # torch.nn.Conv1d does not support asymmetric padding, so pad manually and use "valid". - ref_padding = "valid" - ref_inputs = jnp.pad(inputs, ((0, 0), padding[0], (0, 0))) - ref = torch.nn.Conv1d( - in_channels=input_dim, - out_channels=input_dim, - groups=input_dim, - kernel_size=window, - stride=strides, - padding=ref_padding, - ) - # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). - _copy(layer_params["weight"].transpose(2, 1, 0), ref.weight) - _copy(layer_params["bias"], ref.bias) - ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) - assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) - - ############################## Transposed Convolution ########################################## - - CONVT_EXPLICIT_PADDING_PARAMS = [ - (3, 1, "SAME", 1, (1, 1)), - (3, 2, "SAME", 1, (2, 1)), - (3, 3, "SAME", 1, (2, 2)), - (3, 4, "SAME", 1, (2, 3)), - (3, 1, "SAME", 2, (2, 2)), - (3, 2, "SAME", 2, (3, 2)), - (3, 3, "SAME", 2, (3, 3)), - (3, 1, "VALID", 1, (2, 2)), - (3, 2, "VALID", 1, (2, 2)), - (3, 3, "VALID", 1, (2, 2)), - (3, 4, "VALID", 1, (2, 3)), - (3, 1, "VALID", 2, (4, 4)), - (3, 2, "VALID", 2, (4, 4)), - (3, 3, "VALID", 2, (4, 4)), - (3, 1, "CAUSAL", 1, (2, 0)), - (3, 2, "CAUSAL", 1, (2, 1)), - (3, 3, "CAUSAL", 1, (2, 2)), - (3, 4, "CAUSAL", 1, (2, 3)), - (3, 1, "CAUSAL", 2, (4, 0)), - (3, 2, "CAUSAL", 2, (4, 1)), - (3, 3, "CAUSAL", 2, (4, 2)), - ] - - @parameterized.parameters(*CONVT_EXPLICIT_PADDING_PARAMS) - def test_conv_transpose_explicit_padding(self, window, strides, padding, dilation, expected): - """Tests the cases in conv_transpose_explicit_padding() description.""" - explicit_padding = layers.conv_transpose_explicit_padding( - window=(window,), - strides=(strides,), - padding=padding, - dilation=(dilation,), - ) - self.assertAllEqual(explicit_padding[0], expected) - - @parameterized.parameters(*CONVT_EXPLICIT_PADDING_PARAMS) - def test_conv_transpose_explicit_padding_against_jax( - self, window, strides, padding, dilation, expected - ): - """Compare with jax.lax.convolution._conv_transpose_padding().""" - if padding == "CAUSAL": - self.skipTest("Causal padding is not supported in JAX.") - - # Copied from jax.lax.convolution._conv_transpose_padding. - def _conv_transpose_padding(k, s, padding): - if padding == "SAME": - pad_len = k + s - 2 - if s > k - 1: - pad_a = k - 1 - else: - pad_a = int(np.ceil(pad_len / 2)) - elif padding == "VALID": - pad_len = k + s - 2 + max(k - s, 0) - pad_a = k - 1 - else: - raise ValueError("Padding mode must be `SAME` or `VALID`.") - pad_b = pad_len - pad_a - return pad_a, pad_b - - dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] - ref_padding = _conv_transpose_padding(dilate_window, strides, padding) - - explicit_padding = layers.conv_transpose_explicit_padding( - window=(window,), - strides=(strides,), - padding=padding, - dilation=(dilation,), - ) - - self.assertAllEqual(explicit_padding[0], ref_padding) - self.assertAllEqual(expected, ref_padding) - - @parameterized.parameters( - (3, 1, "SAME", 1, 4), - (3, 2, "SAME", 1, 8), - (3, 3, "SAME", 1, 12), - (3, 4, "SAME", 1, 16), - (3, 1, "SAME", 2, 4), - (3, 2, "SAME", 2, 8), - (3, 3, "SAME", 2, 12), - (3, 1, "VALID", 1, 6), - (3, 2, "VALID", 1, 9), - (3, 3, "VALID", 1, 12), - (3, 4, "VALID", 1, 16), - (3, 1, "VALID", 2, 8), - (3, 2, "VALID", 2, 11), - (3, 3, "VALID", 2, 14), - (3, 1, "CAUSAL", 1, 4), - (3, 2, "CAUSAL", 1, 8), - (3, 3, "CAUSAL", 1, 12), - (3, 4, "CAUSAL", 1, 16), - (3, 1, "CAUSAL", 2, 4), - (3, 2, "CAUSAL", 2, 8), - (3, 3, "CAUSAL", 2, 12), - ) - def test_conv_transpose_output_shape(self, window, strides, padding, dilation, expected): - """Tests the cases in conv_transpose_explicit_padding() description.""" - out_shape = layers.conv_transpose_output_shape( - in_shape=(4,), - window=(window,), - strides=(strides,), - padding=padding, - dilation=(dilation,), - ) - self.assertAllEqual(out_shape[0], expected) - - @parameterized.parameters( - (3, 1, "SAME", 1, [0, 0, 1, 1]), - (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 1, 1]), - (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), - (3, 1, "SAME", 2, [0, 0, 1, 1]), - (3, 2, "SAME", 2, [0, 0, 0, 0, 1, 1, 1, 1]), - (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - (3, 1, "VALID", 1, [0, 0, 1, 1, 1, 1]), - (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 1, 1, 1]), - (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), - (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), - (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]), - (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), - (3, 1, "CAUSAL", 1, [0, 0, 1, 1]), - (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 1, 1]), - (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), - (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), - (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 1, 1, 1]), - (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - ) - def test_compute_conv_transpose_paddings(self, window, strides, padding, dilation, expected): - """Tests the cases in conv_transpose_explicit_padding() description.""" - in_paddings = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :] - out_paddings = layers.compute_conv_transpose_paddings( - in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation - ) - expected = jnp.array(expected).astype(out_paddings.dtype) - self.assertNestedEqual(out_paddings[0], expected) - - @parameterized.product( - window=[1, 3], - strides=[1, 2, 3], - padding=["SAME", "VALID", "CAUSAL"], - dilation=[1, 2], - value=[0, 1], - ) - def test_compute_conv_transpose_paddings_all0or1( - self, window, strides, padding, dilation, value - ): - """If in_paddings is all valid or invalid, out_paddings must be all valid or invalid.""" - in_paddings = jnp.full([1, 4], fill_value=value) - out_paddings = layers.compute_conv_transpose_paddings( - in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation - ) - expected = jnp.ones_like(out_paddings) * value - self.assertNestedEqual(out_paddings, expected) - - CONVT_PADDINGS_PARAMS = dict( - in_paddings=[ - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 1], - [1, 0, 0, 0, 1], - [1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 1, 1], - [1, 1, 0, 1, 1, 1, 1, 0, 0, 0], - ], - window=[1, 3], - padding=["SAME", "VALID", "CAUSAL"], - dilation=[1, 2], - ) - - @parameterized.product(**CONVT_PADDINGS_PARAMS, strides=[1, 2, 3]) - def test_compute_conv_transpose_paddings_with_conv_paddings( - self, in_paddings, window, strides, padding, dilation - ): - """Check if ConvT -> Conv preserves information.""" - in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] - out_paddings = layers.compute_conv_transpose_paddings( - in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation - ) - - recon_paddings = layers.compute_conv_paddings( - out_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation - ) - self.assertNestedEqual(recon_paddings[0], in_paddings[0]) - - @parameterized.product(**CONVT_PADDINGS_PARAMS) - def test_compute_conv_transpose_paddings_against_conv_paddings( - self, in_paddings, window, padding, dilation - ): - # compute_conv_transpose_paddings and compute_conv_paddings are same when window_stride=1 - # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). - strides = 1 - if padding == "VALID": - # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint - # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. - # For reference, the midpoint in SAME and CAUSAL is left_pad. - dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] - conv_padding = layers.conv_explicit_padding( - window=(window,), strides=(strides,), padding=padding, dilation=(dilation,) - )[0] - pad_left, pad_right = conv_padding - anchor_range = dilate_window - pad_left - pad_right - mid_point = anchor_range // 2 - anchor = pad_left + mid_point - else: - anchor = None - - in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] - ref_paddings = layers.compute_conv_paddings( - in_paddings, - window=window, - stride=strides, - conv_padding=padding, - dilation=dilation, - anchor=anchor, - ) - - test_paddings = layers.compute_conv_transpose_paddings( - in_paddings, - window=window, - stride=strides, - conv_padding=padding, - dilation=dilation, - anchor=anchor, - ) - - if ref_paddings.shape != test_paddings.shape: - self.assertEqual(padding, "VALID") - dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] - pad_left = dilate_window - 1 - test_paddings = test_paddings[:, pad_left:-pad_left] - - assert_allclose(ref_paddings, test_paddings) - - CONVT_PARAMS = [ - (3, 1, "SAME", 1, [0, 1, 2, 2]), - (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 2, 1]), - (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), - (3, 1, "SAME", 2, [1, 1, 1, 1]), - (3, 2, "SAME", 2, [0, 0, 0, 1, 0, 2, 0, 2]), - (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0]), - (3, 1, "VALID", 1, [0, 0, 1, 2, 2, 1]), - (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 2, 1, 1]), - (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), - (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), - (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 1]), - (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1]), - (3, 1, "CAUSAL", 1, [0, 0, 1, 2]), - (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 2, 1]), - (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), - (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), - (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), - (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 0, 2, 0]), - (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1]), - ] - - @parameterized.parameters(*CONVT_PARAMS) - def test_conv1d_transpose_simple(self, window, strides, padding, dilation, expected): - """Tests the cases in conv_transpose_explicit_padding() description.""" - input_dim, output_dim = 1, 1 - inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None] - cfg = layers.Conv1DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=padding, - dilation=dilation, - bias=False, - ) - layer = cfg.instantiate(parent=None) - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual(dict(weight=(window, input_dim, output_dim)), shapes(layer_params)) - layer_params["weight"] = jnp.ones_like(layer_params["weight"]) - - (outputs, paddings), _ = F( - layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key - ) - out_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, out_shape) - self.assertIsNone(paddings) - expected = jnp.array(expected).astype(outputs.dtype) - self.assertNestedEqual(outputs[0, :, 0], expected) - - @parameterized.parameters(*CONVT_PARAMS) - def test_conv2d_transpose_simple(self, window, strides, padding, dilation, expected): - """Tests the cases in conv_transpose_explicit_padding() description.""" - window = (window, 1) - strides = (strides, 1) - dilation = (dilation, 1) - input_dim, output_dim = 1, 1 - inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None] - cfg = layers.Conv2DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=padding, - dilation=dilation, - bias=False, - ) - layer = cfg.instantiate(parent=None) - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) - layer_params["weight"] = jnp.ones_like(layer_params["weight"]) - - outputs, _ = F( - layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key - ) - out_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, out_shape) - expected = jnp.array(expected).astype(outputs.dtype) - self.assertNestedEqual(outputs[0, :, 0, 0], expected) - - @parameterized.parameters(*CONVT_PARAMS) - def test_conv3d_transpose_simple(self, window, strides, padding, dilation, expected): - """Tests the cases in conv_transpose_explicit_padding() description.""" - window = (window, 1, 1) - strides = (strides, 1, 1) - dilation = (dilation, 1, 1) - input_dim, output_dim = 1, 1 - inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None, None] - cfg = layers.Conv3DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=padding, - dilation=dilation, - bias=False, - ) - layer = cfg.instantiate(parent=None) - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) - layer_params["weight"] = jnp.ones_like(layer_params["weight"]) - - outputs, _ = F( - layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key - ) - out_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, out_shape) - expected = jnp.array(expected).astype(outputs.dtype) - self.assertNestedEqual(outputs[0, :, 0, 0, 0], expected) - - @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) - def test_conv1d_transpose_against_conv1d(self, window, padding, dilation): - # Conv1D and Conv1DTranspose are same when window_stride=1 - # (stride of Conv1D) and lhs_dilation=1 (stride of Conv1DTranspose). - input_dim, output_dim = 4, 6 - ref_cfg = Conv1D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - ) - ref_layer = ref_cfg.instantiate(parent=None) - - test_cfg = layers.Conv1DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - ) - test_layer = test_cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - ref_states = ref_layer.initialize_parameters_recursively(init_key) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [2, 17, input_dim]) - # Compute layer outputs. - ref_outputs, _ = F( - ref_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key - ) - - (test_outputs, _), _ = F( - test_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key - ) - if ref_outputs.shape != test_outputs.shape: - self.assertEqual(padding, "VALID") - dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] - pad_left = dilate_window - 1 - test_outputs = test_outputs[:, pad_left:-pad_left] - assert_allclose(ref_outputs, test_outputs) - - @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) - def test_conv2d_transpose_against_conv2d(self, window, padding, dilation): - # Conv2D and Conv2DTranspose are same when window_stride=1 - # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). - window = (window, window) - dilation = (dilation, dilation) - input_dim, output_dim = 4, 6 - ref_cfg = Conv2D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - ) - ref_layer = ref_cfg.instantiate(parent=None) - - test_cfg = layers.Conv2DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - transpose_kernel=False, - ) - test_layer = test_cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - ref_states = ref_layer.initialize_parameters_recursively(init_key) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - width, height = 12, 13 - inputs = jax.random.normal(input_key, [2, width, height, input_dim]) - # Compute layer outputs. - ref_outputs, _ = F( - ref_layer, - inputs=dict(x=inputs), - is_training=True, - state=ref_states, - prng_key=prng_key, - ) - - test_outputs, _ = F( - test_layer, - inputs=dict(x=inputs), - is_training=True, - state=ref_states, - prng_key=prng_key, - ) - if ref_outputs.shape != test_outputs.shape: - self.assertEqual(padding, "VALID") - dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) - pad_left = tuple(w - 1 for w in dilate_window) - test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] - - assert_allclose(ref_outputs, test_outputs) - - @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) - def test_conv2d_transpose_against_conv2d_with_paddings(self, window, padding, dilation): - # Conv2DWith1DPadding and Conv2DTransposeWith1DPadding are same when window_stride=1 - # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). - window = (window, window) - dilation = (dilation, dilation) - input_dim, output_dim = 4, 6 - if padding == "VALID": - # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint - # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. - # For reference, the midpoint in SAME and CAUSAL is left_pad. - strides = (1, 1) - dilate_window = layers.conv_dilate_window(window=window, dilation=dilation)[0] - conv_padding = layers.conv_explicit_padding( - window=window, strides=strides, padding=padding, dilation=dilation - ) - pad_left, pad_right = conv_padding[0] - anchor_range = dilate_window - pad_left - pad_right - mid_point = anchor_range // 2 - anchor = pad_left + mid_point - else: - anchor = None - - ref_cfg = Conv2DWith1DPadding.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - anchor=anchor, - ) - ref_layer = ref_cfg.instantiate(parent=None) - - test_cfg = layers.Conv2DTransposeWith1DPadding.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - anchor=anchor, - ) - test_layer = test_cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - ref_states = ref_layer.initialize_parameters_recursively(init_key) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - width, height = 12, 13 - inputs = jax.random.normal(input_key, [2, width, height, input_dim]) - paddings = jnp.zeros([2, width], dtype=inputs.dtype).at[:, -2:].set(1) - # Compute layer outputs. - (ref_outputs, ref_paddings), _ = F( - ref_layer, - inputs=dict(x=inputs, paddings=paddings), - is_training=True, - state=ref_states, - prng_key=prng_key, - ) - - (test_outputs, test_paddings), _ = F( - test_layer, - inputs=dict(x=inputs, paddings=paddings), - is_training=True, - state=ref_states, - prng_key=prng_key, - ) - if ref_outputs.shape != test_outputs.shape: - self.assertEqual(padding, "VALID") - dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) - pad_left = tuple(w - 1 for w in dilate_window) - test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] - test_paddings = test_paddings[:, pad_left[0] : -pad_left[0]] - - assert_allclose(ref_outputs, test_outputs) - assert_allclose(ref_paddings, test_paddings) - - @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) - def test_conv3d_transpose_against_conv3d(self, window, padding, dilation): - # Conv3D and Conv3DTranspose are same when window_stride=1 - # (stride of Conv3D) and lhs_dilation=1 (stride of Conv3DTranspose). - window = (window, window, window) - dilation = (dilation, dilation, dilation) - input_dim, output_dim = 4, 6 - ref_cfg = Conv3D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - ) - ref_layer = ref_cfg.instantiate(parent=None) - - test_cfg = layers.Conv3DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - padding=padding, - dilation=dilation, - ) - test_layer = test_cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - ref_states = ref_layer.initialize_parameters_recursively(init_key) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - width, height, depth = 9, 8, 7 - inputs = jax.random.normal(input_key, [2, width, height, depth, input_dim]) - # Compute layer outputs. - ref_outputs, _ = F( - ref_layer, - inputs=dict(x=inputs), - is_training=True, - state=ref_states, - prng_key=prng_key, - ) - - test_outputs, _ = F( - test_layer, - inputs=dict(x=inputs), - is_training=True, - state=ref_states, - prng_key=prng_key, - ) - if ref_outputs.shape != test_outputs.shape: - self.assertEqual(padding, "VALID") - dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) - pad_left = tuple(w - 1 for w in dilate_window) - test_outputs = test_outputs[ - :, - pad_left[0] : -pad_left[0], - pad_left[1] : -pad_left[1], - pad_left[2] : -pad_left[2], - ] - - assert_allclose(ref_outputs, test_outputs) - - @parameterized.product( - window=(1, 3, 5), - strides=(1, 2), - padding=("SAME", "VALID", "CAUSAL"), - dilation=(1, 2), - anchor=(None, 1), - ) - def test_conv1d_transpose_against_conv2d_transpose_with_1d_padding( - self, - window, - strides, - padding: ConvPaddingType, - dilation, - anchor, - ): - if anchor is not None: - dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] - anchor = dilate_window - 1 - - input_dim, output_dim = 4, 6 - ref_cfg = layers.Conv2DTransposeWith1DPadding.default_config().set( - name="ref", - input_dim=input_dim, - output_dim=output_dim, - window=(window, 1), - strides=(strides, 1), - padding=padding, - dilation=(dilation, 1), - anchor=anchor, - ) - ref_layer = ref_cfg.instantiate(parent=None) - - test_cfg = layers.Conv1DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=padding, - dilation=dilation, - anchor=anchor, - ) - test_layer = test_cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - state = ref_layer.initialize_parameters_recursively(init_key) - test_state = dict( - bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o") - ) - - # Generate a batch of 10 input sequences. - batch_size, max_seq_len = 10, 10 - - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim]) - # The 10 sequences have length 1 to 10. - paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) - - (test_outputs, test_paddings), _ = F( - test_layer, - inputs=dict(x=inputs, paddings=paddings), - is_training=True, - state=test_state, - prng_key=prng_key, - ) - - inputs = einops.rearrange(inputs, "b t i -> b t 1 i") - (ref_outputs, ref_paddings), _ = F( - ref_layer, - inputs=dict(x=inputs, paddings=paddings), - is_training=True, - state=state, - prng_key=prng_key, - ) - ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") - - assert_allclose(ref_paddings, test_paddings) - assert_allclose(ref_outputs, test_outputs) - @parameterized.parameters( itertools.product( (None, 0.0, 0.2, 1.0, -0.1), @@ -2739,153 +1031,6 @@ def test_drop_tokens(self, drop_rate, num_cls_tokens): assert_allclose(outputs.shape, [batch_size, len_tokens, dim]) - @parameterized.parameters( - ( - 2, - (0, 0), - [[[1, 1, 2, 2], [3, 3, 4, 4]], [[7, 7, 8, 8], [0, 0, 0, 0]]], - [[0, 0], [0, 1]], - ), - ( - 3, - (0, 0), - [[[1, 1, 2, 2, 3, 3]], [[7, 7, 8, 8, 0, 0]]], - [[0], [0]], - ), - ( - 3, - (2, 0), - [[[0, 0, 0, 0, 1, 1], [2, 2, 3, 3, 4, 4]], [[0, 0, 0, 0, 7, 7], [0, 0, 0, 0, 0, 0]]], - [[0, 0], [0, 1]], - ), - ) - def test_stack_over_time(self, stride, pad, expected_outputs, expected_output_paddings): - # Input shape [2, 5, 2]. - inputs = jnp.array( - [[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]], [[7, 7], [8, 8], [0, 0], [0, 0], [0, 0]]], - dtype=jnp.float32, - ) - paddings = jnp.array([[0, 0, 0, 0, 0], [0, 0, 1, 1, 1]]) - layer: StackOverTime = ( - StackOverTime.default_config() - .set( - name="test", - stride=stride, - padding=pad, - ) - .instantiate(parent=None) - ) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - (outputs, output_paddings), _ = F( - layer, - inputs=dict(inputs=inputs, paddings=paddings), - is_training=False, - state=layer_params, - prng_key=jax.random.PRNGKey(5), - ) - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - self.assertAllClose(jnp.array(expected_outputs, dtype=jnp.float32), outputs) - self.assertAllClose(jnp.array(expected_output_paddings, dtype=jnp.int32), output_paddings) - - def test_stack_over_time_data_change(self): - """Tests that the stacked outputs is masked with the output paddings.""" - np.random.seed(500) - inputs = np.random.normal(size=[2, 21, 16]) - paddings = np.ones([2, 21], dtype=np.float32) - paddings[0, :9] = 0 - paddings[1, :14] = 0 - inputs = inputs * (1 - paddings)[:, :, None] - - layer: StackOverTime = ( - StackOverTime.default_config() - .set( - name="test", - stride=2, - ) - .instantiate(parent=None) - ) - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - (outputs, output_paddings), _ = F( - layer, - inputs=dict(inputs=inputs, paddings=paddings), - is_training=False, - state=layer_params, - prng_key=jax.random.PRNGKey(5), - ) - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - self.assertAllEqual(np.array([5, 7], dtype=np.float32), np.sum(1 - output_paddings, axis=1)) - self.assertAllClose(np.sum(inputs**2, (1, 2)), np.sum(outputs**2, (1, 2))) - - @parameterized.product(stride=(2, 3, 4), pad=("VALID", "SAME", "CAUSAL")) - def test_stack_consistent_outputs(self, stride, pad): - """Tests that StackOverTime has consistent outputs under different padding lengths.""" - batch_size, input_dim = 2, 1 - input_length = 7 - layer: StackOverTime = ( - StackOverTime.default_config() - .set( - name="test", - stride=stride, - padding=pad, - ) - .instantiate(parent=None) - ) - expected_output_length = layer.output_shape(input_shape=[1, input_length, 1])[1] - layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - for ll in range(4, 11): - # Batch with another example of length ll. - length = max(input_length, ll) - inputs = jnp.ones([batch_size, length, input_dim]) - paddings = jnp.arange(length)[None, :] >= jnp.array([input_length, ll])[:, None] - (outputs, output_paddings), _ = F( - layer, - inputs=dict(inputs=inputs, paddings=paddings), - is_training=False, - state=layer_params, - prng_key=jax.random.PRNGKey(5), - ) - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - if pad != "VALID": # VALID doesn't preserve length. - self.assertEqual(expected_output_length, np.sum(1 - output_paddings, axis=1)[0]) - - @parameterized.parameters(((0, 1), (0, 0)), ((1, 1), (3, 0)), ((1, 1), (0, 3))) - def test_stack_vs_conv2d_output_len_match(self, conv_padding, stack_padding): - # Note that to get the same output length, we need to pad the sequence differently - # for convolution and stacking layer. - for audio_seq_len in [16000, 16160, 16320, 16480, 16640, 16800, 16960, 17120]: - sampling_rate, window_size_ms, window_step_ms = 16000, 25, 10 - window_size = window_size_ms * sampling_rate // 1000 - window_step = window_step_ms * sampling_rate // 1000 - seq_len = max(audio_seq_len - window_size, 0) // window_step + 1 - conv_layer: Conv2DWith1DPadding = ( - Conv2DWith1DPadding.default_config() - .set( - name="test_conv", - input_dim=3, - output_dim=3, - window=(3, 3), - strides=(2, 2), - padding=(conv_padding, (0, 1)), - ) - .instantiate(parent=None) - ) - stack_layer: StackOverTime = ( - StackOverTime.default_config() - .set(name="test_stack", stride=4, padding=stack_padding) - .instantiate(parent=None) - ) - # Computes downsampler output shape. - down_sample_shape1 = conv_layer.output_shape(input_shape=[None, seq_len, 80, 3]) - down_sample_shape2 = conv_layer.output_shape(input_shape=down_sample_shape1) - - # Computes stack output shape. - stack_shape = stack_layer.output_shape(input_shape=[None, seq_len, 80]) - # Tests that the sequence length dimension matches. - self.assertEqual(down_sample_shape2[1], stack_shape[1]) - def test_multilinear_fan_axes(self): input_dim, num_outputs, output_dim = 3, 4, 6 layer: MultiLinear = ( diff --git a/axlearn/common/param_converter.py b/axlearn/common/param_converter.py index 0a8220092..e5afa1b0f 100644 --- a/axlearn/common/param_converter.py +++ b/axlearn/common/param_converter.py @@ -48,6 +48,7 @@ from axlearn.common.base_layer import BaseLayer from axlearn.common.bert import BertModel, BertPooler, BertSequenceClassificationHead from axlearn.common.causal_lm import Model as CausalLMModel +from axlearn.common.convolution import Conv2D from axlearn.common.deberta import DeBERTaV2Encoder from axlearn.common.decoder import Decoder from axlearn.common.dit import ( @@ -61,7 +62,7 @@ ) from axlearn.common.embedding import TransformerTextEmbeddings from axlearn.common.encoder import Encoder, EncoderModel -from axlearn.common.layers import Conv2D, Embedding, LayerNorm, LayerNormStateless, Linear, RMSNorm +from axlearn.common.layers import Embedding, LayerNorm, LayerNormStateless, Linear, RMSNorm from axlearn.common.t5 import T5Decoder, T5Encoder, T5EncoderDecoderModel from axlearn.common.text_encoder import TextEmbeddingEncoder from axlearn.common.utils import NestedTensor, Tensor, VDict, as_tensor diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index df48f5baf..0e8ef7980 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -43,7 +43,8 @@ ) from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.layers import Conv1D, Linear, MultiLinear, RMSNorm +from axlearn.common.convolution import Conv1D +from axlearn.common.layers import Linear, MultiLinear, RMSNorm from axlearn.common.module import Module from axlearn.common.param_init import FanAxes, Initializer, Shape, constant_initializer, uniform from axlearn.common.ssm_kernels.mamba_kernels import compute_mamba_scan diff --git a/axlearn/common/state_builder_test.py b/axlearn/common/state_builder_test.py index fb6e39fb9..6c317c89f 100644 --- a/axlearn/common/state_builder_test.py +++ b/axlearn/common/state_builder_test.py @@ -27,8 +27,9 @@ config_class, config_for_function, ) +from axlearn.common.convolution import Conv2D from axlearn.common.input_fake import FakeLmInput -from axlearn.common.layers import Conv2D, Linear +from axlearn.common.layers import Linear from axlearn.common.module import Module from axlearn.common.module import functional as F from axlearn.common.param_converter import torch_to_axlearn diff --git a/axlearn/common/vision_transformer.py b/axlearn/common/vision_transformer.py index 20141e489..2d39bacd0 100644 --- a/axlearn/common/vision_transformer.py +++ b/axlearn/common/vision_transformer.py @@ -36,8 +36,8 @@ ) from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class +from axlearn.common.convolution import Conv2D from axlearn.common.layers import ( - Conv2D, Dropout, DropToken, L2Norm, diff --git a/axlearn/experiments/audio/conformer/librispeech_trainer.py b/axlearn/experiments/audio/conformer/librispeech_trainer.py index e35b59282..3b475e4b1 100644 --- a/axlearn/experiments/audio/conformer/librispeech_trainer.py +++ b/axlearn/experiments/audio/conformer/librispeech_trainer.py @@ -44,12 +44,12 @@ from axlearn.common import learner from axlearn.common.checkpointer import every_n_steps_policy as save_every_n_steps from axlearn.common.config import InstantiableConfig, config_for_class, config_for_function +from axlearn.common.convolution import Conv2DWith1DPadding from axlearn.common.decoding import PrefixMerger from axlearn.common.evaler import SpmdEvaler from axlearn.common.evaler import every_n_steps_policy as eval_every_n_steps from axlearn.common.input_fake import fake_speech_source, fake_text_source from axlearn.common.input_tf_data import BuildDatasetFn, Input, batch, tfds_dataset -from axlearn.common.layers import Conv2DWith1DPadding from axlearn.common.trainer import SpmdTrainer from axlearn.common.utils import get_data_dir from axlearn.experiments.audio.conformer.common import ( diff --git a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt index 637605743..65599bfc3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-l-rnnt.txt @@ -410,7 +410,7 @@ model.encoder.context.context.layer.ff_start.stochastic_depth.mode: 'row' model.encoder.context.context.layer.ff_start.structure: 'prenorm' model.encoder.context.context.layer.klass: 'axlearn.common.conformer.ConformerLayer' model.encoder.context.context.layer.lconv.conv.bias: False -model.encoder.context.context.layer.lconv.conv.klass: 'axlearn.common.layers.Conv1D' +model.encoder.context.context.layer.lconv.conv.klass: 'axlearn.common.convolution.Conv1D' model.encoder.context.context.layer.lconv.conv.num_input_dim_groups: 1 model.encoder.context.context.layer.lconv.conv.padding: 'SAME' model.encoder.context.context.layer.lconv.conv.param_partition_spec[0]: None @@ -525,7 +525,7 @@ model.encoder.feature.klass: 'axlearn.audio.encoder_asr.SpeechFeatureLayer' model.encoder.feature.output_dim: 512 model.encoder.feature.subsampler.activation: 'nn.relu' model.encoder.feature.subsampler.conv.bias: True -model.encoder.feature.subsampler.conv.klass: 'axlearn.common.layers.Conv2DWith1DPadding' +model.encoder.feature.subsampler.conv.klass: 'axlearn.common.convolution.Conv2DWith1DPadding' model.encoder.feature.subsampler.conv.num_input_dim_groups: 1 model.encoder.feature.subsampler.conv.padding[0][0]: 1 model.encoder.feature.subsampler.conv.padding[0][1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt index 7d35af9b9..4e23c338d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.audio.conformer.librispeech_trainer/conformer-test-ctc.txt @@ -116,7 +116,7 @@ model.encoder.context.context.layer.ff_start.stochastic_depth.mode: 'row' model.encoder.context.context.layer.ff_start.structure: 'prenorm' model.encoder.context.context.layer.klass: 'axlearn.common.conformer.ConformerLayer' model.encoder.context.context.layer.lconv.conv.bias: False -model.encoder.context.context.layer.lconv.conv.klass: 'axlearn.common.layers.Conv1D' +model.encoder.context.context.layer.lconv.conv.klass: 'axlearn.common.convolution.Conv1D' model.encoder.context.context.layer.lconv.conv.num_input_dim_groups: 1 model.encoder.context.context.layer.lconv.conv.padding: 'SAME' model.encoder.context.context.layer.lconv.conv.param_partition_spec[0]: None @@ -227,7 +227,7 @@ model.encoder.feature.klass: 'axlearn.audio.encoder_asr.SpeechFeatureLayer' model.encoder.feature.output_dim: 4 model.encoder.feature.subsampler.activation: 'nn.relu' model.encoder.feature.subsampler.conv.bias: True -model.encoder.feature.subsampler.conv.klass: 'axlearn.common.layers.Conv2DWith1DPadding' +model.encoder.feature.subsampler.conv.klass: 'axlearn.common.convolution.Conv2DWith1DPadding' model.encoder.feature.subsampler.conv.num_input_dim_groups: 1 model.encoder.feature.subsampler.conv.padding[0][0]: 1 model.encoder.feature.subsampler.conv.padding[0][1]: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-101.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-101.txt index 892c7aec9..b6841e553 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-101.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-101.txt @@ -124,7 +124,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -169,7 +169,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-152.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-152.txt index 8f17ef8c8..c1254ca15 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-152.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-152.txt @@ -124,7 +124,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -169,7 +169,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-18.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-18.txt index 38979a796..e6ed60f09 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-18.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-18.txt @@ -124,7 +124,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -169,7 +169,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-34.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-34.txt index d7b030459..0dcd1daee 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-34.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-34.txt @@ -124,7 +124,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -169,7 +169,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50-ema.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50-ema.txt index 1ca91ff37..eb5d1b56d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50-ema.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50-ema.txt @@ -125,7 +125,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -170,7 +170,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50.txt index 4735af512..dcce086bb 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-50.txt @@ -124,7 +124,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -169,7 +169,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test.txt index 040b121e2..de1357d0b 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Test.txt @@ -120,7 +120,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -165,7 +165,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Testb.txt b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Testb.txt index a5347b16e..c70afa266 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Testb.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.vision.resnet.imagenet_trainer/ResNet-Testb.txt @@ -121,7 +121,7 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730 model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' model.backbone.stage.block.activation: 'nn.relu' model.backbone.stage.block.conv.bias: False -model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stage.block.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stage.block.conv.num_input_dim_groups: 1 model.backbone.stage.block.conv.padding[0][0]: 1 model.backbone.stage.block.conv.padding[0][1]: 1 @@ -166,7 +166,7 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage' model.backbone.stage.stride: 1 model.backbone.stem.activation: 'nn.relu' model.backbone.stem.conv.bias: False -model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D' +model.backbone.stem.conv.klass: 'axlearn.common.convolution.Conv2D' model.backbone.stem.conv.num_input_dim_groups: 1 model.backbone.stem.conv.padding[0][0]: 3 model.backbone.stem.conv.padding[0][1]: 3 diff --git a/axlearn/vision/detection_heads.py b/axlearn/vision/detection_heads.py index 5dd0da94c..4576ce597 100644 --- a/axlearn/vision/detection_heads.py +++ b/axlearn/vision/detection_heads.py @@ -16,7 +16,8 @@ from axlearn.common import struct from axlearn.common.base_layer import BaseLayer from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.layers import BatchNorm, Conv2D, Linear, get_activation_fn +from axlearn.common.convolution import Conv2D +from axlearn.common.layers import BatchNorm, Linear, get_activation_fn from axlearn.common.module import Module, Tensor, child_context from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer diff --git a/axlearn/vision/efficientdet.py b/axlearn/vision/efficientdet.py index ab569b8e3..b378ab36b 100644 --- a/axlearn/vision/efficientdet.py +++ b/axlearn/vision/efficientdet.py @@ -13,7 +13,8 @@ from axlearn.common.base_layer import BaseLayer from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.layers import BatchNorm, Conv2D, get_activation_fn, set_norm_recursively +from axlearn.common.convolution import Conv2D +from axlearn.common.layers import BatchNorm, get_activation_fn, set_norm_recursively from axlearn.common.module import Module, Tensor, child_context from axlearn.common.param_init import ( PARAM_REGEXP_BIAS, diff --git a/axlearn/vision/fpn.py b/axlearn/vision/fpn.py index 19efa7e8f..ae1517215 100644 --- a/axlearn/vision/fpn.py +++ b/axlearn/vision/fpn.py @@ -30,15 +30,8 @@ from axlearn.common.base_layer import BaseLayer, ParameterSpec from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.layers import ( - BatchNorm, - Conv2D, - Conv2DTranspose, - LayerNorm, - MaxPool2D, - get_activation_fn, - normalize_sum, -) +from axlearn.common.convolution import Conv2D, Conv2DTranspose +from axlearn.common.layers import BatchNorm, LayerNorm, MaxPool2D, get_activation_fn, normalize_sum from axlearn.common.module import Module from axlearn.common.param_init import ( PARAM_REGEXP_WEIGHT, diff --git a/axlearn/vision/mobilenets.py b/axlearn/vision/mobilenets.py index b17beb53b..0b762e129 100644 --- a/axlearn/vision/mobilenets.py +++ b/axlearn/vision/mobilenets.py @@ -33,9 +33,9 @@ from axlearn.common.base_layer import BaseLayer from axlearn.common.base_model import BaseModel from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class +from axlearn.common.convolution import Conv2D from axlearn.common.layers import ( BatchNorm, - Conv2D, Linear, SqueezeExcitation, StochasticDepth, diff --git a/axlearn/vision/mobilenets_blocks.py b/axlearn/vision/mobilenets_blocks.py index 4e36eca80..4e0d61ea5 100644 --- a/axlearn/vision/mobilenets_blocks.py +++ b/axlearn/vision/mobilenets_blocks.py @@ -28,7 +28,8 @@ from axlearn.common.base_layer import BaseLayer from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class -from axlearn.common.layers import BatchNorm, Conv2D, get_activation_fn +from axlearn.common.convolution import Conv2D +from axlearn.common.layers import BatchNorm, get_activation_fn from axlearn.common.module import Module from axlearn.common.param_init import PerGroupInitializer from axlearn.common.utils import Tensor diff --git a/axlearn/vision/resnet.py b/axlearn/vision/resnet.py index 316f3d338..eeb41664d 100644 --- a/axlearn/vision/resnet.py +++ b/axlearn/vision/resnet.py @@ -33,9 +33,9 @@ from axlearn.common import param_init from axlearn.common.base_layer import BaseLayer from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class +from axlearn.common.convolution import Conv2D from axlearn.common.layers import ( BatchNorm, - Conv2D, SqueezeExcitation, StochasticDepth, get_activation_fn, diff --git a/axlearn/vision/retinanet.py b/axlearn/vision/retinanet.py index 78020f543..1ff4dcf6f 100644 --- a/axlearn/vision/retinanet.py +++ b/axlearn/vision/retinanet.py @@ -23,7 +23,8 @@ config_class, config_for_class, ) -from axlearn.common.layers import BatchNorm, Conv2D, get_activation_fn +from axlearn.common.convolution import Conv2D +from axlearn.common.layers import BatchNorm, get_activation_fn from axlearn.common.loss import ReductionMethod, focal_loss, huber_loss from axlearn.common.module import Module, NestedTensor, Tensor, child_context from axlearn.common.param_init import (