diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index eb1f89a0..6dd51bcd 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -22,6 +22,7 @@ from typing import Any, Callable, Literal, Optional, Union import chex +import einops import jax from absl import logging from jax import nn @@ -712,19 +713,6 @@ def forward(self, x: Tensor) -> Tensor: return super().forward(x) -def _check_conv_cfg(*, padding: ConvPaddingType, strides: Sequence[int]): - if any(s < 1 for s in strides): - raise NotImplementedError(f"strides ({strides}) must be a positive integer.") - - if isinstance(padding, str): - if padding not in SUPPORT_CONV_PADDING: - raise NotImplementedError(f"{padding} padding is not supported.") - else: - padding_flattened = (p for p_tuple in padding for p in p_tuple) - if any(p < 0 for p in padding_flattened): - raise NotImplementedError("Negative padding is not supported") - - class MaxPool2D(BaseLayer): """A wrapper for the 2D max pooling layer.""" @@ -777,6 +765,34 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti return [input_shape[0], output_height, output_width, input_shape[3]] +############################## Convolution ######################################################### + + +def _check_conv_cfg( + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], +): + if any(w < 1 for w in window): + raise NotImplementedError(f"window ({window}) must be a positive integer.") + + if any(s < 1 for s in strides): + raise NotImplementedError(f"strides ({strides}) must be a positive integer.") + + if isinstance(padding, str): + if padding not in SUPPORT_CONV_PADDING: + raise NotImplementedError(f"{padding} padding is not supported.") + else: + padding_flattened = (p for p_tuple in padding for p in p_tuple) + if any(p < 0 for p in padding_flattened): + raise NotImplementedError("Negative padding is not supported") + + if any(d < 1 for d in dilation): + raise NotImplementedError(f"dilation ({dilation}) must be a positive integer.") + + class BaseConv(BaseLayer): """Base class for convolution layers.""" @@ -799,7 +815,7 @@ def _compute_fan_axes(self, name: str, parameter_spec: ParameterSpec) -> Optiona # Copied from jax.lax._dilate_shape # https://github.com/jax-ml/jax/blob/2d78b172266870bd755b039f6faa2056a51930f9/jax/_src/lax/lax.py#L5763 -def conv_dilate_window(*, window: Sequence[int], dilation: Optional[Sequence[int]] = None): +def conv_dilate_window(*, window: Sequence[int], dilation: Sequence[int]): """Returns dilated effective window size. Args: @@ -809,7 +825,7 @@ def conv_dilate_window(*, window: Sequence[int], dilation: Optional[Sequence[int Returns: The dilated effective window size. """ - if dilation is None or all(d == 1 for d in dilation): + if all(d == 1 for d in dilation): return window return tuple(max(1 + d * (w - 1), 0) for w, d in zip(window, dilation)) @@ -822,15 +838,17 @@ def conv_explicit_padding( window: Sequence[int], strides: Sequence[int], padding: ConvPaddingType, - dilation: Optional[Sequence[int]] = None, + dilation: Sequence[int], ) -> ConvPaddingType: """Returns the explicit padding for "SAME", "VALID", and "CAUSAL" modes. Each mode follows the formulas below: - * SAME: (pad_total//2, pad_total - pad_total//2) + * SAME: (pad_total//2, pad_total - pad_total//2) s.t. pad_total = window-1 * VALID: (0, 0) - * CAUSAL: (dilate_window - stride, stride - 1) - s.t. dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + * CAUSAL: (window - stride, stride - 1) + + Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() For example, window=5 and stride=2, * SAME: padding = (2, 2) @@ -895,30 +913,22 @@ def conv_explicit_padding( """ if not isinstance(padding, str): return padding + window = conv_dilate_window(window=window, dilation=dilation) - if dilation is None: - dilation = (1,) * len(window) - - def same_padding(window, dilation): - dilate_window = conv_dilate_window(window=window, dilation=dilation) - pad_total = tuple(w - 1 for w in dilate_window) + def same_padding(window): + pad_total = tuple(w - 1 for w in window) pad_left = tuple(pt // 2 for pt in pad_total) pad_right = tuple(pt - pl for pt, pl in zip(pad_total, pad_left)) return tuple(zip(pad_left, pad_right)) if padding == "SAME": - return same_padding(window, dilation) + return same_padding(window) elif padding == "VALID": return ((0, 0),) * len(window) elif padding == "CAUSAL": - dilate_window = conv_dilate_window(window=window[:1], dilation=dilation[:1])[0] - stride = strides[0] - pad_left = dilate_window - stride - pad_right = stride - 1 - assert pad_left + pad_right == dilate_window - 1 - causal_padding = ((pad_left, pad_right),) + causal_padding = ((window[0] - strides[0], strides[0] - 1),) if len(window) > 1: - causal_padding += same_padding(window[1:], dilation[1:]) + causal_padding += same_padding(window[1:]) return causal_padding else: raise ValueError(f"{padding} padding is not supported.") @@ -930,7 +940,7 @@ def conv_output_shape( window: Sequence[int], strides: Sequence[int], padding: ConvPaddingType, - dilation: Optional[Sequence[int]] = None, + dilation: Sequence[int], ) -> Sequence[int]: """Returns output size for convolution. @@ -991,6 +1001,7 @@ class Config(BaseConv.Config): # Paddings: "SAME", "VALID", "CAUSAL" or ((top, bottom), (left, right)). # Note: Sequence models use the first component to represent time. padding: ConvPaddingType = ((0, 0), (0, 0)) + dilation: tuple[int, int] = (1, 1) # The convolution dilation. output_dim: Required[int] = REQUIRED # Output feature dim. bias: bool = True # Whether to add a bias. # The number of groups in which the input is split along the channel axis. @@ -1011,7 +1022,9 @@ def default_config(cls): def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config - _check_conv_cfg(padding=cfg.padding, strides=cfg.strides) + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) params = dict( weight=ParameterSpec( shape=list(cfg.window) @@ -1029,7 +1042,7 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor) -> Tensor: cfg = self.config conv_padding = conv_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation ) output = jax.lax.conv_general_dilated( lhs=x, @@ -1037,6 +1050,7 @@ def forward(self, x: Tensor) -> Tensor: window_strides=cfg.strides, dimension_numbers=("NHWC", "HWIO", "NHWC"), padding=conv_padding, + rhs_dilation=cfg.dilation, feature_group_count=cfg.num_input_dim_groups, ) if cfg.bias: @@ -1056,7 +1070,11 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti in_shape = input_shape[1:3] out_shape = conv_output_shape( - in_shape, window=cfg.window, strides=cfg.strides, padding=cfg.padding + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, ) return [input_shape[0], *out_shape, cfg.output_dim] @@ -1067,7 +1085,7 @@ def compute_conv_paddings( window: int, stride: int, conv_padding: ConvPaddingType, - dilation: Optional[int] = None, + dilation: int = 1, anchor: Optional[int] = None, ): """Compute output paddings w.r.t. conv_padding. @@ -1095,7 +1113,6 @@ def compute_conv_paddings( ValueError: If anchor is not between left_time_padding and right_time_padding. """ chex.assert_rank(in_paddings, 2) - dilation = dilation or 1 conv_padding = conv_explicit_padding( window=(window,), strides=(stride,), padding=conv_padding, dilation=(dilation,) ) @@ -1128,93 +1145,6 @@ def compute_conv_paddings( return out_paddings -class Conv2DTranspose(BaseConv): - """The 2-D transposed convolution layer. - - Kernel weights have the HWIO layout and in the shape of (window[0], window[1], output_dim, - input_dim). Both inputs and outputs will be in the NHWC layout. - """ - - @config_class - class Config(BaseConv.Config): - """Configures Conv2DTranspose.""" - - window: tuple[int, int] = (1, 1) - strides: tuple[int, int] = (1, 1) - padding: ConvPaddingType = ((0, 0), (0, 0)) - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, None, None) - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - _check_conv_cfg(padding=cfg.padding, strides=cfg.strides) - params = dict( - weight=ParameterSpec( - shape=list(cfg.window) + [cfg.output_dim, cfg.input_dim], - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=(None, None, "row", "col")), - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward(self, x: Tensor) -> Tensor: - cfg = self.config - output = jax.lax.conv_transpose( - lhs=x, - rhs=self.parameters["weight"], - strides=cfg.strides, - dimension_numbers=("NHWC", "HWIO", "NHWC"), - padding=cfg.padding, - # if True flips spatial axes and swaps the input/output channel axes of the kernel. - # This makes the output of this function identical to the gradient-derived functions - # like keras.layers.Conv2DTranspose applied to the same kernel. - transpose_kernel=True, - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 4: - raise ValueError(f"We expect len(input_shape) = 4, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - "cfg.input_dim = {cfg.input_dim}." - ) - input_height, input_width = input_shape[1:3] - output_height, output_width = None, None - - if cfg.padding == "SAME": - if cfg.padding == "SAME" and any(s > 1 for s in cfg.strides): - raise NotImplementedError("SAME padding does not support strides > 1") - if input_height is not None: - output_height = input_height * cfg.strides[0] - if input_width is not None: - output_width = input_width * cfg.strides[0] - elif cfg.padding == "VALID": - if input_height is not None: - output_height = input_height * cfg.strides[0] + max( - cfg.window[0] - cfg.strides[0], 0 - ) - if input_width is not None: - output_width = input_width * cfg.strides[1] + max(cfg.window[1] - cfg.strides[1], 0) - - return [input_shape[0], output_height, output_width, cfg.output_dim] - - class Conv2DWith1DPadding(Conv2D): """The 2-D convolution with 1-D padding on the time axis.""" @@ -1324,6 +1254,7 @@ def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: window=cfg.window[0], stride=cfg.strides[0], conv_padding=cfg.padding, + dilation=cfg.dilation[0], anchor=cfg.anchor, ) # Apply padding to the outputs. @@ -1331,98 +1262,6 @@ def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: return output, output_paddings -class Conv3D(BaseConv): - """The 3-D convolution layer. - - Kernel weights have the HWDIO layout and in the shape of (window[0], window[1], - window[2], input_dim, output_dim). Both inputs and outputs will be in the NHWDC layout. - """ - - @config_class - class Config(BaseConv.Config): - """Configures Conv3D.""" - - window: tuple[int, int, int] = (1, 1, 1) # The convolution window. - strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. - - # Paddings: "SAME" or "VALID, or ((top, bottom), (left, right), (front, back)) - padding: ConvPaddingType = ( - (0, 0), - (0, 0), - (0, 0), - ) - - output_dim: Required[int] = REQUIRED # Output feature dim. - bias: bool = True # Whether to add a bias. - - # The number of groups in which the input is split along the channel axis. - # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, - # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). - # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers - # side by side, each seeing half the input and producing half the output channels. - # - At num_input_dim_groups=input_dim, each input channel is convolved with its own - # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, - # where K is a positive integer, the operation is also known as a "depthwise convolution". - num_input_dim_groups: Optional[int] = 1 - - @classmethod - def default_config(cls): - cfg = super().default_config() - cfg.param_partition_spec = (None, None, None, None, None) - return cfg - - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: - cfg = self.config - _check_conv_cfg(padding=cfg.padding, strides=cfg.strides) - params = dict( - weight=ParameterSpec( - shape=list(cfg.window) - + [cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], - mesh_axes=cfg.param_partition_spec, - factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), - ) - ) - if cfg.bias: - params["bias"] = ParameterSpec( - shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) - ) - return params - - def forward(self, x: Tensor) -> Tensor: - cfg = self.config - conv_padding = conv_explicit_padding( - window=cfg.window, strides=cfg.strides, padding=cfg.padding - ) - output = jax.lax.conv_general_dilated( - lhs=x, - rhs=self.parameters["weight"], - window_strides=cfg.strides, - dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), - padding=conv_padding, - feature_group_count=cfg.num_input_dim_groups, - ) - if cfg.bias: - output += self.parameters["bias"] - return output - - @nowrap - def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: - cfg = self.config - if len(input_shape) != 5: - raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") - if input_shape[-1] != cfg.input_dim: - raise ValueError( - f"input_shape[-1] = {input_shape[-1]} does not match " - f"cfg.input_dim = {cfg.input_dim}." - ) - - in_shape = input_shape[1:4] - out_shape = conv_output_shape( - in_shape, window=cfg.window, strides=cfg.strides, padding=cfg.padding - ) - return [input_shape[0], *out_shape, cfg.output_dim] - - class Conv1D(BaseConv): """The 1D convolution layer. @@ -1450,12 +1289,9 @@ class Config(BaseConv.Config): # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, # where K is a positive integer, the operation is also known as a "depthwise convolution". num_input_dim_groups: Optional[int] = 1 - # LHS_dilation is also known as transposed convolution. It is either None, or an int - # indicating the dilation factor applied on the input. - lhs_dilation: Optional[int] = None # RHS_dilation is also known as atrous convolution. It is either None, or an int indicating # dilation factor applied to the weight. - rhs_dilation: Optional[int] = None + dilation: int = 1 @classmethod def default_config(cls): @@ -1463,14 +1299,14 @@ def default_config(cls): cfg.param_partition_spec = (None, None, "model") return cfg - def __init__(self, cfg: Config, *, parent: Optional[Module]): - super().__init__(cfg, parent=parent) - # Check lhs_dilation and padding setting compatibility. - if cfg.lhs_dilation is not None and cfg.lhs_dilation != 1 and isinstance(cfg.padding, str): - raise ValueError("String padding is not supported for LHS dilation.") - def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config + _check_conv_cfg( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) if cfg.padding not in SUPPORT_CONV_PADDING: left, right = cfg.padding[0] if any(p < 0 for p in (left, right)): @@ -1489,11 +1325,12 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor) -> Tensor: cfg = self.config - dilation = cfg.rhs_dilation or 1 conv_padding = conv_explicit_padding( - window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=(dilation,) + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), ) - transpose_dilation = cfg.lhs_dilation or 1 output = jax.lax.conv_general_dilated( lhs=x, rhs=self.parameters["weight"], @@ -1501,8 +1338,7 @@ def forward(self, x: Tensor) -> Tensor: dimension_numbers=("NWC", "WIO", "NWC"), padding=conv_padding, feature_group_count=cfg.num_input_dim_groups, - lhs_dilation=(transpose_dilation,), - rhs_dilation=(dilation,), + rhs_dilation=(cfg.dilation,), ) if cfg.bias: output += self.parameters["bias"] @@ -1544,19 +1380,13 @@ def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: # Apply Conv1D. output = super().forward(x) - # TODO(dhwang2): Implement Conv1DTranspose separately for lhs_dilation. It's problematic - # for lhs_dilation (Conv Transpose) and rhs_dilation (Dilated Convolution) to be part of - # the same class. Not only are they never used together, but their combined usage would - # result in undefined behavior. Additionally, the logic for handling explicit padding and - # paddings is fundamentally different between them, so supporting both in a single class - # makes the code error-prone. # Compute paddings conv output. output_paddings = compute_conv_paddings( paddings, window=cfg.window, stride=cfg.strides, conv_padding=cfg.padding, - dilation=cfg.rhs_dilation, + dilation=cfg.dilation, anchor=cfg.anchor, ) # Apply padding to the outputs. @@ -1580,6 +1410,7 @@ class Config(BaseConv.Config): # Paddings: "SAME", "VALID", "CAUSAL" or (left, right). # For causal convolution, set padding to (window - 1, 0). padding: ConvPaddingType = ((0, 0),) + dilation: int = 1 # The convolution dilation. bias: bool = True # Whether to add a bias. @classmethod @@ -1590,6 +1421,12 @@ def default_config(cls): def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: cfg = self.config + _check_conv_cfg( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) if cfg.padding not in SUPPORT_CONV_PADDING: left, right = cfg.padding[0] if any(p < 0 for p in (left, right)): @@ -1613,7 +1450,10 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor) -> Tensor: cfg = self.config conv_padding = conv_explicit_padding( - window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), ) output = jax.lax.conv_general_dilated( lhs=x, @@ -1621,6 +1461,7 @@ def forward(self, x: Tensor) -> Tensor: window_strides=(cfg.strides,), dimension_numbers=("NWC", "WIO", "NWC"), padding=conv_padding, + rhs_dilation=(cfg.dilation,), feature_group_count=cfg.input_dim, ) if cfg.bias: @@ -1628,6 +1469,969 @@ def forward(self, x: Tensor) -> Tensor: return output +class Conv3D(BaseConv): + """The 3-D convolution layer. + + Kernel weights have the HWDIO layout and in the shape of (window[0], window[1], + window[2], input_dim, output_dim). Both inputs and outputs will be in the NHWDC layout. + """ + + @config_class + class Config(BaseConv.Config): + """Configures Conv3D.""" + + window: tuple[int, int, int] = (1, 1, 1) # The convolution window. + strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. + + # Paddings: "SAME" or "VALID, or ((top, bottom), (left, right), (front, back)) + padding: ConvPaddingType = ( + (0, 0), + (0, 0), + (0, 0), + ) + dilation: tuple[int, int, int] = (1, 1, 1) # The convolution dilation. + + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + + # The number of groups in which the input is split along the channel axis. + # input_dim and output_dim must both be divisible by num_input_dim_groups. For example, + # - At num_input_dim_groups=1, all inputs are convolved to all outputs (the default). + # - At num_input_dim_groups=2, the operation is equivalent to concatenating two conv layers + # side by side, each seeing half the input and producing half the output channels. + # - At num_input_dim_groups=input_dim, each input channel is convolved with its own + # set of filters (of size output_dim / input_dim); if further output_dim == K * input_dim, + # where K is a positive integer, the operation is also known as a "depthwise convolution". + num_input_dim_groups: Optional[int] = 1 + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, None, None, None) + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + params = dict( + weight=ParameterSpec( + shape=list(cfg.window) + + [cfg.input_dim // cfg.num_input_dim_groups, cfg.output_dim], + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = conv_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + output = jax.lax.conv_general_dilated( + lhs=x, + rhs=self.parameters["weight"], + window_strides=cfg.strides, + dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), + padding=conv_padding, + rhs_dilation=cfg.dilation, + feature_group_count=cfg.num_input_dim_groups, + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 5: + raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:4] + out_shape = conv_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +############################## Transposed Convolution ############################################## + + +# Based on jax.lax.convolution._conv_transpose_padding, but ours is more intuitive. +def convt_explicit_padding( + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], +) -> ConvPaddingType: + """Convert str padding to tuple padding for conv_transpose. + + Each mode follows the formulas below, + * SAME: (min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) + pad_total = window+stride-2 + when stride > window -> (window-1, stride-1) + * VALID: (window-1, max(stride-1, window-1)) + pad_total = window+stride-2 + max(window-stride, 0) + when stride > window -> (window-1, stride-1) + * CAUSAL: (window-1, stride-1) + pad_total = window+stride-2 + + Note: output_size = input_size*stride - (window+stride-2) + pad_total + = input_size*stride <- "SAME", "CAUSAL" + = input_size*stride + max(window-stride, 0) <- "VALID" + + Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + + The following illustration demonstrates how Conv Transpose operates, assuming all kernel values + are set to 1 for simplicity in showcasing output values. + + In the window=3 and stride=1 case, this function creates outputs as follows: + * "SAME" padding=(1, 1) + pad| |pad + paddings: 0|0 0 1 1|0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + 1 1 0 -> 2 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 0 1 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + 1 1 0 -> 2 + 1 0 0 -> 1 + + * "CAUSAL" padding=(2, 0) + pad | |pad + paddings: 0 0|0 0 1 1| + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 1 -> 2 + + In the window=3 and stride=2 case, this function creates outputs as follows: + * "SAME" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + 1 0 0 -> 1 + + * "CAUSAL" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 1 -> 2 + 0 1 0 -> 1 + + In the window=3 and stride=3 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * * 0 * * 1 * * 1|0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + + In the window=3 and stride=4 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 3) + pad | |pad + paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 0 -> 0 + 0 0 1 -> 1 + 0 1 0 -> 1 + 1 0 0 -> 1 + 0 0 0 -> 0 + Here is how to compute output_size, given the above example, + 1. |_| -(window-1) + 2. |_______________________| (input_size-1)*stride + 1 + 3. |_| |___| + pad_total + + So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total + = input_size*stride - (window+stride-2) + pad_total + = input_size*stride <- "SAME", "CAUSAL" + = input_size*stride + max(window-stride, 0) <- "VALID" + + OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. + For example, when window=3 and dilation=2, dilate_window=5. + + In the stride=2 case, this function creates outputs as follows: + * "SAME" padding=(3, 2) + pad | |pad + paddings: 0 0 0|0 * 0 * 1 * 1|0 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + 1 * 1 * 0 -> 2 + + * "VALID" padding=(4, 4) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + 1 * 1 * 0 -> 2 + 0 * 0 * 0 -> 0 + 1 * 0 * 0 -> 1 + + * "CAUSAL" padding=(4, 1) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 0 -> 0 + 0 * 0 * 1 -> 1 + 0 * 0 * 0 -> 0 + 0 * 1 * 1 -> 2 + 0 * 0 * 0 -> 0 + 1 * 1 * 0 -> 2 + 0 * 0 * 0 -> 0 + 1 * 0 * 0 -> 1 + + For "CAUSAL", the first component is time and treated as "CAUSAL", while the remaining + components are handled with "SAME" padding. + + Args: + window: convolution window. + strides: transposed convolution strides. It's lhs_dilation, not window_stride. + padding: convolution padding. + dilation: convolution dilation, a.k.a rhs_dilation. + + Returns: + The padding tuple. + + Raises: + ValueError: If padding is not supported. + """ + if not isinstance(padding, str): + return padding + + window = conv_dilate_window(window=window, dilation=dilation) + + def same_padding(window, strides): + pad_left = tuple(min(w - 1, (w + s - 1) // 2) for w, s in zip(window, strides)) + pad_right = tuple(max(s - 1, (w + s - 2) // 2) for w, s in zip(window, strides)) + return tuple(zip(pad_left, pad_right)) + + if padding == "SAME": + return same_padding(window, strides) + elif padding == "VALID": + pad_left = tuple(w - 1 for w in window) + pad_right = tuple(max(s - 1, w - 1) for w, s in zip(window, strides)) + return tuple(zip(pad_left, pad_right)) + elif padding == "CAUSAL": + causal_padding = ((window[0] - 1, strides[0] - 1),) + if len(window) > 1: + causal_padding += same_padding(window[1:], strides[1:]) + return causal_padding + else: + raise ValueError(f"{padding} padding is not supported.") + + +def convt_output_shape( + in_shape: Sequence[Optional[int]], + *, + window: Sequence[int], + strides: Sequence[int], + padding: ConvPaddingType, + dilation: Sequence[int], +) -> Sequence[int]: + """Returns output size for conv transpose. + + Each mode follows the formulas below, + * SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) + pad_total = window+stride-2 + output_size = input_size*stride + * VALID: padding=(window-1, max(stride-1, window-1)) + pad_total = window+stride-2 + max(window-stride, 0) + output_size = input_size*stride + max(window-stride, 0) + * CAUSAL: padding=(window-1, stride-1) + pad_total = window+stride-2 + output_size = input_size*stride + + Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. + dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() + + Refer to + https://towardsdatascience.com/understand-transposed-convolutions-and-build-your-own-transposed-convolution-layer-from-scratch-4f5d97b2967 + + Args: + in_shape: convolution lhs shape. + window: convolution window. + strides: convolution strides. + padding: convolution padding. + dilation: convolution dilation. + + Returns: + The output shape. + + Raises: + ValueError: If the length of in_shape, window, strides, and padding are not equal. + """ + if len(in_shape) != len(window) or len(in_shape) != len(strides): + raise ValueError( + f"len(in_shape) = {len(in_shape)} must be equal to " + f"len(window) = {len(window)} and len(strides) = {len(strides)}" + ) + + window = conv_dilate_window(window=window, dilation=dilation) + + def output_shape(in_shape: Optional[int], window: int, stride: int): + if in_shape is None: + return None + + if padding == "SAME": + return in_shape * stride + elif padding == "VALID": + return in_shape * stride + max(window - stride, 0) + elif padding == "CAUSAL": + return in_shape * stride + else: + raise ValueError(f"{padding} padding is not supported.") + + return tuple(map(output_shape, in_shape, window, strides)) + + +def compute_convt_paddings( + in_paddings: Tensor, + *, + window: int, + stride: int, + conv_padding: ConvPaddingType, + dilation: int = 1, + anchor: Optional[int] = None, +): + """Compute output paddings w.r.t. conv_padding for conv transpose. + + The output paddings value is determined by the padding value at the anchor point in the + window. If anchor is None, the default anchor point is the left time padding from conv + padding config. See `Conv2DWith1DPadding.Config` in details. + + In the window=3 and stride=1 case, this function creates paddings as follows: + + The following illustration demonstrates how Conv Transpose operates, assuming all kernel values + are set to 1 for simplicity in showcasing output values. + + In the window=3 and stride=1 case, this function creates outputs as follows: + * "SAME" padding=(1, 1) + pad| |pad + paddings: 0|0 0 1 1|1 + |_____| + * 0 * -> 0 + * 0 * -> 0 + * 1 * -> 1 + * 1 0 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 0 1 1|1 1 + |_________| + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "CAUSAL" padding=(2, 0) + pad | |pad + paddings: 0 0|0 0 1 1| + |_____| + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + + In the window=3 and stride=2 case, this function creates outputs as follows: + * "SAME" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "VALID" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 1 + |_______________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + * "CAUSAL" padding=(2, 1) + pad | |pad + paddings: 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + In the window=3 and stride=3 case, this function creates outputs as follows: + * "SAME", "VALID" and "CAUSAL" padding=(2, 2) + pad | |pad + paddings: 0 0|0 * * 0 * * 1 * * 1|1 1 + |_____________________| + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 0 -> 0 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + * * 1 -> 1 + + OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. + For example, when window=3 and dilation=2, dilate_window=5. + + In the stride=2 case, this function creates outputs as follows: + * "SAME" padding=(3, 2) + pad | |pad + paddings: 0 0 0|0 * 0 * 1 * 1|1 1 + |_____________| + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 0 * -> 0 + * * * 1 * -> 1 + * * * 1 * -> 1 + * * * 1 * -> 1 + * * * 1 * -> 1 + + * "VALID" padding=(4, 4) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|1 1 1 1 + |___________________| + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + + * "CAUSAL" padding=(4, 1) + pad | |pad + paddings: 0 0 0 0|0 * 0 * 1 * 1|1 + |_____________| + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 0 -> 0 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + * * * * 1 -> 1 + + Args: + in_paddings: A Tensor of shape [batch_size, seq_len]. + window: convolution window size of the time axis. + stride: convolution stride size of the time axis. + conv_padding: "SAME", "VALID", "CAUSAL" or ((left_time_padding, right_time_padding),) + dilation: convolution dilation size of the time axis. + anchor: an optional integer in the range of [0, window) + that specifies the anchor position within the convolution window that is used to + determine output paddings. Specifically, the output token is valid iff the input token + at the anchor position of the corresponding window is valid. + If None, anchor defaults to conv_padding[0] (i.e. left_time_padding). + + Returns: + out_paddings: A Tensor of shape [batch_size, seq_len]. + + Raises: + ValueError: If anchor is not between left_time_padding and window. + """ + + chex.assert_rank(in_paddings, 2) + conv_padding = convt_explicit_padding( + window=(window,), strides=(stride,), padding=conv_padding, dilation=(dilation,) + ) + window = conv_dilate_window(window=(window,), dilation=(dilation,))[0] + # Note: in transposed conv, left_pad + right_pad >= window - 1. See convt_explicit_padding(). + left_pad, right_pad = conv_padding[0] + + if anchor is None: + anchor = left_pad + # elif not left_pad <= anchor < window: + elif not anchor < window: + raise ValueError(f"anchor ({anchor}) must in range [0, {window}).") + + # Consider the case where window=3, strides=2, dilation=2, and padding="SAME" + # explicit padding=(3, 2) + # pad | |pad + # paddings: 0 0 0|0 * 0 * 1 * 1|1 1 + # |_____________| + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 0 * -> 0 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + # * * * 1 * -> 1 + + # |0 0 1 1| -> |0 * 0 * 1 * 1| + def dilate_paddings(paddings): + most, last = jnp.split(paddings, [paddings.shape[1] - 1], axis=1) + dilated = einops.repeat(most, "b t -> b (t s)", s=stride) + return jnp.concatenate([dilated, last], axis=1) + + in_paddings = dilate_paddings(in_paddings) + + # |0 * 0 * 1 * 1| -> 0 0 0|0 * 0 * 1 * 1|1 1 + # |_____________| which is |0 * 0 * 1 * 1|1 + window_pad_total = window - 1 # Note: we already check `anchor < window`` always. + window_right_pad = window_pad_total - anchor + assert window_right_pad >= 0, f"{anchor=} < {window=} always." + # Note: left_pad + right_pad >= window + stride - 2 >= window - 1 == anchor + window_right_pad + valid_right_pad = right_pad - window_right_pad + if valid_right_pad >= 0: + out_paddings = jnp.pad(in_paddings, ((0, 0), (0, valid_right_pad)), mode="edge") + else: + out_paddings = in_paddings[:, :valid_right_pad] + + start_index = anchor - left_pad + if start_index < 0: + out_paddings = jnp.pad(out_paddings, ((0, 0), (-start_index, 0)), mode="edge") + else: + out_paddings = out_paddings[:, start_index:] + return out_paddings + + +class Conv1DTranspose(BaseConv): + """The 1D transposed convolution layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv1DTranspose.""" + + window: int = 1 + strides: int = 1 + padding: ConvPaddingType = "SAME" + dilation: int = 1 # Dilation for dilated Convolution. + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + + # An optional integer in the range of [0, window) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See compute_convt_paddings more details. + anchor: Optional[int] = None + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) + params = dict( + weight=ParameterSpec( + shape=(cfg.window, cfg.input_dim, cfg.output_dim), + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward( + self, x: Tensor, *, paddings: Optional[Tensor] = None + ) -> tuple[Tensor, Optional[Tensor]]: + cfg = self.config + conv_padding = convt_explicit_padding( + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) + + if paddings is not None: + chex.assert_rank(x, paddings.ndim + 1) + # Apply padding to the input. + x = x * (1 - paddings[..., None]) + + output = jax.lax.conv_transpose( + lhs=x, + rhs=self.parameters["weight"], + strides=(cfg.strides,), + padding=conv_padding, + rhs_dilation=(cfg.dilation,), + dimension_numbers=("NWC", "WIO", "NWC"), + ) + if cfg.bias: + output += self.parameters["bias"] + + if paddings is None: + output_paddings = None + else: + # Compute paddings conv output. + output_paddings = compute_convt_paddings( + paddings, + window=cfg.window, + stride=cfg.strides, + conv_padding=cfg.padding, + dilation=cfg.dilation, + anchor=cfg.anchor, + ) + output = output * (1 - output_paddings[..., None]) + return output, output_paddings + + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 3: + raise ValueError(f"We expect len(input_shape) = 3, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + "cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:2] + out_shape = convt_output_shape( + in_shape, + window=(cfg.window,), + strides=(cfg.strides,), + padding=cfg.padding, + dilation=(cfg.dilation,), + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +class Conv2DTranspose(BaseConv): + """The 2-D transposed convolution layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv2DTranspose.""" + + window: tuple[int, int] = (1, 1) + strides: tuple[int, int] = (1, 1) + padding: ConvPaddingType = "SAME" + dilation: tuple[int, int] = (1, 1) + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + # True for backward compatibility, but False is more computationally efficient. + # If True, flips spatial axes and swaps the input/output channel axes of the kernel. + # This makes the output of this function identical to the gradient-derived functions + # like keras.layers.Conv2DTranspose applied to the same kernel. + transpose_kernel: bool = True + + @classmethod + def default_config(cls): + cfg = super().default_config() + if cfg.transpose_kernel: + cfg.param_partition_spec = (None, None, "model", None) + else: + cfg.param_partition_spec = (None, None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + if cfg.transpose_kernel: + io_shape = (cfg.output_dim, cfg.input_dim) + else: + io_shape = (cfg.input_dim, cfg.output_dim) + params = dict( + weight=ParameterSpec( + shape=cfg.window + io_shape, + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = convt_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + output = jax.lax.conv_transpose( + lhs=x, + rhs=self.parameters["weight"], + strides=cfg.strides, + padding=conv_padding, + rhs_dilation=cfg.dilation, + dimension_numbers=("NHWC", "HWIO", "NHWC"), + transpose_kernel=cfg.transpose_kernel, + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 4: + raise ValueError(f"We expect len(input_shape) = 4, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + "cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:3] + out_shape = convt_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +class Conv2DTransposeWith1DPadding(Conv2DTranspose): + """The 2-D convolution transpose with 1-D padding on the time axis.""" + + @config_class + class Config(Conv2DTranspose.Config): + """Configures Conv2DTransposeWith1DPadding.""" + + transpose_kernel: bool = False + # An optional integer in the range of [0, window) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See compute_convt_paddings more details. + anchor: Optional[int] = None + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.transpose_kernel = False # Choose better one unlike parent. + return cfg + + # We add a kwargs "paddings" to the forward method. + # pylint: disable-next=arguments-differ + def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: + """Computes convolution outputs and paddings. + + Args: + x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + + Returns: + output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + """ + cfg = self.config + # Apply padding to the input. + assert len(x.shape) == len(paddings.shape) + 2 + x = x * (1 - paddings[..., None, None]) + + # Apply Conv2D. + output = super().forward(x) + # Compute paddings conv output. + output_paddings = compute_convt_paddings( + paddings, + window=cfg.window[0], + stride=cfg.strides[0], + conv_padding=cfg.padding, + dilation=cfg.dilation[0], + anchor=cfg.anchor, + ) + # Apply padding to the outputs. + output = output * (1 - output_paddings[..., None, None]) + return output, output_paddings + + +class Conv3DTranspose(BaseConv): + """The 3-D convolution transpose layer.""" + + @config_class + class Config(BaseConv.Config): + """Configures Conv3DTranspose.""" + + window: tuple[int, int, int] = (1, 1, 1) # The convolution window. + strides: tuple[int, int, int] = (1, 1, 1) # The convolution strides. + # Paddings: "SAME" or "VALID, or ((top, bottom), (left, right), (front, back)) + padding: ConvPaddingType = "SAME" + dilation: tuple[int, int, int] = (1, 1, 1) # The convolution dilation. + + output_dim: Required[int] = REQUIRED # Output feature dim. + bias: bool = True # Whether to add a bias. + + @classmethod + def default_config(cls): + cfg = super().default_config() + cfg.param_partition_spec = (None, None, None, None, "model") + return cfg + + def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: + cfg = self.config + _check_conv_cfg( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + params = dict( + weight=ParameterSpec( + shape=cfg.window + (cfg.input_dim, cfg.output_dim), + mesh_axes=cfg.param_partition_spec, + factorization=FactorizationSpec(axes=(None, None, None, "row", "col")), + ) + ) + if cfg.bias: + params["bias"] = ParameterSpec( + shape=[cfg.output_dim], mesh_axes=(cfg.param_partition_spec[-1],) + ) + return params + + def forward(self, x: Tensor) -> Tensor: + cfg = self.config + conv_padding = convt_explicit_padding( + window=cfg.window, strides=cfg.strides, padding=cfg.padding, dilation=cfg.dilation + ) + output = jax.lax.conv_transpose( + lhs=x, + rhs=self.parameters["weight"], + strides=cfg.strides, + padding=conv_padding, + rhs_dilation=cfg.dilation, + dimension_numbers=("NHWDC", "HWDIO", "NHWDC"), + ) + if cfg.bias: + output += self.parameters["bias"] + return output + + @nowrap + def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Optional[int]]: + cfg = self.config + if len(input_shape) != 5: + raise ValueError(f"We expect len(input_shape) = 5, but got {len(input_shape)}.") + if input_shape[-1] != cfg.input_dim: + raise ValueError( + f"input_shape[-1] = {input_shape[-1]} does not match " + f"cfg.input_dim = {cfg.input_dim}." + ) + + in_shape = input_shape[1:4] + out_shape = convt_output_shape( + in_shape, + window=cfg.window, + strides=cfg.strides, + padding=cfg.padding, + dilation=cfg.dilation, + ) + return [input_shape[0], *out_shape, cfg.output_dim] + + +############################## Others ############################################################## + + class Embedding(BaseLayer): """Implements an embedding lookup function. @@ -2150,7 +2954,7 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: padding = cfg.padding if isinstance(padding, str): padding = conv_explicit_padding( - window=(cfg.stride,), strides=(cfg.stride,), padding=padding + window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) )[0] inputs = jnp.pad(inputs, ((0, 0), padding, (0, 0)), constant_values=0) @@ -2183,7 +2987,7 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti if isinstance(padding, tuple): padding = (padding,) out_shape = conv_output_shape( - [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding + [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding, dilation=(1,) ) return [batch_size, *out_shape, input_dim * cfg.stride] diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index 1c5ffc5f..ca9791d2 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -924,36 +924,6 @@ def test_conv_output_shape(self, in_shape, window, strides, padding, dilation, e )[0] self.assertEqual(out_shape, expected) - @parameterized.parameters( - ([0, 0, 0, 1], [0, 0, 0, 1], 1, "SAME"), - ([0], [], 1, "VALID"), - ([0, 0], [], 1, "VALID"), - ([0, 0, 0], [0], 1, "VALID"), - ([0, 0, 0, 0], [0, 0], 1, "VALID"), - ([0, 0, 0, 1], [0, 0], 1, "VALID"), - ([0, 0, 0, 0], [0], 2, "VALID"), - ([0, 0, 0, 1], [0], 2, "VALID"), - ([0, 0, 1, 1], [0], 2, "VALID"), - ([0, 0, 0, 0, 0], [0, 0], 2, "VALID"), - ([0, 0, 0, 0, 1], [0, 0], 2, "VALID"), - ([0, 0, 0, 1, 1], [0, 0], 2, "VALID"), - ([0, 0, 1, 1, 1], [0, 1], 2, "VALID"), - ([0, 0, 0, 0, 0, 0], [0, 0], 2, "VALID"), - ([0, 0, 0, 0, 0, 1], [0, 0], 2, "VALID"), - ([0, 0, 0, 0, 1, 1], [0, 0], 2, "VALID"), - ([0, 0, 0, 1, 1, 1], [0, 0], 2, "VALID"), - ([0, 0, 1, 1, 1, 1], [0, 1], 2, "VALID"), - ) - def test_conv_padding(self, input_paddings, expected_paddings, stride: int, padding_cfg: str): - """Tests conv_output_shape() with SAME and VALID padding cfg.""" - # This test is from lingvo - # https://github.com/tensorflow/lingvo/blob/master/lingvo/core/conv_layers_with_time_padding_test.py#L157. - window = 3 - out_paddings = compute_conv_paddings( - jnp.array([input_paddings]), window=window, stride=stride, conv_padding=padding_cfg - ) - assert_allclose(out_paddings[0], expected_paddings) - @parameterized.parameters( (5, 1, "SAME", 1, (2, 2)), (5, 2, "SAME", 1, (2, 2)), @@ -986,6 +956,36 @@ def test_conv_explicit_padding( ) self.assertAllEqual(explicit_padding[0], expected) + @parameterized.parameters( + ([0, 0, 0, 1], [0, 0, 0, 1], 1, "SAME"), + ([0], [], 1, "VALID"), + ([0, 0], [], 1, "VALID"), + ([0, 0, 0], [0], 1, "VALID"), + ([0, 0, 0, 0], [0, 0], 1, "VALID"), + ([0, 0, 0, 1], [0, 0], 1, "VALID"), + ([0, 0, 0, 0], [0], 2, "VALID"), + ([0, 0, 0, 1], [0], 2, "VALID"), + ([0, 0, 1, 1], [0], 2, "VALID"), + ([0, 0, 0, 0, 0], [0, 0], 2, "VALID"), + ([0, 0, 0, 0, 1], [0, 0], 2, "VALID"), + ([0, 0, 0, 1, 1], [0, 0], 2, "VALID"), + ([0, 0, 1, 1, 1], [0, 1], 2, "VALID"), + ([0, 0, 0, 0, 0, 0], [0, 0], 2, "VALID"), + ([0, 0, 0, 0, 0, 1], [0, 0], 2, "VALID"), + ([0, 0, 0, 0, 1, 1], [0, 0], 2, "VALID"), + ([0, 0, 0, 1, 1, 1], [0, 0], 2, "VALID"), + ([0, 0, 1, 1, 1, 1], [0, 1], 2, "VALID"), + ) + def test_conv_padding(self, input_paddings, expected_paddings, stride: int, padding_cfg: str): + """Tests conv_output_shape() with SAME and VALID padding cfg.""" + # This test is from lingvo + # https://github.com/tensorflow/lingvo/blob/master/lingvo/core/conv_layers_with_time_padding_test.py#L157. + window = 3 + out_paddings = compute_conv_paddings( + jnp.array([input_paddings]), window=window, stride=stride, conv_padding=padding_cfg + ) + assert_allclose(out_paddings[0], expected_paddings) + @parameterized.parameters( (5, 1, "SAME", [0, 0, 0, 0, 1, 1]), (5, 2, "SAME", [0, 0, 1]), @@ -1052,7 +1052,7 @@ def test_conv_output_1d_padding_causal(self, in_paddings, expected): stride = 6 padding = "CAUSAL" explicit_padding = layers.conv_explicit_padding( - window=(window,), strides=(stride,), padding=padding + window=(window,), strides=(stride,), padding=padding, dilation=(1,) ) self.assertAllEqual(explicit_padding[0], (9, 5)) @@ -1082,7 +1082,7 @@ def test_conv_output_1d_padding_against_str_padding( paddings = jnp.triu(jnp.ones((batch_size, seq_len)), k=1) explicit_padding = layers.conv_explicit_padding( - window=(window,), strides=(stride,), padding=ref_padding + window=(window,), strides=(stride,), padding=ref_padding, dilation=(1,) ) self.assertAllEqual(explicit_padding, padding[:1]) @@ -1350,103 +1350,6 @@ def test_conv1d_against_conv2d_with_1d_padding( assert_allclose(ref_paddings, test_paddings) assert_allclose(ref_outputs, test_outputs) - @parameterized.named_parameters( - { - "testcase_name": "2x2", - "window": (2, 2), - "strides": (1, 1), - "padding": "VALID", - }, - { - "testcase_name": "2x2_S2", - "window": (2, 2), - "strides": (2, 2), - "padding": "VALID", - }, - { - "testcase_name": "3x3_S2", - "window": (3, 3), - "strides": (2, 2), - "padding": "VALID", - }, - ) - def test_deconv2d( - self, - window: tuple[int, int], - strides: tuple[int, int], - padding: Union[str, tuple[int, int]], - ): - input_dim, output_dim = 4, 8 - if isinstance(padding, tuple): - deconv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) - else: - deconv_padding = padding - cfg = Conv2DTranspose.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=deconv_padding, - ) - layer: Conv2DTranspose = cfg.instantiate(parent=None) - - # Initialize layer parameters. - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict( - weight=(window[0], window[1], output_dim, input_dim), - bias=(output_dim,), - ), - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - # Random inputs. - prng_key, input_key = jax.random.split(prng_key) - inputs = jax.random.normal(input_key, [2, 10, 7, input_dim]) - # Compute layer outputs. - outputs, _ = F( - layer, - inputs=(inputs,), - is_training=True, - state=layer_params, - prng_key=prng_key, - ) - - # Compute ref outputs. - if isinstance(padding, tuple): - ref_padding = padding[0] - elif isinstance(padding, str): - ref_padding = padding.lower() - if ref_padding == "valid": - ref_padding = 0 - else: - ref_padding = 0 - - ref = torch.nn.ConvTranspose2d( - in_channels=input_dim, - out_channels=output_dim, - kernel_size=window, - stride=strides, - padding=ref_padding, - ) - # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). - _copy(layer_params["weight"].transpose(3, 2, 0, 1), ref.weight) - _copy(layer_params["bias"], ref.bias) - ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 3, 1, 2))) - assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 1)) - # Tests output_shape. - output_shape = layer.output_shape(input_shape=inputs.shape) - self.assertAllEqual(outputs.shape, output_shape) - @parameterized.named_parameters( { "testcase_name": "1x1x1", @@ -1597,12 +1500,12 @@ def test_conv3d( self.assertAllEqual(outputs.shape, output_shape) @parameterized.named_parameters( - ("w3s1d1_VALID", 3, 1, "VALID", None), + ("w3s1d1_VALID", 3, 1, "VALID", 1), ("w3s1d2_VALID", 3, 1, "VALID", 2), - ("w3s1d1_SAME", 3, 1, "SAME", None), - ("w4s1d1_SAME", 4, 1, "SAME", None), + ("w3s1d1_SAME", 3, 1, "SAME", 1), + ("w4s1d1_SAME", 4, 1, "SAME", 1), ("w4s1d3_SAME", 4, 1, "SAME", 3), - ("w4s1d1_CAUSAL", 4, 1, ((3, 0),), None), + ("w4s1d1_CAUSAL", 4, 1, ((3, 0),), 1), ("w4s1d5_CAUSAL", 4, 1, ((3, 0),), 5), ) def test_conv1d( @@ -1610,7 +1513,7 @@ def test_conv1d( window: int, strides: int, padding: ConvPaddingType, - dilation: Optional[int] = None, + dilation: int, ): input_dim, output_dim = 4, 6 cfg = Conv1D.default_config().set( @@ -1620,7 +1523,7 @@ def test_conv1d( window=window, strides=strides, padding=padding, - rhs_dilation=dilation, + dilation=dilation, ) layer: Conv1D = cfg.instantiate(parent=None) # Initialize layer parameters. @@ -1673,70 +1576,6 @@ def test_conv1d( ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) - @parameterized.named_parameters( - ("w3s1d3_pad1", 3, 1, (1, 1), 3), - ("w3s1d3_pad1_causal", 3, 1, (1, 0), 3), - ("w3s2d3_pad1", 3, 2, (1, 1), 3), - ("w3s2d1_pad1", 3, 2, (1, 1), None), - ) - def test_lhs_dilation_conv1d( - self, - window: int, - strides: int, - padding: tuple[int, int], - dilation: Optional[int], - ): - input_dim, output_dim = 4, 6 - cfg = Conv1D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=window, - strides=strides, - padding=(padding,), - lhs_dilation=dilation, - ) - layer: Conv1D = cfg.instantiate(parent=None) - prng_key = jax.random.PRNGKey(123) - prng_key, init_key = jax.random.split(prng_key) - layer_params = layer.initialize_parameters_recursively(init_key) - self.assertEqual( - dict(weight=(window, input_dim, output_dim), bias=(output_dim,)), - shapes(layer_params), - ) - bias = layer_params["bias"] - assert_allclose(bias, jnp.zeros_like(bias)) - # Randomize bias. - layer_params["bias"] = jax.random.normal( - jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype - ) - - prng_key, input_key = jax.random.split(prng_key) - input_length = 7 - inputs = jax.random.normal(input_key, [2, input_length, input_dim]) - outputs, _ = F( - layer, inputs=(inputs,), is_training=True, state=layer_params, prng_key=prng_key - ) - expected_length = input_length + padding[0] + padding[1] - int((window + 1) / 2) - if dilation is not None: - expected_length = expected_length + (input_length - 1) * (dilation - 1) - expected_length = math.floor((expected_length - 1) / strides + 1) - self.assertNestedEqual(outputs.shape, (2, expected_length, output_dim)) - - def test_lhs_dilation_not_using_string_padding_conv1d(self): - input_dim, output_dim = 4, 6 - cfg = Conv1D.default_config().set( - name="test", - input_dim=input_dim, - output_dim=output_dim, - window=3, - strides=1, - padding="SAME", - lhs_dilation=2, - ) - with self.assertRaises(ValueError): - cfg.instantiate(parent=None) - @parameterized.named_parameters( ("w3s1_VALID", 3, 1, "VALID"), ("w3s1_SAME", 3, 1, "SAME"), @@ -1809,6 +1648,756 @@ def test_depthwise_conv1d( ref_outputs = ref(as_torch_tensor(ref_inputs.transpose(0, 2, 1))) assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 1)) + ############################## Transposed Convolution ########################################## + + @parameterized.parameters( + (3, 1, "SAME", 1, (1, 1)), + (3, 2, "SAME", 1, (2, 1)), + (3, 3, "SAME", 1, (2, 2)), + (3, 4, "SAME", 1, (2, 3)), + (3, 1, "SAME", 2, (2, 2)), + (3, 2, "SAME", 2, (3, 2)), + (3, 3, "SAME", 2, (3, 3)), + (3, 1, "VALID", 1, (2, 2)), + (3, 2, "VALID", 1, (2, 2)), + (3, 3, "VALID", 1, (2, 2)), + (3, 4, "VALID", 1, (2, 3)), + (3, 1, "VALID", 2, (4, 4)), + (3, 2, "VALID", 2, (4, 4)), + (3, 3, "VALID", 2, (4, 4)), + (3, 1, "CAUSAL", 1, (2, 0)), + (3, 2, "CAUSAL", 1, (2, 1)), + (3, 3, "CAUSAL", 1, (2, 2)), + (3, 4, "CAUSAL", 1, (2, 3)), + (3, 1, "CAUSAL", 2, (4, 0)), + (3, 2, "CAUSAL", 2, (4, 1)), + (3, 3, "CAUSAL", 2, (4, 2)), + ) + def test_convt_explicit_padding(self, window, strides, padding, dilation, expected): + """Tests the cases in convt_explicit_padding() description.""" + explicit_padding = layers.convt_explicit_padding( + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + ) + self.assertAllEqual(explicit_padding[0], expected) + + @parameterized.parameters( + (3, 1, "SAME", 1, 4), + (3, 2, "SAME", 1, 8), + (3, 3, "SAME", 1, 12), + (3, 4, "SAME", 1, 16), + (3, 1, "SAME", 2, 4), + (3, 2, "SAME", 2, 8), + (3, 3, "SAME", 2, 12), + (3, 1, "VALID", 1, 6), + (3, 2, "VALID", 1, 9), + (3, 3, "VALID", 1, 12), + (3, 4, "VALID", 1, 16), + (3, 1, "VALID", 2, 8), + (3, 2, "VALID", 2, 11), + (3, 3, "VALID", 2, 14), + (3, 1, "CAUSAL", 1, 4), + (3, 2, "CAUSAL", 1, 8), + (3, 3, "CAUSAL", 1, 12), + (3, 4, "CAUSAL", 1, 16), + (3, 1, "CAUSAL", 2, 4), + (3, 2, "CAUSAL", 2, 8), + (3, 3, "CAUSAL", 2, 12), + ) + def test_convt_output_shape(self, window, strides, padding, dilation, expected): + """Tests the cases in convt_explicit_padding() description.""" + out_shape = layers.convt_output_shape( + in_shape=(4,), + window=(window,), + strides=(strides,), + padding=padding, + dilation=(dilation,), + ) + self.assertAllEqual(out_shape[0], expected) + + @parameterized.parameters( + (3, 1, "SAME", 1, [0, 0, 1, 1]), + (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "SAME", 2, [0, 0, 1, 1]), + (3, 2, "SAME", 2, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 1, "VALID", 1, [0, 0, 1, 1, 1, 1]), + (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 1, 1, 1]), + (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), + (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]), + (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "CAUSAL", 1, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), + (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 1, 1, 1]), + (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + ) + def test_compute_convt_paddings(self, window, strides, padding, dilation, expected): + """Tests the cases in convt_explicit_padding() description.""" + in_paddings = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :] + out_paddings = layers.compute_convt_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + expected = jnp.array(expected).astype(out_paddings.dtype) + self.assertNestedEqual(out_paddings[0], expected) + + @parameterized.product( + window=[1, 3], + strides=[1, 2, 3], + padding=["SAME", "VALID", "CAUSAL"], + dilation=[1, 2], + value=[0, 1], + ) + def test_compute_convt_paddings_all0or1(self, window, strides, padding, dilation, value): + """If in_paddings is all valid or invalid, out_paddings must be all valid or invalid.""" + in_paddings = jnp.full([1, 4], fill_value=value) + out_paddings = layers.compute_convt_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + expected = jnp.ones_like(out_paddings) * value + self.assertNestedEqual(out_paddings, expected) + + CONVT_PADDINGS_PARAMS = dict( + in_paddings=[ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 1], + [1, 0, 0, 0, 1], + [1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 0, 0, 1, 1], + [1, 1, 0, 1, 1, 1, 1, 0, 0, 0], + ], + window=[1, 3], + padding=["SAME", "VALID", "CAUSAL"], + dilation=[1, 2], + ) + + @parameterized.product(**CONVT_PADDINGS_PARAMS, strides=[1, 2, 3]) + def test_compute_convt_paddings_with_conv_paddings( + self, in_paddings, window, strides, padding, dilation + ): + """Check if ConvT -> Conv preserves information.""" + in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] + out_paddings = layers.compute_convt_paddings( + in_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + + recon_paddings = layers.compute_conv_paddings( + out_paddings, window=window, stride=strides, conv_padding=padding, dilation=dilation + ) + self.assertNestedEqual(recon_paddings[0], in_paddings[0]) + + @parameterized.product(**CONVT_PADDINGS_PARAMS) + def test_compute_convt_paddings_against_conv_paddings( + self, in_paddings, window, padding, dilation + ): + # compute_convt_paddings and compute_conv_paddings are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + strides = 1 + if padding == "VALID": + # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint + # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. + # For reference, the midpoint in SAME and CAUSAL is left_pad. + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + conv_padding = layers.conv_explicit_padding( + window=(window,), strides=(strides,), padding=padding, dilation=(dilation,) + )[0] + pad_left, pad_right = conv_padding + anchor_range = dilate_window - pad_left - pad_right + mid_point = anchor_range // 2 + anchor = pad_left + mid_point + else: + anchor = None + + in_paddings = jnp.array(in_paddings, dtype=jnp.float32)[None, :] + ref_paddings = layers.compute_conv_paddings( + in_paddings, + window=window, + stride=strides, + conv_padding=padding, + dilation=dilation, + anchor=anchor, + ) + + test_paddings = layers.compute_convt_paddings( + in_paddings, + window=window, + stride=strides, + conv_padding=padding, + dilation=dilation, + anchor=anchor, + ) + + if ref_paddings.shape != test_paddings.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + pad_left = dilate_window - 1 + test_paddings = test_paddings[:, pad_left:-pad_left] + + assert_allclose(ref_paddings, test_paddings) + + CONVT_PARAMS = [ + (3, 1, "SAME", 1, [0, 1, 2, 2]), + (3, 2, "SAME", 1, [0, 0, 0, 0, 1, 1, 2, 1]), + (3, 3, "SAME", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "SAME", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "SAME", 2, [1, 1, 1, 1]), + (3, 2, "SAME", 2, [0, 0, 0, 1, 0, 2, 0, 2]), + (3, 3, "SAME", 2, [0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0]), + (3, 1, "VALID", 1, [0, 0, 1, 2, 2, 1]), + (3, 2, "VALID", 1, [0, 0, 0, 0, 1, 1, 2, 1, 1]), + (3, 3, "VALID", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "VALID", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "VALID", 2, [0, 0, 1, 1, 1, 1, 1, 1]), + (3, 2, "VALID", 2, [0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 1]), + (3, 3, "VALID", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1]), + (3, 1, "CAUSAL", 1, [0, 0, 1, 2]), + (3, 2, "CAUSAL", 1, [0, 0, 0, 0, 1, 1, 2, 1]), + (3, 3, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), + (3, 4, "CAUSAL", 1, [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0]), + (3, 1, "CAUSAL", 2, [0, 0, 1, 1]), + (3, 2, "CAUSAL", 2, [0, 0, 0, 0, 1, 0, 2, 0]), + (3, 3, "CAUSAL", 2, [0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1]), + ] + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv1d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in convt_explicit_padding() description.""" + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None] + cfg = layers.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + (outputs, paddings), _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, out_shape) + self.assertIsNone(paddings) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0], expected) + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv2d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in convt_explicit_padding() description.""" + window = (window, 1) + strides = (strides, 1) + dilation = (dilation, 1) + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None] + cfg = layers.Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + outputs, _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, out_shape) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0, 0], expected) + + @parameterized.parameters(*CONVT_PARAMS) + def test_conv3d_transpose_simple(self, window, strides, padding, dilation, expected): + """Tests the cases in convt_explicit_padding() description.""" + window = (window, 1, 1) + strides = (strides, 1, 1) + dilation = (dilation, 1, 1) + input_dim, output_dim = 1, 1 + inputs = jnp.array([0, 0, 1, 1], dtype=jnp.float32)[None, :, None, None, None] + cfg = layers.Conv3DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + bias=False, + ) + layer = cfg.instantiate(parent=None) + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual(dict(weight=(*window, input_dim, output_dim)), shapes(layer_params)) + layer_params["weight"] = jnp.ones_like(layer_params["weight"]) + + outputs, _ = F( + layer, inputs=dict(x=inputs), is_training=True, state=layer_params, prng_key=prng_key + ) + out_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, out_shape) + expected = jnp.array(expected).astype(outputs.dtype) + self.assertNestedEqual(outputs[0, :, 0, 0, 0], expected) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv1d_transpose_against_conv1d(self, window, padding, dilation): + # Conv1D and Conv1DTranspose are same when window_stride=1 + # (stride of Conv1D) and lhs_dilation=1 (stride of Conv1DTranspose). + input_dim, output_dim = 4, 6 + ref_cfg = Conv1D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 17, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key + ) + + (test_outputs, _), _ = F( + test_layer, inputs=dict(x=inputs), is_training=True, state=ref_states, prng_key=prng_key + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + pad_left = dilate_window - 1 + test_outputs = test_outputs[:, pad_left:-pad_left] + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv2d_transpose_against_conv2d(self, window, padding, dilation): + # Conv2D and Conv2DTranspose are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + window = (window, window) + dilation = (dilation, dilation) + input_dim, output_dim = 4, 6 + ref_cfg = Conv2D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + transpose_kernel=False, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height = 12, 13 + inputs = jax.random.normal(input_key, [2, width, height, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + test_outputs, _ = F( + test_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] + + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv2d_transpose_against_conv2d_with_paddings(self, window, padding, dilation): + # Conv2DWith1DPadding and Conv2DTransposeWith1DPadding are same when window_stride=1 + # (stride of Conv2D) and lhs_dilation=1 (stride of Conv2DTranspose). + window = (window, window) + dilation = (dilation, dilation) + input_dim, output_dim = 4, 6 + if padding == "VALID": + # TODO(dhwang2,ruoming): Currently, anchor is pad_left but it should be the midpoint + # between [pad_left, pad_right). Otherwise, the consistency of VALID padding is broken. + # For reference, the midpoint in SAME and CAUSAL is left_pad. + strides = (1, 1) + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation)[0] + conv_padding = layers.conv_explicit_padding( + window=window, strides=strides, padding=padding, dilation=dilation + ) + pad_left, pad_right = conv_padding[0] + anchor_range = dilate_window - pad_left - pad_right + mid_point = anchor_range // 2 + anchor = pad_left + mid_point + else: + anchor = None + + ref_cfg = Conv2DWith1DPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv2DTransposeWith1DPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + anchor=anchor, + transpose_kernel=False, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height = 12, 13 + inputs = jax.random.normal(input_key, [2, width, height, input_dim]) + paddings = jnp.zeros([2, width], dtype=inputs.dtype).at[:, -2:].set(1) + # Compute layer outputs. + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[:, pad_left[0] : -pad_left[0], pad_left[1] : -pad_left[1]] + test_paddings = test_paddings[:, pad_left[0] : -pad_left[0]] + + assert_allclose(ref_outputs, test_outputs) + assert_allclose(ref_paddings, test_paddings) + + @parameterized.product(window=(1, 3, 5), padding=("SAME", "VALID", "CAUSAL"), dilation=(1, 2)) + def test_conv3d_transpose_against_conv3d(self, window, padding, dilation): + # Conv3D and Conv3DTranspose are same when window_stride=1 + # (stride of Conv3D) and lhs_dilation=1 (stride of Conv3DTranspose). + window = (window, window, window) + dilation = (dilation, dilation, dilation) + input_dim, output_dim = 4, 6 + ref_cfg = Conv3D.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv3DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + padding=padding, + dilation=dilation, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + ref_states = ref_layer.initialize_parameters_recursively(init_key) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + width, height, depth = 9, 8, 7 + inputs = jax.random.normal(input_key, [2, width, height, depth, input_dim]) + # Compute layer outputs. + ref_outputs, _ = F( + ref_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + + test_outputs, _ = F( + test_layer, + inputs=dict(x=inputs), + is_training=True, + state=ref_states, + prng_key=prng_key, + ) + if ref_outputs.shape != test_outputs.shape: + self.assertEqual(padding, "VALID") + dilate_window = layers.conv_dilate_window(window=window, dilation=dilation) + pad_left = tuple(w - 1 for w in dilate_window) + test_outputs = test_outputs[ + :, + pad_left[0] : -pad_left[0], + pad_left[1] : -pad_left[1], + pad_left[2] : -pad_left[2], + ] + + assert_allclose(ref_outputs, test_outputs) + + @parameterized.product( + window=(1, 3, 5), + strides=(1, 2), + padding=("SAME", "VALID", "CAUSAL"), + dilation=(1, 2), + anchor=(None, 1), + ) + def test_conv1d_transpose_against_conv2d_transpose_with_1d_padding( + self, + window, + strides, + padding: ConvPaddingType, + dilation, + anchor, + ): + if anchor is not None: + dilate_window = layers.conv_dilate_window(window=(window,), dilation=(dilation,))[0] + anchor = dilate_window - 1 + + input_dim, output_dim = 4, 6 + ref_cfg = layers.Conv2DTransposeWith1DPadding.default_config().set( + name="ref", + input_dim=input_dim, + output_dim=output_dim, + window=(window, 1), + strides=(strides, 1), + padding=padding, + dilation=(dilation, 1), + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = layers.Conv1DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + dilation=dilation, + anchor=anchor, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + state = ref_layer.initialize_parameters_recursively(init_key) + test_state = dict( + bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o") + ) + + # Generate a batch of 10 input sequences. + batch_size, max_seq_len = 10, 10 + + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim]) + # The 10 sequences have length 1 to 10. + paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=test_state, + prng_key=prng_key, + ) + + inputs = einops.rearrange(inputs, "b t i -> b t 1 i") + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=state, + prng_key=prng_key, + ) + ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") + + assert_allclose(ref_paddings, test_paddings) + assert_allclose(ref_outputs, test_outputs) + + @parameterized.named_parameters( + { + "testcase_name": "2x2", + "window": (2, 2), + "strides": (1, 1), + "padding": "VALID", + }, + { + "testcase_name": "2x2_S2", + "window": (2, 2), + "strides": (2, 2), + "padding": "VALID", + }, + { + "testcase_name": "3x3_S2", + "window": (3, 3), + "strides": (2, 2), + "padding": "VALID", + }, + ) + def test_conv2d_transpose_against_pytorch( + self, + window: tuple[int, int], + strides: tuple[int, int], + padding: Union[str, tuple[int, int]], + ): + input_dim, output_dim = 4, 8 + if isinstance(padding, tuple): + deconv_padding = ((padding[0], padding[0]), (padding[1], padding[1])) + else: + deconv_padding = padding + cfg = Conv2DTranspose.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=deconv_padding, + ) + layer: Conv2DTranspose = cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + layer_params = layer.initialize_parameters_recursively(init_key) + self.assertEqual( + dict( + weight=(window[0], window[1], output_dim, input_dim), + bias=(output_dim,), + ), + shapes(layer_params), + ) + bias = layer_params["bias"] + assert_allclose(bias, jnp.zeros_like(bias)) + # Randomize bias. + layer_params["bias"] = jax.random.normal( + jax.random.PRNGKey(45), shape=bias.shape, dtype=bias.dtype + ) + + # Random inputs. + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [2, 10, 7, input_dim]) + # Compute layer outputs. + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + state=layer_params, + prng_key=prng_key, + ) + + # Compute ref outputs. + if isinstance(padding, tuple): + ref_padding = padding[0] + elif isinstance(padding, str): + ref_padding = padding.lower() + if ref_padding == "valid": + ref_padding = 0 + else: + ref_padding = 0 + + ref = torch.nn.ConvTranspose2d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=window, + stride=strides, + padding=ref_padding, + ) + # torch.nn.Linear.weight is of shape (output_dim, input_dim, kernel_size...). + _copy(layer_params["weight"].transpose(3, 2, 0, 1), ref.weight) + _copy(layer_params["bias"], ref.bias) + ref_outputs = ref(as_torch_tensor(inputs.transpose(0, 3, 1, 2))) + assert_allclose(outputs, ref_outputs.detach().numpy().transpose(0, 2, 3, 1)) + # Tests output_shape. + output_shape = layer.output_shape(input_shape=inputs.shape) + self.assertAllEqual(outputs.shape, output_shape) + @parameterized.parameters( itertools.product( (None, 0.0, 0.2, 1.0, -0.1),