Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fails to convert ConvTranspose2D with specific padding #316

Open
alenhardt opened this issue Oct 24, 2024 · 4 comments
Open

Fails to convert ConvTranspose2D with specific padding #316

alenhardt opened this issue Oct 24, 2024 · 4 comments

Comments

@alenhardt
Copy link

alenhardt commented Oct 24, 2024

Description of the bug:

python 3.10
ai_edge_torch.__version__:  0.3.0.dev20241024
torch_xla.__version__:  2.4.0
tensorflow.__version__:  2.17.0

When converting my model with ai_edge_torch.convert I get an error

loc(callsite(callsite(callsite("__main__.CulpritGraphModule;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_36933"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_36941"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): error: 'vhlo.convolution_v1' op is not part of the vhlo support yet.
error: failed while converting: 'main': 

I tracked down the error to a minimal code snippet for reproduction.
The source seems to be the padding argument for nn.ConvTranspose2D which will fail to export if the first padding dimension is set to 0.

Here is the code to reproduce the issue

import torch.nn as nn
import torch
import ai_edge_torch

# WORKS: Padding (1,2)
class Works(torch.nn.Module):
    def __init__(self):
        super(Works, self).__init__()
        self.deconv = nn.ConvTranspose2d(16, 16, (1,5), stride=(1,2), padding=(1,2), groups=2)
    
    def forward(self, x):
        x = self.deconv(x)
        return x
# use (1, 16, 3, 33,) input dim instead of (1, 16, 1, 33,)
_edge_model = ai_edge_torch.convert(Works().eval(), (torch.randn((1, 16, 3, 33,), dtype=torch.float32),))


# WORKS: No Padding, groups to 1
class Works2(torch.nn.Module):
    def __init__(self):
        super(Works2, self).__init__()
        self.deconv = nn.ConvTranspose2d(16, 16, (1,5), stride=(1,2), groups=1)
    
    def forward(self, x):
        x = self.deconv(x)
        return x

_edge_model = ai_edge_torch.convert(Works2().eval(), (torch.randn((1, 16, 1, 33,), dtype=torch.float32),))

# FAILS: Padding (0, ...)
class Fails(torch.nn.Module):
    def __init__(self):
        super(Fails, self).__init__()
        self.deconv = nn.ConvTranspose2d(16, 16, (1,5), stride=(1,2), padding=(0,2), groups=2)
    
    def forward(self, x):
        x = self.deconv(x)
        return x

m = Fails().eval()
y = m(torch.randn((1, 16, 1, 33,)))
print(y.shape)
_edge_model = ai_edge_torch.convert(Fails().eval(), (torch.randn((1, 16, 1, 33,), dtype=torch.float32),))

Actual vs expected behavior:

No response

Any other information you'd like to share?

Here is the output of find_culprits when i tried to convert the original model

import torch
from torch import device
import ai_edge_torch

class CulpritGraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[16, 8, 1, 5]", arg1_1: "f32[16]", arg2_1: "f32[1, 16, 1, 33]"):
        # File: ... in forward, code: return self.act(self.bn(self.conv(x)))
        convolution: "f32[1, 16, 1, 65]" = torch.ops.aten.convolution.default(arg2_1, arg0_1, arg1_1, [1, 2], [0, 2], [1, 1], True, [0, 0], 2);  arg2_1 = arg0_1 = arg1_1 = None
        return (convolution,)

_args = (
    torch.randn((16, 8, 1, 5,), dtype=torch.float32),
    torch.randn((16,), dtype=torch.float32),
    torch.randn((1, 16, 1, 33,), dtype=torch.float32),
)

_edge_model = ai_edge_torch.convert(CulpritGraphModule().eval(), _args)
@alenhardt alenhardt added the type:bug Bug label Oct 24, 2024
@alenhardt
Copy link
Author

alenhardt commented Oct 24, 2024

One more note, its not just the padding, but it only fails if stride in either dimension is > 1.
So this works as well:

class WorksToo(torch.nn.Module):
    def __init__(self):
        super(WorksToo, self).__init__()
        # Setting stride=(1,1) instead of (1,2)
        self.deconv = nn.ConvTranspose2d(16, 16, (1,5), stride=(1,1), padding=(0,2), groups=2)
    
    def forward(self, x):
        x = self.deconv(x)
        return x

@pkgoogle pkgoogle self-assigned this Oct 24, 2024
@pkgoogle
Copy link
Contributor

I'm actually failing at the first one using the latest code with this commit: 4ecdbee

output:

...
loc(callsite(callsite(callsite("__main__.Works/torch.nn.modules.conv.ConvTranspose2d_deconv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_15"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_25"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): error: failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal
error: mhlo to TFLite legalization failed.
Traceback (most recent call last):
  File "/xxxxxx/issues/ai-edge-torch/316/test.py", line 16, in <module>
    _edge_model = ai_edge_torch.convert(Works().eval(), (torch.randn((1, 16, 3, 33,), dtype=torch.float32),))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 254, in convert
    return Converter().convert(
           ^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 169, in convert
    return conversion.convert_signatures(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/conversion.py", line 138, in convert_signatures
    tflite_model = lowertools.exported_programs_to_tflite(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/_shim.py", line 75, in exported_programs_to_tflite
    return utils.merged_bundle_to_tfl_model(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/torch_xla_utils.py", line 274, in merged_bundle_to_tfl_model
    tflite_model = converter.convert()
                   ^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1237, in wrapper
    return self._convert_and_export_metrics(convert_func, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1189, in _convert_and_export_metrics
    result = convert_func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1571, in convert
    return self._convert_from_saved_model(graph_def)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1429, in _convert_from_saved_model
    result = _convert_saved_model(**converter_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 212, in wrapper
    raise converter_error from None  # Re-throws the exception.
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 205, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert.py", line 888, in convert_saved_model
    data = convert(
           ^^^^^^^^
  File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert.py", line 348, in convert
    raise converter_error
tensorflow.lite.python.convert_phase.ConverterError: <unknown>:0: error: loc(callsite(callsite(callsite("__main__.Works/torch.nn.modules.conv.ConvTranspose2d_deconv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_15"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_25"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
<unknown>:0: error: mhlo to TFLite legalization failed.

@alenhardt which version/commit are you using? for AI-Edge-Torch

@alenhardt
Copy link
Author

alenhardt commented Oct 24, 2024

Updated issue with version infos.

python: 3.10
ai_edge_torch.__version__:  0.3.0.dev20241024
torch_xla.__version__:  2.4.0
tensorflow.__version__:  2.17.0

@gudgud96
Copy link

gudgud96 commented Nov 4, 2024

Same observation on ConvTranspose, I find that even with padding = 0, if stride > 1 and dilation > 1 together (either one > 1 but keeping another = 1 works), the same error error: mhlo to TFLite legalization failed. is thrown.

My setup is identical with @alenhardt.

Sample code:

import torch
from torch import nn


class SimpleUNetLike(nn.Module):
    def __init__(self):
        super().__init__()
        self.convt1 = nn.ConvTranspose1d(16, 1, 3, padding=0, stride=2, dilation=2)
    
    def forward(self, x):
        x = self.convt1(x)
        print("convt1", x.shape)
        return x


if __name__ == "__main__":
    model = SimpleUNetLike()
    model.eval()
    x = torch.randn(1, 16, 510)
    y = model(x)

    import ai_edge_torch
    edge_model = ai_edge_torch.convert(model, (x,))
    edge_model.export(f"simple_model.tflite")

@pkgoogle I found some code from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc#L90
which basically has some if-else statements determining if a given ConvTranspose argument combination is legal. Could we get someone from TFLite to have a look / explain about the limitations of the ConvTranspose op?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants