-
Notifications
You must be signed in to change notification settings - Fork 1
/
pytorch2onnx.py
165 lines (140 loc) · 5.8 KB
/
pytorch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import warnings
import numpy as np
import torch
from mmpose.apis import init_pose_model
try:
import onnx
import onnxruntime as rt
except ImportError as e:
raise ImportError(f'Please install onnx and onnxruntime first. {e}')
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=1.0.4')
def _convert_batchnorm(module):
"""Convert the syncBNs into normal BN3ds."""
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm3d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def pytorch2onnx(model,
input_shape,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
"""Convert pytorch model to onnx model.
Args:
model (:obj:`nn.Module`): The pytorch model to be exported.
input_shape (tuple[int]): The input tensor shape of the model.
opset_version (int): Opset version of onnx used. Default: 11.
show (bool): Determines whether to print the onnx model architecture.
Default: False.
output_file (str): Output onnx model name. Default: 'tmp.onnx'.
verify (bool): Determines whether to verify the onnx model.
Default: False.
"""
model.cpu().eval()
one_img = torch.randn(input_shape)
register_extra_symbolics(opset_version)
torch.onnx.export(
model,
one_img,
output_file,
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_results = model(one_img)
if not isinstance(pytorch_results, (list, tuple)):
assert isinstance(pytorch_results, torch.Tensor)
pytorch_results = [pytorch_results]
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert len(net_feed_input) == 1
sess = rt.InferenceSession(output_file)
onnx_results = sess.run(None,
{net_feed_input[0]: one_img.detach().numpy()})
# compare results
assert len(pytorch_results) == len(onnx_results)
for pt_result, onnx_result in zip(pytorch_results, onnx_results):
assert np.allclose(
pt_result.detach().cpu(), onnx_result, atol=1.e-5
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')
def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMPose models to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[1, 3, 256, 192],
help='input size')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
assert args.opset_version == 11, 'MMPose only supports opset 11 now'
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
model = init_pose_model(args.config, args.checkpoint, device='cpu')
model = _convert_batchnorm(model)
# onnx.export does not support kwargs
if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'Please implement the forward method for exporting.')
# convert model to onnx file
pytorch2onnx(
model,
args.shape,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)