Skip to content

Commit

Permalink
Add support for --op_block_list in quantization script (#1036)
Browse files Browse the repository at this point in the history
* Add support for op_block_list

* Remove arg

* Set default to none

* Minor code suggestions

* whoops - actually apply suggestions

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
pdufour and xenova authored Nov 25, 2024
1 parent c9f12e5 commit 5272b12
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ __pycache__
.vscode
node_modules
.cache
.DS_STORE

# Do not track build artifacts/generated files
/dist
Expand Down
28 changes: 26 additions & 2 deletions scripts/quantize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum

from tqdm import tqdm
from typing import Set
from typing import Set, List, Optional
import onnx
import os

Expand Down Expand Up @@ -110,6 +110,16 @@ class QuantizationArguments:
},
)

op_block_list: List[str] = field(
default=None,
metadata={
"help": "List of operators to exclude from quantization."
"Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)"
"or your custom implemented operators.",
"nargs": "+",
},
)


def get_operators(model: onnx.ModelProto) -> Set[str]:
operators = set()
Expand All @@ -131,6 +141,7 @@ def quantize_q8(
per_channel: bool,
reduce_range: bool,
weight_type: QuantType,
op_block_list: Optional[List[str]]
):
"""
Quantize the weights of the model from float32 to int8/uint8
Expand All @@ -140,6 +151,10 @@ def quantize_q8(
it is faster on most CPU architectures
"""

op_types_to_quantize = set(IntegerOpsRegistry.keys())
if op_block_list is not None:
op_types_to_quantize.difference_update(op_block_list)

quantizer = ONNXQuantizer(
model,
per_channel,
Expand All @@ -151,7 +166,7 @@ def quantize_q8(
tensors_range=None,
nodes_to_quantize=[],
nodes_to_exclude=[],
op_types_to_quantize=list(IntegerOpsRegistry.keys()),
op_types_to_quantize=op_types_to_quantize,
extra_options=dict(
EnableSubgraph=True,
MatMulConstBOnly=True,
Expand All @@ -165,6 +180,7 @@ def quantize_q8(
def quantize_fp16(
model: onnx.ModelProto,
save_path: str,
op_block_list: Optional[List[str]]
):
"""
Quantize the weights of the model from float32 to float16
Expand All @@ -174,10 +190,15 @@ def quantize_fp16(
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841
disable_shape_infer = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF

blocked_ops = set(float16.DEFAULT_OP_BLOCK_LIST)
if op_block_list is not None:
blocked_ops.update(op_block_list)

model_fp16 = float16.convert_float_to_float16(
model,
keep_io_types=True,
disable_shape_infer=disable_shape_infer,
op_block_list=blocked_ops,
)
graph = gs.import_onnx(model_fp16)
graph.toposort()
Expand Down Expand Up @@ -271,6 +292,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
quantize_fp16(
model,
save_path,
quantization_args.op_block_list
)

elif mode in (QuantMode.Q4, QuantMode.Q4F16):
Expand All @@ -287,6 +309,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
quantize_fp16(
q4_model,
save_path,
quantization_args.op_block_list,
)

elif mode == QuantMode.BNB4:
Expand Down Expand Up @@ -331,6 +354,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
per_channel=quantization_args.per_channel,
reduce_range=quantization_args.reduce_range,
weight_type=weight_type,
op_block_list=quantization_args.op_block_list,
)


Expand Down

0 comments on commit 5272b12

Please sign in to comment.