-
Notifications
You must be signed in to change notification settings - Fork 48
/
onnx_convert.py
65 lines (57 loc) · 1.74 KB
/
onnx_convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch, time
import timm
import io
import onnx
import os
from models.gc_vit import gc_vit_xtiny
def main():
model_name='gc_vit_tiny'
resolution = 224
model = timm.create_model(
model_name,
resolution=resolution,
exportable=True)
in_size = (1, 3, resolution, resolution)
model = model.cuda()
model.eval()
imgs = torch.randn(in_size, device="cuda",requires_grad=True)
export_onnx(model, imgs, onnx_file_name=model_name+'.onnx', export_params=True)
def export_onnx(
model: torch.nn.Module,
sample_inputs,
export_params: bool = False,
opset_version: int = 13,
result_dir: str = "",
batch_first: bool = True,
is_training: bool = False,
onnx_file_name: str ="",
):
f = io.BytesIO()
torch.onnx.export(
model,
# ONNX has issue to unpack the tuple of parameters to the model.
# https://github.com/pytorch/pytorch/issues/11456
(sample_inputs,) if type(sample_inputs) == tuple else sample_inputs,
f,
export_params=export_params,
training=torch.onnx.TrainingMode.TRAINING
if is_training
else torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
opset_version=opset_version,
input_names=["input"] if batch_first else None,
output_names=["output"] if batch_first else None,
dynamic_axes={"input": [0], "output": [0]} if batch_first else None,
)
onnx_model = onnx.load_model_from_string(f.getvalue(), onnx.ModelProto)
f.close()
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
onnx.save(
onnx_model,
os.path.join(
result_dir, onnx_file_name
),
)
return onnx_model
if __name__ == "__main__":
main()