diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index eb1f89a0..e30c8392 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -22,6 +22,7 @@ from typing import Any, Callable, Literal, Optional, Union import chex +import einops import jax from absl import logging from jax import nn @@ -712,19 +713,6 @@ def forward(self, x: Tensor) -> Tensor: return super().forward(x) -def _check_conv_cfg(*, padding: ConvPaddingType, strides: Sequence[int]): - if any(s < 1 for s in strides): - raise NotImplementedError(f"strides ({strides}) must be a positive integer.") - - if isinstance(padding, str): - if padding not in SUPPORT_CONV_PADDING: - raise NotImplementedError(f"{padding} padding is not supported.") - else: - padding_flattened = (p for p_tuple in padding for p in p_tuple) - if any(p < 0 for p in padding_flattened): - raise NotImplementedError("Negative padding is not supported") - - class MaxPool2D(BaseLayer): """A wrapper for the 2D max pooling layer.""" @@ -777,6 +765,34 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti return [input_shape[0], output_height, output_width, input_shape[3]] +############################## Convolution ######################################################### + + +def _check_conv_cfg( + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], +): + if any(w < 1 for w in window): + raise NotImplementedError(f"window ({window}) must be a positive integer.") + + if any(s < 1 for s in strides): + raise NotImplementedError(f"strides ({strides}) must be a positive integer.") + + if isinstance(padding, str): + if padding not in SUPPORT_CONV_PADDING: + raise NotImplementedError(f"{padding} padding is not supported.") + else: + padding_flattened = (p for p_tuple in padding for p in p_tuple) + if any(p < 0 for p in padding_flattened): + raise NotImplementedError("Negative padding is not supported") + + if dilation is not None and any(d < 1 for d in dilation): + raise NotImplementedError(f"dilation ({dilation}) must be a positive integer.") + + class BaseConv(BaseLayer): """Base class for convolution layers.""" @@ -799,7 +815,7 @@ def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optiona # Copied from jax.lax._dilate_shape # https://github.com/jax-ml/jax/blob/2d78b172266870bd755b039f6faa2056a51930f9/jax/_src/lax/lax.py#L5763 -def conv_dilate_window(*, window: Sequence[int], dilation: Optional[Sequence[int]] = None): +def conv_dilate_window(*, window: Sequence[int], dilation: Optional[Sequence[int]]): """Returns dilated effective window size. Args: @@ -827,10 +843,12 @@ def conv_explicit_padding( """Returns the explicit padding for "SAME", "VALID", and "CAUSAL" modes. Each mode follows the formulas below: - * SAME: (pad_total//2, pad_total - pad_total//2) + * SAME: (pad_total//2, pad_total - pad_total//2) s.t. pad_total = window-1 * VALID: (0, 0) - * CAUSAL: (dilate_window - stride, stride - 1) - s.t. dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + * CAUSAL: (window - stride, stride - 1) + + Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() For example, window=5 and stride=2, * SAME: padding = (2, 2) @@ -895,30 +913,22 @@ def conv_explicit_padding( """ if not isinstance(padding, str): return padding + window = conv_dilate_window(window=window, dilation=dilation) - if dilation is None: - dilation = (1,) * len(window) - - def same_padding(window, dilation): - dilate_window = conv_dilate_window(window=window, dilation=dilation) - pad_total = tuple(w - 1 for w in dilate_window) + def same_padding(window): + pad_total = tuple(w - 1 for w in window) pad_left = tuple(pt // 2 for pt in pad_total) pad_right = tuple(pt - pl for pt, pl in zip(pad_total, pad_left)) return tuple(zip(pad_left, pad_right)) if padding == "SAME": - return same_padding(window, dilation) + return same_padding(window) elif padding == "VALID": return ((0, 0),) * len(window) elif padding == "CAUSAL": - dilate_window = conv_dilate_window(window=window[:1], dilation=dilation[:1])[0] - stride = strides[0] - pad_left = dilate_window - stride - pad_right = stride - 1 - assert pad_left + pad_right == dilate_window - 1 - causal_padding = ((pad_left, pad_right),) + causal_padding = ((window[0] - strides[0], strides[0] - 1),) if len(window) > 1: - causal_padding += same_padding(window[1:], dilation[1:]) + causal_padding += same_padding(window[1:]) return causal_padding else: raise ValueError(f"{padding} padding is not supported.") @@ -991,6 +1001,8 @@ class Config(BaseConv.Config): # Paddings: "SAME", "VALID", "CAUSAL" or ((top, bottom), (left, right)). # Note: Sequence models use the first component to represent time. padding: ConvPaddingType = ((0, 0), (0, 0)) + # The convolution dilation. If None, assume all 1's. + dilation: Optional[tuple[int, int]] = None output_dim: Required[int] = REQUIRED # Output feature dim. bias: bool = True # Whether to add a bias. # The number of groups in which the input is split along the channel axis. @@ -1011,7 +1023,9 @@ def default_config(cls): def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config - _check_conv_cfg(padding=cfg.padding, strides=cfg.strides) + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) params = dict( weight=ParameterSpec( shape=list(cfg.window) @@ -1029,14 +1043,26 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor) -> Tensor: cfg = self.config conv_padding = conv_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], + ) -> Tensor: + cfg = self.config output = jax.lax.conv_general_dilated( lhs=x, rhs=self.parameters["weight"], - window_strides=cfg.strides, + window_strides=strides, + padding=padding, + rhs_dilation=dilation, dimension_numbers=("NHWC", "HWIO", "NHWC"), - padding=conv_padding, feature_group_count=cfg.num_input_dim_groups, ) if cfg.bias: @@ -1056,7 +1082,11 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti in_shape = input_shape[1:3] out_shape = conv_output_shape( - in_shape, window=cfg.window, strides=cfg.strides, padding=cfg.padding + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, ) return [input_shape[0], *out_shape, cfg.output_dim] @@ -1128,6 +1158,7 @@ def compute_conv_paddings( return out_paddings +# TODO(dhwang2): move to convolution transpose section. class Conv2DTranspose(BaseConv): """The 2-D transposed convolution layer. @@ -1141,22 +1172,37 @@ class Config(BaseConv.Config): window: tuple[int, int] = (1, 1) strides: tuple[int, int] = (1, 1) - padding: ConvPaddingType = ((0, 0), (0, 0)) + padding: ConvPaddingType = "SAME" + dilation: tuple[int, int] = (1, 1) output_dim: Required[int] = REQUIRED # Output feature dim. bias: bool = True # Whether to add a bias. + # True for backward compatibility, but False is more computationally efficient. + # If True, flips spatial axes and swaps the input/output channel axes of the kernel. + # This makes the output of this function identical to the gradient-derived functions + # like keras.layers.Conv2DTranspose applied to the same kernel. + transpose_kernel: bool = True @classmethod def default_config(cls): cfg = super().default_config() - cfg.param_partition_spec = (None, None, None, None) + if cfg.transpose_kernel: + cfg.param_partition_spec = (None, None, "model", None) + else: + cfg.param_partition_spec = (None, None, None, "model") return cfg def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config - _check_conv_cfg(padding=cfg.padding, strides=cfg.strides) + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + if cfg.transpose_kernel: + io_shape = (cfg.output_dim, cfg.input_dim) + else: + io_shape = (cfg.input_dim, cfg.output_dim) params = dict( weight=ParameterSpec( - shape=list(cfg.window) + [cfg.output_dim, cfg.input_dim], + shape=tuple(cfg.window) + io_shape, mesh_axes=cfg.param_partition_spec, factorization=FactorizationSpec(axes=(None, None, "row", "col")), ) @@ -1168,17 +1214,29 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: return params def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = conv_transpose_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], + ) -> Tensor: cfg = self.config output = jax.lax.conv_transpose( lhs=x, rhs=self.parameters["weight"], - strides=cfg.strides, + strides=strides, + padding=padding, + rhs_dilation=dilation, dimension_numbers=("NHWC", "HWIO", "NHWC"), - padding=cfg.padding, - # if True flips spatial axes and swaps the input/output channel axes of the kernel. - # This makes the output of this function identical to the gradient-derived functions - # like keras.layers.Conv2DTranspose applied to the same kernel. - transpose_kernel=True, + transpose_kernel=cfg.transpose_kernel, ) if cfg.bias: output += self.parameters["bias"] @@ -1194,25 +1252,16 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti f"input_shape[-1] = {input_shape[-1]} does not match " "cfg.input_dim = {cfg.input_dim}." ) - input_height, input_width = input_shape[1:3] - output_height, output_width = None, None - - if cfg.padding == "SAME": - if cfg.padding == "SAME" and any(s > 1 for s in cfg.strides): - raise NotImplementedError("SAME padding does not support strides > 1") - if input_height is not None: - output_height = input_height * cfg.strides[0] - if input_width is not None: - output_width = input_width * cfg.strides[0] - elif cfg.padding == "VALID": - if input_height is not None: - output_height = input_height * cfg.strides[0] + max( - cfg.window[0] - cfg.strides[0], 0 - ) - if input_width is not None: - output_width = input_width * cfg.strides[1] + max(cfg.window[1] - cfg.strides[1], 0) - return [input_shape[0], output_height, output_width, cfg.output_dim] + in_shape = input_shape[1:3] + out_shape = conv_transpose_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] class Conv2DWith1DPadding(Conv2D): @@ -1319,11 +1368,13 @@ def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: # Apply Conv2D. output = super().forward(x) # Compute paddings conv output. + dilation = 1 if cfg.dilation is None else cfg.dilation[0] output_paddings = compute_conv_paddings( paddings, window=cfg.window[0], stride=cfg.strides[0], conv_padding=cfg.padding, + dilation=dilation, anchor=cfg.anchor, ) # Apply padding to the outputs. @@ -1351,6 +1402,8 @@ class Config(BaseConv.Config): (0, 0), (0, 0), ) + # The convolution dilation. If None, assume all 1's. + dilation: Optional[tuple[int, int, int]] = None output_dim: Required[int] = REQUIRED # Output feature dim. bias: bool = True # Whether to add a bias. @@ -1373,7 +1426,9 @@ def default_config(cls): def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config - _check_conv_cfg(padding=cfg.padding, strides=cfg.strides) + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) params = dict( weight=ParameterSpec( shape=list(cfg.window) @@ -1391,14 +1446,26 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor) -> Tensor: cfg = self.config conv_padding = conv_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], + ) -> Tensor: + cfg = self.config output = jax.lax.conv_general_dilated( lhs=x, rhs=self.parameters["weight"], - window_strides=cfg.strides, + window_strides=strides, + padding=padding, + rhs_dilation=dilation, dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), - padding=conv_padding, feature_group_count=cfg.num_input_dim_groups, ) if cfg.bias: @@ -1418,7 +1485,11 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti in_shape = input_shape[1:4] out_shape = conv_output_shape( - in_shape, window=cfg.window, strides=cfg.strides, padding=cfg.padding + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, ) return [input_shape[0], *out_shape, cfg.output_dim] @@ -1450,12 +1521,9 @@ class Config(BaseConv.Config): # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, # where K is a positive integer, the operation is also known as a "depthwise convolution". num_input_dim_groups: Optional[int] = 1 - # LHS_dilation is also known as transposed convolution. It is either None, or an int - # indicating the dilation factor applied on the input. - lhs_dilation: Optional[int] = None - # RHS_dilation is also known as atrous convolution. It is either None, or an int indicating - # dilation factor applied to the weight. - rhs_dilation: Optional[int] = None + # The convolution dilation, indicating dilation factor applied to the weight. It is also + # known as atrous convolution or dilated convolution. If None, assume 1. + dilation: Optional[int] = None @classmethod def default_config(cls): @@ -1463,14 +1531,15 @@ def default_config(cls): cfg.param_partition_spec = (None, None, "model") return cfg - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) - # Check lhs_dilation and padding setting compatibility. - if cfg.lhs_dilation is not None and cfg.lhs_dilation != 1 and isinstance(cfg.padding, str): - raise ValueError("String padding is not supported for LHS dilation.") - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config + dilation = cfg.dilation or 1 + _check_conv_cfg( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(dilation,), + ) if cfg.padding not in SUPPORT_CONV_PADDING: left, right = cfg.padding[0] if any(p < 0 for p in (left, right)): @@ -1489,25 +1558,59 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor) -> Tensor: cfg = self.config - dilation = cfg.rhs_dilation or 1 + dilation = cfg.dilation or 1 conv_padding = conv_explicit_padding( - window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=(dilation,) + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(dilation,), ) - transpose_dilation = cfg.lhs_dilation or 1 + return self._conv(x=x, strides=(cfg.strides,), padding=conv_padding, dilation=(dilation,)) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Optional[Sequence[int]], + ) -> Tensor: + cfg = self.config output = jax.lax.conv_general_dilated( lhs=x, rhs=self.parameters["weight"], - window_strides=(cfg.strides,), + window_strides=strides, + padding=padding, + rhs_dilation=dilation, dimension_numbers=("NWC", "WIO", "NWC"), - padding=conv_padding, feature_group_count=cfg.num_input_dim_groups, - lhs_dilation=(transpose_dilation,), - rhs_dilation=(dilation,), ) if cfg.bias: output += self.parameters["bias"] return output + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 3: + raise ValueError(f"We expect len(input_shape) = 3, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:2] + dilation = cfg.dilation or 1 + out_shape = conv_output_shape( + in_shape, + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(dilation,), + ) + return [input_shape[0], *out_shape, cfg.output_dim] + class Conv1DWithPadding(Conv1D): """The 1-D convolution with 1-D padding on the time axis.""" @@ -1544,19 +1647,13 @@ def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: # Apply Conv1D. output = super().forward(x) - # TODO(dhwang2): Implement Conv1DTranspose separately for lhs_dilation. It's problematic - # for lhs_dilation (Conv Transpose) and rhs_dilation (Dilated Convolution) to be part of - # the same class. Not only are they never used together, but their combined usage would - # result in undefined behavior. Additionally, the logic for handling explicit padding and - # paddings is fundamentally different between them, so supporting both in a single class - # makes the code error-prone. # Compute paddings conv output. output_paddings = compute_conv_paddings( paddings, window=cfg.window, stride=cfg.strides, conv_padding=cfg.padding, - dilation=cfg.rhs_dilation, + dilation=cfg.dilation, anchor=cfg.anchor, ) # Apply padding to the outputs. @@ -1564,7 +1661,7 @@ def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: return output, output_paddings -class DepthwiseConv1D(BaseConv): +class DepthwiseConv1D(Conv1D): """The 1-D depth-wise convolution layer. Kernel weights have the WIO layout and in the shape of (window, 1, output_dim=input_dim). @@ -1572,16 +1669,562 @@ class DepthwiseConv1D(BaseConv): """ @config_class - class Config(BaseConv.Config): + class Config(Conv1D.Config): """Configures DepthwiseConv1D.""" - window: Required[int] = REQUIRED # The convolution window. - strides: int = 1 # The convolution strides. - # Paddings: "SAME", "VALID", "CAUSAL" or (left, right). - # For causal convolution, set padding to (window - 1, 0). - padding: ConvPaddingType = ((0, 0),) + output_dim: Optional[int] = None + num_input_dim_groups: Optional[int] = None + + def __init__(self, cfg: Config, *, parent: Module): + cfg.num_input_dim_groups = cfg.input_dim + cfg.output_dim = cfg.input_dim + super().__init__(cfg, parent=parent) + + +############################## Transposed Convolution ############################################## + + +# Based on jax.lax.convolution._conv_transpose_padding, but ours is more intuitive. +def conv_transpose_explicit_padding( + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], +) -> ConvPaddingType: + """Convert str padding to tuple padding for conv_transpose. + + Each mode follows the formulas below, + * SAME: (min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) + pad_total = window+stride-2 + when stride > window -> (window-1, stride-1) + * VALID: (window-1, max(stride-1, window-1)) + pad_total = window+stride-2 + max(window-stride, 0) + when stride > window -> (window-1, stride-1) + * CAUSAL: (window-1, stride-1) + pad_total = window+stride-2 + + Note: output_size = input_size*stride - (window+stride-2) + pad_total + = input_size*stride <- "SAME", "CAUSAL" + = input_size*stride + max(window-stride, 0) <- "VALID" + + Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + + The following illustration demonstrates how Conv Transpose operates, assuming all kernel values + are set to 1 for simplicity in showcasing output values. + + In the window=3 and stride=1 case, this function creates outputs as follows: + * "SAME" padding=(1, 1) + pad| |pad + paddings: 0|0 0 1 1|0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + 1 1 0 -> 2 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 0 1 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + 1 1 0 -> 2 + 1 0 0 -> 1 + + * "CAUSAL" padding=(2, 0) + pad | |pad + paddings: 0 0|0 0 1 1| + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + + In the window=3 and stride=2 case, this function creates outputs as follows: + * "SAME" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + 1 0 0 -> 1 + + * "CAUSAL" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + + In the window=3 and stride=3 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * * 0 * * 1 * * 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + + In the window=3 and stride=4 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 3) + pad | |pad + paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 0 -> 0 + Here is how to compute output_size, given the above example, + 1. |_| -(window-1) + 2. |_______________________| (input_size-1)*stride + 1 + 3. |_| |___| + pad_total + + So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total + = input_size*stride - (window+stride-2) + pad_total + = input_size*stride <- "SAME", "CAUSAL" + = input_size*stride + max(window-stride, 0) <- "VALID" + + OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. + For example, when window=3 and dilation=2, dilate_window=5. + + In the stride=2 case, this function creates outputs as follows: + * "SAME" padding=(3, 2) + pad | |pad + paddings: 0 0 0|0 * 0 * 1 * 1|0 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + 1 * 1 * 0 -> 2 + + * "VALID" padding=(4, 4) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + 1 * 1 * 0 -> 2 + 0 * 0 * 0 -> 0 + 1 * 0 * 0 -> 1 + + * "CAUSAL" padding=(4, 1) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + + For "CAUSAL", the first component is time and treated as "CAUSAL", while the remaining + components are handled with "SAME" padding. + + Args: + window: convolution window. + strides: transposed convolution strides. It's lhs_dilation, not window_stride. + padding: convolution padding. + dilation: convolution dilation, a.k.a rhs_dilation. + + Returns: + The padding tuple. + + Raises: + ValueError: If padding is not supported. + """ + if not isinstance(padding, str): + return padding + + window = conv_dilate_window(window=window, dilation=dilation) + + def same_padding(window, strides): + pad_left = tuple(min(w - 1, (w + s - 1) // 2) for w, s in zip(window, strides)) + pad_right = tuple(max(s - 1, (w + s - 2) // 2) for w, s in zip(window, strides)) + return tuple(zip(pad_left, pad_right)) + + if padding == "SAME": + return same_padding(window, strides) + elif padding == "VALID": + pad_left = tuple(w - 1 for w in window) + pad_right = tuple(max(s - 1, w - 1) for w, s in zip(window, strides)) + return tuple(zip(pad_left, pad_right)) + elif padding == "CAUSAL": + causal_padding = ((window[0] - 1, strides[0] - 1),) + if len(window) > 1: + causal_padding += same_padding(window[1:], strides[1:]) + return causal_padding + else: + raise ValueError(f"{padding} padding is not supported.") + + +def conv_transpose_output_shape( + in_shape: Sequence[Optional[int]], + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], +) -> Sequence[int]: + """Returns output size for conv transpose. + + Each mode follows the formulas below, + * SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) + pad_total = window+stride-2 + output_size = input_size*stride + * VALID: padding=(window-1, max(stride-1, window-1)) + pad_total = window+stride-2 + max(window-stride, 0) + output_size = input_size*stride + max(window-stride, 0) + * CAUSAL: padding=(window-1, stride-1) + pad_total = window+stride-2 + output_size = input_size*stride + + Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + + Refer to + https://towardsdatascience.com/understand-transposed-convolutions-and-build-your-own-transposed-convolution-layer-from-scratch-4f5d97b2967 + + Args: + in_shape: convolution lhs shape. + window: convolution window. + strides: convolution strides. + padding: convolution padding. + dilation: convolution dilation. + + Returns: + The output shape. + + Raises: + ValueError: If the length of in_shape, window, strides, and padding are not equal. + """ + if len(in_shape) != len(window) or len(in_shape) != len(strides): + raise ValueError( + f"len(in_shape) = {len(in_shape)} must be equal to " + f"len(window) = {len(window)} and len(strides) = {len(strides)}" + ) + + window = conv_dilate_window(window=window, dilation=dilation) + + def output_shape(in_shape: Optional[int], window: int, stride: int): + if in_shape is None: + return None + + if padding == "SAME": + return in_shape * stride + elif padding == "VALID": + return in_shape * stride + max(window - stride, 0) + elif padding == "CAUSAL": + return in_shape * stride + else: + raise ValueError(f"{padding} padding is not supported.") + + return tuple(map(output_shape, in_shape, window, strides)) + + +def compute_conv_transpose_paddings( + in_paddings: Tensor, + *, + window: int, + stride: int, + conv_padding: ConvPaddingType, + dilation: int = 1, + anchor: Optional[int] = None, +): + """Compute output paddings w.r.t. conv_padding for conv transpose. + + The output paddings value is determined by the padding value at the anchor point in the + window. If anchor is None, the default anchor point is the left time padding from conv + padding config. See `Conv2DWith1DPadding.Config` in details. + + In the window=3 and stride=1 case, this function creates paddings as follows: + + The following illustration demonstrates how Conv Transpose operates, assuming all kernel values + are set to 1 for simplicity in showcasing output values. + + In the window=3 and stride=1 case, this function creates outputs as follows: + * "SAME" padding=(1, 1) + pad| |pad + paddings: 0|0 0 1 1|1 + |_____| + * 0 * -> 0 + * 0 * -> 0 + * 1 * -> 1 + * 1 0 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 0 1 1|1 1 + |_________| + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "CAUSAL" padding=(2, 0) + pad | |pad + paddings: 0 0|0 0 1 1| + |_____| + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + + In the window=3 and stride=2 case, this function creates outputs as follows: + * "SAME" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 1 + |_______________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "CAUSAL" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + In the window=3 and stride=3 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * * 0 * * 1 * * 1|1 1 + |_____________________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. + For example, when window=3 and dilation=2, dilate_window=5. + + In the stride=2 case, this function creates outputs as follows: + * "SAME" padding=(3, 2) + pad | |pad + paddings: 0 0 0|0 * 0 * 1 * 1|1 1 + |_____________| + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 1 * -> 1 + * * * 1 * -> 1 + * * * 1 * -> 1 + * * * 1 * -> 1 + + * "VALID" padding=(4, 4) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|1 1 1 1 + |___________________| + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + + * "CAUSAL" padding=(4, 1) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + + Args: + in_paddings: A Tensor of shape [batch_size, seq_len]. + window: convolution window size of the time axis. + stride: convolution stride size of the time axis. + conv_padding: "SAME", "VALID", "CAUSAL" or ((left_time_padding, right_time_padding),) + dilation: convolution dilation size of the time axis. + anchor: an optional integer in the range of [0, window) + that specifies the anchor position within the convolution window that is used to + determine output paddings. Specifically, the output token is valid iff the input token + at the anchor position of the corresponding window is valid. + If None, anchor defaults to conv_padding[0] (i.e. left_time_padding). + + Returns: + out_paddings: A Tensor of shape [batch_size, seq_len]. + + Raises: + ValueError: If anchor is not between left_time_padding and window. + """ + + chex.assert_rank(in_paddings, 2) + conv_padding = conv_transpose_explicit_padding( + window=(window,), strides=(stride,), padding=conv_padding, dilation=(dilation,) + ) + window = conv_dilate_window(window=(window,), dilation=(dilation,))[0] + # Note: in transposed conv, left_pad + right_pad >= window - 1. + # See conv_transpose_explicit_padding(). + left_pad, right_pad = conv_padding[0] + + if anchor is None: + anchor = left_pad + # elif not left_pad <= anchor < window: + elif not anchor < window: + raise ValueError(f"anchor ({anchor}) must in range [0, {window}).") + + # Consider the case where window=3, strides=2, dilation=2, and padding="SAME" + # explicit padding=(3, 2) + # pad | |pad + # paddings: 0 0 0|0 * 0 * 1 * 1|1 1 + # |_____________| + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + + # |0 0 1 1| -> |0 * 0 * 1 * 1| + def dilate_paddings(paddings): + most, last = jnp.split(paddings, [paddings.shape[1] - 1], axis=1) + dilated = einops.repeat(most, "b t -> b (t s)", s=stride) + return jnp.concatenate([dilated, last], axis=1) + + in_paddings = dilate_paddings(in_paddings) + + # |0 * 0 * 1 * 1| -> 0 0 0|0 * 0 * 1 * 1|1 1 + # |_____________| which is |0 * 0 * 1 * 1|1 + window_pad_total = window - 1 # Note: we already check `anchor < window`` always. + window_right_pad = window_pad_total - anchor + assert window_right_pad >= 0, f"{anchor=} < {window=} always." + # Note: left_pad + right_pad >= window + stride - 2 >= window - 1 == anchor + window_right_pad + valid_right_pad = right_pad - window_right_pad + if valid_right_pad >= 0: + out_paddings = jnp.pad(in_paddings, ((0, 0), (0, valid_right_pad)), mode="edge") + else: + out_paddings = in_paddings[:, :valid_right_pad] + + start_index = anchor - left_pad + if start_index < 0: + out_paddings = jnp.pad(out_paddings, ((0, 0), (-start_index, 0)), mode="edge") + else: + out_paddings = out_paddings[:, start_index:] + return out_paddings + + +class Conv1DTranspose(BaseConv): + """The 1D transposed convolution layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv1DTranspose.""" + + window: int = 1 + strides: int = 1 + padding: ConvPaddingType = "SAME" + dilation: int = 1 # Dilation for dilated Convolution. + output_dim: Required[int] = REQUIRED # Output feature dim. bias: bool = True # Whether to add a bias. + # An optional integer in the range of [0, window) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See compute_conv_transpose_paddings more details. + anchor: Optional[int] = None + @classmethod def default_config(cls): cfg = super().default_config() @@ -1590,43 +2233,249 @@ def default_config(cls): def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config - if cfg.padding not in SUPPORT_CONV_PADDING: - left, right = cfg.padding[0] - if any(p < 0 for p in (left, right)): - raise NotImplementedError("Negative padding is not supported") + _check_conv_cfg( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) params = dict( weight=ParameterSpec( - # https://www.tensorflow.org/xla/operation_semantics#conv_convolution: - # The input feature dimension of rhs needs to be equal to the lhs input feature - # dimension divided by feature_group_count (so it already has the size of a group - # of input features). - shape=[cfg.window, 1, cfg.input_dim], + shape=(cfg.window, cfg.input_dim, cfg.output_dim), mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, "row", "col")), ) ) if cfg.bias: params["bias"] = ParameterSpec( - shape=[cfg.input_dim], mesh_axes=(cfg.param_partition_spec[-1],) + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) ) return params - def forward(self, x: Tensor) -> Tensor: + def forward( + self, x: Tensor, *, paddings: Optional[Tensor] = None + ) -> tuple[Tensor, Optional[Tensor]]: cfg = self.config - conv_padding = conv_explicit_padding( - window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding + conv_padding = conv_transpose_explicit_padding( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), ) - output = jax.lax.conv_general_dilated( + + if paddings is not None: + chex.assert_rank(x, paddings.ndim + 1) + # Apply padding to the input. + x = x * (1 - paddings[..., None]) + + output = self._conv( + x=x, strides=(cfg.strides,), padding=conv_padding, dilation=(cfg.dilation,) + ) + + if paddings is None: + output_paddings = None + else: + # Compute paddings conv output. + output_paddings = compute_conv_transpose_paddings( + paddings, + window=cfg.window, + stride=cfg.strides, + conv_padding=cfg.padding, + dilation=cfg.dilation, + anchor=cfg.anchor, + ) + output = output * (1 - output_paddings[..., None]) + return output, output_paddings + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_transpose( lhs=x, rhs=self.parameters["weight"], - window_strides=(cfg.strides,), + strides=strides, + padding=padding, + rhs_dilation=dilation, dimension_numbers=("NWC", "WIO", "NWC"), - padding=conv_padding, - feature_group_count=cfg.input_dim, ) if cfg.bias: output += self.parameters["bias"] return output + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 3: + raise ValueError(f"We expect len(input_shape) = 3, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + "cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:2] + out_shape = conv_transpose_output_shape( + in_shape, + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +class Conv2DTransposeWith1DPadding(Conv2DTranspose): + """The 2-D convolution transpose with 1-D padding on the time axis.""" + + @config_class + class Config(Conv2DTranspose.Config): + """Configures Conv2DTransposeWith1DPadding.""" + + transpose_kernel: bool = False + # An optional integer in the range of [0, window) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See compute_conv_transpose_paddings more details. + anchor: Optional[int] = None + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.transpose_kernel = False # Choose better one unlike parent. + return cfg + + # We add a kwargs "paddings" to the forward method. + # pylint: disable-next=arguments-differ + def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: + """Computes convolution outputs and paddings. + + Args: + x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + + Returns: + output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + """ + cfg = self.config + # Apply padding to the input. + assert len(x.shape) == len(paddings.shape) + 2 + x = x * (1 - paddings[..., None, None]) + + # Apply Conv2D. + output = super().forward(x) + # Compute paddings conv output. + output_paddings = compute_conv_transpose_paddings( + paddings, + window=cfg.window[0], + stride=cfg.strides[0], + conv_padding=cfg.padding, + dilation=cfg.dilation[0], + anchor=cfg.anchor, + ) + # Apply padding to the outputs. + output = output * (1 - output_paddings[..., None, None]) + return output, output_paddings + + +class Conv3DTranspose(BaseConv): + """The 3-D convolution transpose layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv3DTranspose.""" + + window: tuple[int, int, int] = (1, 1, 1) # The convolution window. + strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. + # Paddings: "SAME" or "VALID, or ((top, bottom), (left, right), (front, back)) + padding: ConvPaddingType = "SAME" + dilation: tuple[int, int, int] = (1, 1, 1) # The convolution dilation. + + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + params = dict( + weight=ParameterSpec( + shape=cfg.window + (cfg.input_dim, cfg.output_dim), + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = conv_transpose_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + return self._conv(x=x, strides=cfg.strides, padding=conv_padding, dilation=cfg.dilation) + + def _conv( + self, + x: Tensor, + *, + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], + ) -> Tensor: + cfg = self.config + output = jax.lax.conv_transpose( + lhs=x, + rhs=self.parameters["weight"], + strides=strides, + padding=padding, + rhs_dilation=dilation, + dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 5: + raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:4] + out_shape = conv_transpose_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +############################## Others ############################################################## + class Embedding(BaseLayer): """Implements an embedding lookup function. @@ -2150,7 +2999,7 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: padding = cfg.padding if isinstance(padding, str): padding = conv_explicit_padding( - window=(cfg.stride,), strides=(cfg.stride,), padding=padding + window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) )[0] inputs = jnp.pad(inputs, ((0, 0), padding, (0, 0)), constant_values=0) @@ -2183,7 +3032,7 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti if isinstance(padding, tuple): padding = (padding,) out_shape = conv_output_shape( - [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding + [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) ) return [batch_size, *out_shape, input_dim * cfg.stride] diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index 1c5ffc5f..98dc6bec 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -1052,7 +1052,7 @@ def test_conv_output_1d_padding_causal(self, in_paddings, expected): stride = 6 padding = "CAUSAL" explicit_padding = layers.conv_explicit_padding( - window=(window,), strides=(stride,), padding=padding + window=(window,), strides=(stride,), padding=padding, dilation=(1,) ) self.assertAllEqual(explicit_padding[0], (9, 5)) @@ -1082,7 +1082,7 @@ def test_conv_output_1d_padding_against_str_padding( paddings = jnp.triu(jnp.ones((batch_size, seq_len)), k=1) explicit_padding = layers.conv_explicit_padding( - window=(window,), strides=(stride,), padding=ref_padding + window=(window,), strides=(stride,), padding=ref_padding, dilation=(1,) ) self.assertAllEqual(explicit_padding, padding[:1]) @@ -1224,6 +1224,9 @@ def test_conv2d_with_1d_padding( state=layer_params, prng_key=prng_key, ) + output_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(ref_outputs.shape, output_shape) + random_keys = jax.random.split(input_key, num=2 * max_seq_len) for seq_len in range(1, max_seq_len): # We create a new batch. The time axis of the new batch is of length seq_len. @@ -1336,6 +1339,8 @@ def test_conv1d_against_conv2d_with_1d_padding( state=test_state, prng_key=prng_key, ) + output_shape = test_layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(test_outputs.shape, output_shape) inputs = einops.rearrange(inputs, "b t i -> b t 1 i") (ref_outputs, ref_paddings), _ = F( @@ -1345,6 +1350,8 @@ def test_conv1d_against_conv2d_with_1d_padding( state=state, prng_key=prng_key, ) + output_shape = ref_layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(ref_outputs.shape, output_shape) ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") assert_allclose(ref_paddings, test_paddings) @@ -1370,7 +1377,7 @@ def test_conv1d_against_conv2d_with_1d_padding( "padding": "VALID", }, ) - def test_deconv2d( + def test_conv2d_transpose_against_pytorch( self, window: tuple[int, int], strides: tuple[int, int], @@ -1610,7 +1617,7 @@ def test_conv1d( window: int, strides: int, padding: ConvPaddingType, - dilation: Optional[int] = None, + dilation: Optional[int], ): input_dim, output_dim = 4, 6 cfg = Conv1D.default_config().set( @@ -1620,7 +1627,7 @@ def test_conv1d( window=window, strides=strides, padding=padding, - rhs_dilation=dilation, + dilation=dilation, ) layer: Conv1D = cfg.instantiate(parent=None) # Initialize layer parameters. @@ -1649,6 +1656,8 @@ def test_conv1d( state=layer_params, prng_key=prng_key, ) + output_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, output_shape) # Compute ref outputs. if isinstance(padding, str): @@ -1673,70 +1682,6 @@ def test_conv1d( ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) - @parameterized.named_parameters( - ("w3s1d3_pad1", 3, 1, (1, 1), 3), - ("w3s1d3_pad1_causal", 3, 1, (1, 0), 3), - ("w3s2d3_pad1", 3, 2, (1, 1), 3), - ("w3s2d1_pad1", 3, 2, (1, 1), None), - ) - def test_lhs_dilation_conv1d( - self, - window: int, - strides: int, - padding: tuple[int, int], - dilation: Optional[int], - ): - input_dim, output_dim = 4, 6 - cfg = Conv1D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=(padding,), - lhs_dilation=dilation, - ) - layer: Conv1D = cfg.instantiate(parent=None) - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict(weight=(window, input_dim, output_dim), bias=(output_dim,)), - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - prng_key, input_key = jax.random.split(prng_key) - input_length = 7 - inputs = jax.random.normal(input_key, [2, input_length, input_dim]) - outputs, _ = F( - layer, inputs=(inputs,), is_training=True, state=layer_params, prng_key=prng_key - ) - expected_length = input_length + padding[0] + padding[1] - int((window + 1) / 2) - if dilation is not None: - expected_length = expected_length + (input_length - 1) * (dilation - 1) - expected_length = math.floor((expected_length - 1) / strides + 1) - self.assertNestedEqual(outputs.shape, (2, expected_length, output_dim)) - - def test_lhs_dilation_not_using_string_padding_conv1d(self): - input_dim, output_dim = 4, 6 - cfg = Conv1D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=3, - strides=1, - padding="SAME", - lhs_dilation=2, - ) - with self.assertRaises(ValueError): - cfg.instantiate(parent=None) - @parameterized.named_parameters( ("w3s1_VALID", 3, 1, "VALID"), ("w3s1_SAME", 3, 1, "SAME"), @@ -1786,6 +1731,8 @@ def test_depthwise_conv1d( state=layer_params, prng_key=prng_key, ) + output_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, output_shape) # Compute ref outputs. if isinstance(padding, str): @@ -1809,6 +1756,661 @@ def test_depthwise_conv1d( ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) + ############################## Transposed Convolution ########################################## + + @parameterized.parameters( + (3, 1, "SAME", 1, (1, 1)), + (3, 2, "SAME", 1, (2, 1)), + (3, 3, "SAME", 1, (2, 2)), + (3, 4, "SAME", 1, (2, 3)), + (3, 1, "SAME", 2, (2, 2)), + (3, 2, "SAME", 2, (3, 2)), + (3, 3, "SAME", 2, (3, 3)), + (3, 1, "VALID", 1, (2, 2)), + (3, 2, "VALID", 1, (2, 2)), + (3, 3, "VALID", 1, (2, 2)), + (3, 4, "VALID", 1, (2, 3)), + (3, 1, "VALID", 2, (4, 4)), + (3, 2, "VALID", 2, (4, 4)), + (3, 3, "VALID", 2, (4, 4)), + (3, 1, "CAUSAL", 1, (2, 0)), + (3, 2, "CAUSAL", 1, (2, 1)), + (3, 3, "CAUSAL", 1, (2, 2)), + (3, 4, "CAUSAL", 1, (2, 3)), + (3, 1, "CAUSAL", 2, (4, 0)), + (3, 2, "CAUSAL", 2, (4, 1)), + (3, 3, "CAUSAL", 2, (4, 2)), + ) + def test_conv_transpose_explicit_padding(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + explicit_padding = layers.conv_transpose_explicit_padding( + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + ) + self.assertAllEqual(explicit_padding[0], expected) + + @parameterized.parameters( + (3, 1, "SAME", 1, 4), + (3, 2, "SAME", 1, 8), + (3, 3, "SAME", 1, 12), + (3, 4, "SAME", 1, 16), + (3, 1, "SAME", 2, 4), + (3, 2, "SAME", 2, 8), + (3, 3, "SAME", 2, 12), + (3, 1, "VALID", 1, 6), + (3, 2, "VALID", 1, 9), + (3, 3, "VALID", 1, 12), + (3, 4, "VALID", 1, 16), + (3, 1, "VALID", 2, 8), + (3, 2, "VALID", 2, 11), + (3, 3, "VALID", 2, 14), + (3, 1, "CAUSAL", 1, 4), + (3, 2, "CAUSAL", 1, 8), + (3, 3, "CAUSAL", 1, 12), + (3, 4, "CAUSAL", 1, 16), + (3, 1, "CAUSAL", 2, 4), + (3, 2, "CAUSAL", 2, 8), + (3, 3, "CAUSAL", 2, 12), + ) + def test_conv_transpose_output_shape(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + out_shape = layers.conv_transpose_output_shape( + in_shape=(4,), + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + ) + self.assertAllEqual(out_shape[0], expected) + + @parameterized.parameters( + (3, 1, "SAME", 1, [0, 0, 1, 1]), + (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "SAME", 2, [0, 0, 1, 1]), + (3, 2, "SAME", 2, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 1, "VALID", 1, [0, 0, 1, 1, 1, 1]), + (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 1, 1, 1]), + (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), + (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]), + (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "CAUSAL", 1, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + ) + def test_compute_conv_transpose_paddings(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + in_paddings = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :] + out_paddings = layers.compute_conv_transpose_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + expected = jnp.array(expected).astype(out_paddings.dtype) + self.assertNestedEqual(out_paddings[0], expected) + + @parameterized.product( + window=[1, 3], + strides=[1, 2, 3], + padding=["SAME", "VALID", "CAUSAL"], + dilation=[1, 2], + value=[0, 1], + ) + def test_compute_conv_transpose_paddings_all0or1( + self, window, strides, padding, dilation, value + ): + """If in_paddings is all valid or invalid, out_paddings must be all valid or invalid.""" + in_paddings = jnp.full([1, 4], fill_value=value) + out_paddings = layers.compute_conv_transpose_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + expected = jnp.ones_like(out_paddings) * value + self.assertNestedEqual(out_paddings, expected) + + CONVT_PADDINGS_PARAMS = dict( + in_paddings=[ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 1], + [1, 0, 0, 0, 1], + [1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 0, 0, 1, 1], + [1, 1, 0, 1, 1, 1, 1, 0, 0, 0], + ], + window=[1, 3], + padding=["SAME", "VALID", "CAUSAL"], + dilation=[1, 2], + ) + + @parameterized.product(**CONVT_PADDINGS_PARAMS, strides=[1, 2, 3]) + def test_compute_conv_transpose_paddings_with_conv_paddings( + self, in_paddings, window, strides, padding, dilation + ): + """Check if ConvT -> Conv preserves information.""" + in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] + out_paddings = layers.compute_conv_transpose_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + + recon_paddings = layers.compute_conv_paddings( + out_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + self.assertNestedEqual(recon_paddings[0], in_paddings[0]) + + @parameterized.product(**CONVT_PADDINGS_PARAMS) + def test_compute_conv_transpose_paddings_against_conv_paddings( + self, in_paddings, window, padding, dilation + ): + # compute_conv_transpose_paddings and compute_conv_paddings are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + strides = 1 + if padding == "VALID": + # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint + # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. + # For reference, the midpoint in SAME and CAUSAL is left_pad. + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + conv_padding = layers.conv_explicit_padding( + window=(window,), strides=(strides,), padding=padding, dilation=(dilation,) + )[0] + pad_left, pad_right = conv_padding + anchor_range = dilate_window - pad_left - pad_right + mid_point = anchor_range // 2 + anchor = pad_left + mid_point + else: + anchor = None + + in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] + ref_paddings = layers.compute_conv_paddings( + in_paddings, + window=window, + stride=strides, + conv_padding=padding, + dilation=dilation, + anchor=anchor, + ) + + test_paddings = layers.compute_conv_transpose_paddings( + in_paddings, + window=window, + stride=strides, + conv_padding=padding, + dilation=dilation, + anchor=anchor, + ) + + if ref_paddings.shape != test_paddings.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + pad_left = dilate_window - 1 + test_paddings = test_paddings[:, pad_left:-pad_left] + + assert_allclose(ref_paddings, test_paddings) + + CONVT_PARAMS = [ + (3, 1, "SAME", 1, [0, 1, 2, 2]), + (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 2, 1]), + (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "SAME", 2, [1, 1, 1, 1]), + (3, 2, "SAME", 2, [0, 0, 0, 1, 0, 2, 0, 2]), + (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0]), + (3, 1, "VALID", 1, [0, 0, 1, 2, 2, 1]), + (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 2, 1, 1]), + (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), + (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 1]), + (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1]), + (3, 1, "CAUSAL", 1, [0, 0, 1, 2]), + (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 2, 1]), + (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 0, 2, 0]), + (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1]), + ] + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv1d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None] + cfg = layers.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + (outputs, paddings), _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, out_shape) + self.assertIsNone(paddings) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0], expected) + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv2d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + window = (window, 1) + strides = (strides, 1) + dilation = (dilation, 1) + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None] + cfg = layers.Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + outputs, _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, out_shape) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0, 0], expected) + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv3d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in conv_transpose_explicit_padding() description.""" + window = (window, 1, 1) + strides = (strides, 1, 1) + dilation = (dilation, 1, 1) + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None, None] + cfg = layers.Conv3DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + outputs, _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, out_shape) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0, 0, 0], expected) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv1d_transpose_against_conv1d(self, window, padding, dilation): + # Conv1D and Conv1DTranspose are same when window_stride=1 + # (stride of Conv1D) and lhs_dilation=1 (stride of Conv1DTranspose). + input_dim, output_dim = 4, 6 + ref_cfg = Conv1D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 17, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key + ) + + (test_outputs, _), _ = F( + test_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + pad_left = dilate_window - 1 + test_outputs = test_outputs[:, pad_left:-pad_left] + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv2d_transpose_against_conv2d(self, window, padding, dilation): + # Conv2D and Conv2DTranspose are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + window = (window, window) + dilation = (dilation, dilation) + input_dim, output_dim = 4, 6 + ref_cfg = Conv2D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + transpose_kernel=False, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height = 12, 13 + inputs = jax.random.normal(input_key, [2, width, height, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + test_outputs, _ = F( + test_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] + + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv2d_transpose_against_conv2d_with_paddings(self, window, padding, dilation): + # Conv2DWith1DPadding and Conv2DTransposeWith1DPadding are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + window = (window, window) + dilation = (dilation, dilation) + input_dim, output_dim = 4, 6 + if padding == "VALID": + # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint + # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. + # For reference, the midpoint in SAME and CAUSAL is left_pad. + strides = (1, 1) + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation)[0] + conv_padding = layers.conv_explicit_padding( + window=window, strides=strides, padding=padding, dilation=dilation + ) + pad_left, pad_right = conv_padding[0] + anchor_range = dilate_window - pad_left - pad_right + mid_point = anchor_range // 2 + anchor = pad_left + mid_point + else: + anchor = None + + ref_cfg = Conv2DWith1DPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv2DTransposeWith1DPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + anchor=anchor, + transpose_kernel=False, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height = 12, 13 + inputs = jax.random.normal(input_key, [2, width, height, input_dim]) + paddings = jnp.zeros([2, width], dtype=inputs.dtype).at[:, -2:].set(1) + # Compute layer outputs. + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] + test_paddings = test_paddings[:, pad_left[0] : -pad_left[0]] + + assert_allclose(ref_outputs, test_outputs) + assert_allclose(ref_paddings, test_paddings) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv3d_transpose_against_conv3d(self, window, padding, dilation): + # Conv3D and Conv3DTranspose are same when window_stride=1 + # (stride of Conv3D) and lhs_dilation=1 (stride of Conv3DTranspose). + window = (window, window, window) + dilation = (dilation, dilation, dilation) + input_dim, output_dim = 4, 6 + ref_cfg = Conv3D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv3DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height, depth = 9, 8, 7 + inputs = jax.random.normal(input_key, [2, width, height, depth, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + test_outputs, _ = F( + test_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[ + :, + pad_left[0] : -pad_left[0], + pad_left[1] : -pad_left[1], + pad_left[2] : -pad_left[2], + ] + + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product( + window=(1, 3, 5), + strides=(1, 2), + padding=("SAME", "VALID", "CAUSAL"), + dilation=(1, 2), + anchor=(None, 1), + ) + def test_conv1d_transpose_against_conv2d_transpose_with_1d_padding( + self, + window, + strides, + padding: ConvPaddingType, + dilation, + anchor, + ): + if anchor is not None: + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + anchor = dilate_window - 1 + + input_dim, output_dim = 4, 6 + ref_cfg = layers.Conv2DTransposeWith1DPadding.default_config().set( + name="ref", + input_dim=input_dim, + output_dim=output_dim, + window=(window, 1), + strides=(strides, 1), + padding=padding, + dilation=(dilation, 1), + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + anchor=anchor, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + state = ref_layer.initialize_parameters_recursively(init_key) + test_state = dict( + bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o") + ) + + # Generate a batch of 10 input sequences. + batch_size, max_seq_len = 10, 10 + + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim]) + # The 10 sequences have length 1 to 10. + paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=test_state, + prng_key=prng_key, + ) + + inputs = einops.rearrange(inputs, "b t i -> b t 1 i") + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=state, + prng_key=prng_key, + ) + ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") + + assert_allclose(ref_paddings, test_paddings) + assert_allclose(ref_outputs, test_outputs) + @parameterized.parameters( itertools.product( (None, 0.0, 0.2, 1.0, -0.1),