Skip to content

Commit

Permalink
Implement ConvXDTranspose
Browse files Browse the repository at this point in the history
This PR implements unified transpose convolution covering 1D/2D/3D,
SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and
dilation.

SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in
this PR.

Each Literal padding follows the formulas below,
* SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2)))
    pad_total = window+stride-2
    when stride > window -> (window-1, stride-1)
* VALID: padding=(window-1, max(stride-1, window-1))
    pad_total = window+stride-2 + max(window-stride, 0)
    when stride > window -> (window-1, stride-1)
* CAUSAL: padding=(window-1, stride-1)
    pad_total = window+stride-2

Note: output_size = input_size*stride - (window+stride-2) + pad_total
                  = input_size*stride  <- "SAME", "CAUSAL"
                  = input_size*stride + max(window-stride, 0)  <- "VALID"

Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1.
    dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window()

The following illustration demonstrates how Conv Transpose operates, assuming all kernel values
are set to 1 for simplicity in showcasing output values.

In the window=3 and stride=1 case, this function creates outputs as follows:
* "SAME" padding=(1, 1)
                pad|       |pad
    paddings:     0|0 0 1 1|0
                  0 0 0  -> 0
                    0 0 1  -> 1
                      0 1 1  -> 2
                        1 1 0  -> 2

* "VALID" padding=(2, 2)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2
                          1 1 0  -> 2
                            1 0 0  -> 1

* "CAUSAL" padding=(2, 0)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2

In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

* "VALID" padding=(2, 2)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1
                                  1 0 0  -> 1

* "CAUSAL" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

In the window=3 and stride=3 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 2)
                pad  |                   |pad
    paddings:     0 0|0 * * 0 * * 1 * * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 1  -> 1
                                0 1 0  -> 1
                                  1 0 0  -> 1
                                    0 0 1  -> 1
                                      0 1 0  -> 1
                                        1 0 0  -> 1

In the window=3 and stride=4 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 3)
                pad  |                         |pad
    paddings:     0 0|0 * * * 0 * * * 1 * * * 1|0 0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 0  -> 0
                                0 0 0  -> 0
                                  0 0 1  -> 1
                                    0 1 0  -> 1
                                      1 0 0  -> 1
                                        0 0 0  -> 0
                                          0 0 1  -> 1
                                            0 1 0  -> 1
                                              1 0 0  -> 1
                                                0 0 0  -> 0
    Here is how to compute output_size, given the above example,
      1.          |_|  -(window-1)
      2.              |_______________________|  (input_size-1)*stride + 1
      3.          |_|                           |___|  + pad_total

    So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total
                    = input_size*stride - (window+stride-2) + pad_total
                    = input_size*stride  <- "SAME", "CAUSAL"
                    = input_size*stride + max(window-stride, 0)  <- "VALID"

OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1.
For example, when window=3 and dilation=2, dilate_window=5.

In the stride=2 case, this function creates outputs as follows:
* "SAME" padding=(3, 2)
                pad    |             |pad
    paddings:     0 0 0|0 * 0 * 1 * 1|0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 1  -> 1
                          0 * 0 * 0  -> 0
                            0 * 1 * 1  -> 2
                              0 * 0 * 0  -> 0
                                1 * 1 * 0  -> 2

* "VALID" padding=(4, 4)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
                                  1 * 1 * 0  -> 2
                                    0 * 0 * 0  -> 0
                                      1 * 0 * 0  -> 1

* "CAUSAL" padding=(4, 1)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
  • Loading branch information
ds-hwang committed Nov 23, 2024
1 parent d0edead commit 7ab7e90
Show file tree
Hide file tree
Showing 6 changed files with 1,683 additions and 206 deletions.
14 changes: 11 additions & 3 deletions axlearn/common/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from axlearn.common.config import REQUIRED, InstantiableConfig, Required, config_class
from axlearn.common.layers import (
BatchNorm,
DepthwiseConv1D,
Conv1D,
Dropout,
GroupNorm,
LayerNorm,
Expand Down Expand Up @@ -72,7 +72,7 @@ class Config(BaseLayer.Config):
linear1_norm: LayerNorm.Config = LayerNorm.default_config()
linear1_activation: tuple[str, str] = ("linear", "nn.sigmoid")
linear1: Linear.Config = Linear.default_config().set(bias=True)
conv: DepthwiseConv1D.Config = DepthwiseConv1D.default_config().set(
conv: Conv1D.Config = Conv1D.default_config().set(
# See Table 2 and 7.
window=32,
bias=False,
Expand All @@ -96,7 +96,15 @@ def __init__(self, cfg: Config, *, parent: Module):
cfg.linear1.set(input_dim=cfg.input_dim, output_dim=cfg.input_dim),
)

self._add_child("conv", cfg.conv.set(input_dim=cfg.input_dim))
# Setup Depthwise Convolution (3 dims are same).
self._add_child(
"conv",
cfg.conv.set(
input_dim=cfg.input_dim,
output_dim=cfg.input_dim,
num_input_dim_groups=cfg.input_dim,
),
)
self._add_child("conv_norm", cfg.conv_norm.set(input_dim=cfg.input_dim))
self._add_child(
"linear2",
Expand Down
Loading

0 comments on commit 7ab7e90

Please sign in to comment.