diff --git a/torch2onnx.py b/torch2onnx.py index 377289d..58cfefc 100644 --- a/torch2onnx.py +++ b/torch2onnx.py @@ -25,7 +25,7 @@ def main(): namekey = k[7:] # 去掉module前缀 new_state_dict[namekey] = v - net.load_state_dict(new_state_dict) + net.load_state_dict(new_state_dict, strict=False) dummy_input = torch.randn(1, 3, 320,800, device='cpu') torch.onnx.export(net, dummy_input, 'tusimple_r18.onnx',