Skip to content

Commit

Permalink
Implement ConvXDTranspose
Browse files Browse the repository at this point in the history
This PR implements unified transpose convolution covering 1D/2D/3D,
SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and
dilation.

SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in
this PR.

Each Literal padding follows the formulas below,
* SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2)))
    pad_total = window+stride-2
    when stride > window -> (window-1, stride-1)
* VALID: padding=(window-1, max(stride-1, window-1))
    pad_total = window+stride-2 + max(window-stride, 0)
    when stride > window -> (window-1, stride-1)
* CAUSAL: padding=(window-1, stride-1)
    pad_total = window+stride-2

Note: output_size = input_size*stride - (window+stride-2) + pad_total
                  = input_size*stride  <- "SAME", "CAUSAL"
                  = input_size*stride + max(window-stride, 0)  <- "VALID"

Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1.
    dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window()

The following illustration demonstrates how Conv Transpose operates, assuming all kernel values
are set to 1 for simplicity in showcasing output values.

In the window=3 and stride=1 case, this function creates outputs as follows:
* "SAME" padding=(1, 1)
                pad|       |pad
    paddings:     0|0 0 1 1|0
                  0 0 0  -> 0
                    0 0 1  -> 1
                      0 1 1  -> 2
                        1 1 0  -> 2

* "VALID" padding=(2, 2)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2
                          1 1 0  -> 2
                            1 0 0  -> 1

* "CAUSAL" padding=(2, 0)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2

In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

* "VALID" padding=(2, 2)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1
                                  1 0 0  -> 1

* "CAUSAL" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

In the window=3 and stride=3 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 2)
                pad  |                   |pad
    paddings:     0 0|0 * * 0 * * 1 * * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 1  -> 1
                                0 1 0  -> 1
                                  1 0 0  -> 1
                                    0 0 1  -> 1
                                      0 1 0  -> 1
                                        1 0 0  -> 1

In the window=3 and stride=4 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 3)
                pad  |                         |pad
    paddings:     0 0|0 * * * 0 * * * 1 * * * 1|0 0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 0  -> 0
                                0 0 0  -> 0
                                  0 0 1  -> 1
                                    0 1 0  -> 1
                                      1 0 0  -> 1
                                        0 0 0  -> 0
                                          0 0 1  -> 1
                                            0 1 0  -> 1
                                              1 0 0  -> 1
                                                0 0 0  -> 0
    Here is how to compute output_size, given the above example,
      1.          |_|  -(window-1)
      2.              |_______________________|  (input_size-1)*stride + 1
      3.          |_|                           |___|  + pad_total

    So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total
                    = input_size*stride - (window+stride-2) + pad_total
                    = input_size*stride  <- "SAME", "CAUSAL"
                    = input_size*stride + max(window-stride, 0)  <- "VALID"

OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1.
For example, when window=3 and dilation=2, dilate_window=5.

In the stride=2 case, this function creates outputs as follows:
* "SAME" padding=(3, 2)
                pad    |             |pad
    paddings:     0 0 0|0 * 0 * 1 * 1|0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 1  -> 1
                          0 * 0 * 0  -> 0
                            0 * 1 * 1  -> 2
                              0 * 0 * 0  -> 0
                                1 * 1 * 0  -> 2

* "VALID" padding=(4, 4)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
                                  1 * 1 * 0  -> 2
                                    0 * 0 * 0  -> 0
                                      1 * 0 * 0  -> 1

* "CAUSAL" padding=(4, 1)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
                                  1 * 1 * 0  -> 2
                                    0 * 0 * 0  -> 0
                                      1 * 0 * 0  -> 1
  • Loading branch information
ds-hwang committed Nov 21, 2024
1 parent 1af2ba8 commit 250287a
Show file tree
Hide file tree
Showing 12 changed files with 1,875 additions and 444 deletions.
1,294 changes: 1,049 additions & 245 deletions axlearn/common/layers.py

Large diffs are not rendered by default.

987 changes: 788 additions & 199 deletions axlearn/common/layers_test.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ model.encoder.context.context.layer.ff_start.stochastic_depth.mode: 'row'
model.encoder.context.context.layer.ff_start.structure: 'prenorm'
model.encoder.context.context.layer.klass: 'axlearn.common.conformer.ConformerLayer'
model.encoder.context.context.layer.lconv.conv.bias: False
model.encoder.context.context.layer.lconv.conv.dilation: 1
model.encoder.context.context.layer.lconv.conv.klass: 'axlearn.common.layers.DepthwiseConv1D'
model.encoder.context.context.layer.lconv.conv.padding: 'SAME'
model.encoder.context.context.layer.lconv.conv.param_partition_spec[0]: None
Expand Down Expand Up @@ -524,6 +525,8 @@ model.encoder.feature.klass: 'axlearn.audio.encoder_asr.SpeechFeatureLayer'
model.encoder.feature.output_dim: 512
model.encoder.feature.subsampler.activation: 'nn.relu'
model.encoder.feature.subsampler.conv.bias: True
model.encoder.feature.subsampler.conv.dilation[0]: 1
model.encoder.feature.subsampler.conv.dilation[1]: 1
model.encoder.feature.subsampler.conv.klass: 'axlearn.common.layers.Conv2DWith1DPadding'
model.encoder.feature.subsampler.conv.num_input_dim_groups: 1
model.encoder.feature.subsampler.conv.padding[0][0]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ model.encoder.context.context.layer.ff_start.stochastic_depth.mode: 'row'
model.encoder.context.context.layer.ff_start.structure: 'prenorm'
model.encoder.context.context.layer.klass: 'axlearn.common.conformer.ConformerLayer'
model.encoder.context.context.layer.lconv.conv.bias: False
model.encoder.context.context.layer.lconv.conv.dilation: 1
model.encoder.context.context.layer.lconv.conv.klass: 'axlearn.common.layers.DepthwiseConv1D'
model.encoder.context.context.layer.lconv.conv.padding: 'SAME'
model.encoder.context.context.layer.lconv.conv.param_partition_spec[0]: None
Expand Down Expand Up @@ -226,6 +227,8 @@ model.encoder.feature.klass: 'axlearn.audio.encoder_asr.SpeechFeatureLayer'
model.encoder.feature.output_dim: 4
model.encoder.feature.subsampler.activation: 'nn.relu'
model.encoder.feature.subsampler.conv.bias: True
model.encoder.feature.subsampler.conv.dilation[0]: 1
model.encoder.feature.subsampler.conv.dilation[1]: 1
model.encoder.feature.subsampler.conv.klass: 'axlearn.common.layers.Conv2DWith1DPadding'
model.encoder.feature.subsampler.conv.num_input_dim_groups: 1
model.encoder.feature.subsampler.conv.padding[0][0]: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -169,6 +171,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -169,6 +171,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -169,6 +171,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -169,6 +171,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -170,6 +172,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -169,6 +171,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -165,6 +167,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ model.backbone.param_init.init_by_param_name['.*weight$'].scale: 1.4142135623730
model.backbone.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.backbone.stage.block.activation: 'nn.relu'
model.backbone.stage.block.conv.bias: False
model.backbone.stage.block.conv.dilation[0]: 1
model.backbone.stage.block.conv.dilation[1]: 1
model.backbone.stage.block.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stage.block.conv.num_input_dim_groups: 1
model.backbone.stage.block.conv.padding[0][0]: 1
Expand Down Expand Up @@ -166,6 +168,8 @@ model.backbone.stage.klass: 'axlearn.vision.resnet.ResNetStage'
model.backbone.stage.stride: 1
model.backbone.stem.activation: 'nn.relu'
model.backbone.stem.conv.bias: False
model.backbone.stem.conv.dilation[0]: 1
model.backbone.stem.conv.dilation[1]: 1
model.backbone.stem.conv.klass: 'axlearn.common.layers.Conv2D'
model.backbone.stem.conv.num_input_dim_groups: 1
model.backbone.stem.conv.padding[0][0]: 3
Expand Down

0 comments on commit 250287a

Please sign in to comment.