-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fails to convert ConvTranspose2D with specific padding #316
Comments
One more note, its not just the padding, but it only fails if stride in either dimension is > 1. class WorksToo(torch.nn.Module):
def __init__(self):
super(WorksToo, self).__init__()
# Setting stride=(1,1) instead of (1,2)
self.deconv = nn.ConvTranspose2d(16, 16, (1,5), stride=(1,1), padding=(0,2), groups=2)
def forward(self, x):
x = self.deconv(x)
return x |
I'm actually failing at the first one using the latest code with this commit: 4ecdbee output: ...
loc(callsite(callsite(callsite("__main__.Works/torch.nn.modules.conv.ConvTranspose2d_deconv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_15"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_25"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): error: failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal
error: mhlo to TFLite legalization failed.
Traceback (most recent call last):
File "/xxxxxx/issues/ai-edge-torch/316/test.py", line 16, in <module>
_edge_model = ai_edge_torch.convert(Works().eval(), (torch.randn((1, 16, 3, 33,), dtype=torch.float32),))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 254, in convert
return Converter().convert(
^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/converter.py", line 169, in convert
return conversion.convert_signatures(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/_convert/conversion.py", line 138, in convert_signatures
tflite_model = lowertools.exported_programs_to_tflite(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/_shim.py", line 75, in exported_programs_to_tflite
return utils.merged_bundle_to_tfl_model(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/git/ai-edge-torch/ai_edge_torch/lowertools/torch_xla_utils.py", line 274, in merged_bundle_to_tfl_model
tflite_model = converter.convert()
^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1237, in wrapper
return self._convert_and_export_metrics(convert_func, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1189, in _convert_and_export_metrics
result = convert_func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1571, in convert
return self._convert_from_saved_model(graph_def)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/lite.py", line 1429, in _convert_from_saved_model
result = _convert_saved_model(**converter_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 212, in wrapper
raise converter_error from None # Re-throws the exception.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert_phase.py", line 205, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert.py", line 888, in convert_saved_model
data = convert(
^^^^^^^^
File "/xxxxxxxx/envs/aet_head/lib/python3.11/site-packages/tensorflow/lite/python/convert.py", line 348, in convert
raise converter_error
tensorflow.lite.python.convert_phase.ConverterError: <unknown>:0: error: loc(callsite(callsite(callsite("__main__.Works/torch.nn.modules.conv.ConvTranspose2d_deconv;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_15"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_25"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
<unknown>:0: error: mhlo to TFLite legalization failed. @alenhardt which version/commit are you using? for AI-Edge-Torch |
Updated issue with version infos.
|
Same observation on ConvTranspose, I find that even with padding = 0, if stride > 1 and dilation > 1 together (either one > 1 but keeping another = 1 works), the same error My setup is identical with @alenhardt. Sample code: import torch
from torch import nn
class SimpleUNetLike(nn.Module):
def __init__(self):
super().__init__()
self.convt1 = nn.ConvTranspose1d(16, 1, 3, padding=0, stride=2, dilation=2)
def forward(self, x):
x = self.convt1(x)
print("convt1", x.shape)
return x
if __name__ == "__main__":
model = SimpleUNetLike()
model.eval()
x = torch.randn(1, 16, 510)
y = model(x)
import ai_edge_torch
edge_model = ai_edge_torch.convert(model, (x,))
edge_model.export(f"simple_model.tflite") @pkgoogle I found some code from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc#L90 |
Description of the bug:
When converting my model with
ai_edge_torch.convert
I get an errorI tracked down the error to a minimal code snippet for reproduction.
The source seems to be the padding argument for nn.ConvTranspose2D which will fail to export if the first padding dimension is set to 0.
Here is the code to reproduce the issue
Actual vs expected behavior:
No response
Any other information you'd like to share?
Here is the output of find_culprits when i tried to convert the original model
The text was updated successfully, but these errors were encountered: