From 5272b124ad85120e7447a293ecae03c1cf154a74 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Mon, 25 Nov 2024 21:48:55 +0000 Subject: [PATCH] Add support for `--op_block_list` in quantization script (#1036) * Add support for op_block_list * Remove arg * Set default to none * Minor code suggestions * whoops - actually apply suggestions --------- Co-authored-by: Joshua Lochner --- .gitignore | 1 + scripts/quantize.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index f3e6fafbd..cb9d10c0d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__ .vscode node_modules .cache +.DS_STORE # Do not track build artifacts/generated files /dist diff --git a/scripts/quantize.py b/scripts/quantize.py index 1ace2d353..554a064d1 100644 --- a/scripts/quantize.py +++ b/scripts/quantize.py @@ -1,7 +1,7 @@ from enum import Enum from tqdm import tqdm -from typing import Set +from typing import Set, List, Optional import onnx import os @@ -110,6 +110,16 @@ class QuantizationArguments: }, ) + op_block_list: List[str] = field( + default=None, + metadata={ + "help": "List of operators to exclude from quantization." + "Can be any standard ONNX operator (see https://onnx.ai/onnx/operators/)" + "or your custom implemented operators.", + "nargs": "+", + }, + ) + def get_operators(model: onnx.ModelProto) -> Set[str]: operators = set() @@ -131,6 +141,7 @@ def quantize_q8( per_channel: bool, reduce_range: bool, weight_type: QuantType, + op_block_list: Optional[List[str]] ): """ Quantize the weights of the model from float32 to int8/uint8 @@ -140,6 +151,10 @@ def quantize_q8( it is faster on most CPU architectures """ + op_types_to_quantize = set(IntegerOpsRegistry.keys()) + if op_block_list is not None: + op_types_to_quantize.difference_update(op_block_list) + quantizer = ONNXQuantizer( model, per_channel, @@ -151,7 +166,7 @@ def quantize_q8( tensors_range=None, nodes_to_quantize=[], nodes_to_exclude=[], - op_types_to_quantize=list(IntegerOpsRegistry.keys()), + op_types_to_quantize=op_types_to_quantize, extra_options=dict( EnableSubgraph=True, MatMulConstBOnly=True, @@ -165,6 +180,7 @@ def quantize_q8( def quantize_fp16( model: onnx.ModelProto, save_path: str, + op_block_list: Optional[List[str]] ): """ Quantize the weights of the model from float32 to float16 @@ -174,10 +190,15 @@ def quantize_fp16( # ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB: 2338583841 disable_shape_infer = model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF + blocked_ops = set(float16.DEFAULT_OP_BLOCK_LIST) + if op_block_list is not None: + blocked_ops.update(op_block_list) + model_fp16 = float16.convert_float_to_float16( model, keep_io_types=True, disable_shape_infer=disable_shape_infer, + op_block_list=blocked_ops, ) graph = gs.import_onnx(model_fp16) graph.toposort() @@ -271,6 +292,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen quantize_fp16( model, save_path, + quantization_args.op_block_list ) elif mode in (QuantMode.Q4, QuantMode.Q4F16): @@ -287,6 +309,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen quantize_fp16( q4_model, save_path, + quantization_args.op_block_list, ) elif mode == QuantMode.BNB4: @@ -331,6 +354,7 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen per_channel=quantization_args.per_channel, reduce_range=quantization_args.reduce_range, weight_type=weight_type, + op_block_list=quantization_args.op_block_list, )