forked from urchade/GLiNER
-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert_to_onnx.py
75 lines (65 loc) · 3.09 KB
/
convert_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import argparse
import numpy as np
from gliner import GLiNER
import torch
from onnxruntime.quantization import quantize_dynamic, QuantType
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default= "model/")
parser.add_argument('--save_path', type=str, default = 'model/')
parser.add_argument('--quantize', type=bool, default = True)
args = parser.parse_args()
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
onnx_save_path = os.path.join(args.save_path, "model.onnx")
print("Loading a model...")
gliner_model = GLiNER.from_pretrained(args.model_path, load_tokenizer=True)
text = "ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools."
labels = ['format', 'model', 'tool', 'cat']
inputs, _ = gliner_model.prepare_model_inputs([text], labels)
if gliner_model.config.span_mode == 'token_level':
all_inputs = (inputs['input_ids'], inputs['attention_mask'],
inputs['words_mask'], inputs['text_lengths'])
input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths']
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"words_mask": {0: "batch_size", 1: "sequence_length"},
"text_lengths": {0: "batch_size", 1: "value"},
"logits": {0: "position", 1: "batch_size", 2: "sequence_length", 3: "num_classes"},
}
else:
all_inputs = (inputs['input_ids'], inputs['attention_mask'],
inputs['words_mask'], inputs['text_lengths'],
inputs['span_idx'], inputs['span_mask'])
input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths', 'span_idx', 'span_mask']
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"words_mask": {0: "batch_size", 1: "sequence_length"},
"text_lengths": {0: "batch_size", 1: "value"},
"span_idx": {0: "batch_size", 1: "num_spans", 2: "idx"},
"span_mask": {0: "batch_size", 1: "num_spans"},
"logits": {0: "batch_size", 1: "sequence_length", 2: "num_spans", 3: "num_classes"},
}
print('Converting the model...')
torch.onnx.export(
gliner_model.model,
all_inputs,
f=onnx_save_path,
input_names=input_names,
output_names=["logits"],
dynamic_axes=dynamic_axes,
opset_version=14,
)
if args.quantize:
quantized_save_path = os.path.join(args.save_path, "model_quantized.onnx")
# Quantize the ONNX model
print("Quantizing the model...")
quantize_dynamic(
onnx_save_path, # Input model
quantized_save_path, # Output model
weight_type=QuantType.QUInt8 # Quantize weights to 8-bit integers
)
print("Done!")