From 791d4449c49f3b4cd73487176f45d43f951adfd6 Mon Sep 17 00:00:00 2001 From: Tracin <953719031@qq.com> Date: Fri, 22 Jul 2022 17:04:10 +0800 Subject: [PATCH] [Update] Conv3d && Fix bugs. (#146) * [Update] Conv3d && Fix bugs. * [REQ] Update torch version to 1.10.0. * [REQ] Update torchvision version to 0.11.1. * [Update] Update observer. * [Test] Update test. Co-authored-by: zhangqi3 Co-authored-by: fanyunqian --- ...on-package-conda.yml => lint-and-test.yml} | 7 +++- mqbench/custom_quantizer/model_quantizer.py | 4 -- .../custom_quantizer/onnx_qnn_quantizer.py | 7 ++-- .../custom_quantizer/tensorrt_quantizer.py | 3 -- mqbench/deploy/deploy_tengine.py | 12 ++++-- mqbench/fuser_method_mappings.py | 3 +- mqbench/fusion_method.py | 34 ++++++++++++++--- mqbench/observer.py | 37 +++++-------------- requirements.txt | 5 +-- test/model/test_model.py | 7 ++-- test/version.py | 4 +- 11 files changed, 64 insertions(+), 59 deletions(-) rename .github/workflows/{python-package-conda.yml => lint-and-test.yml} (84%) diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/lint-and-test.yml similarity index 84% rename from .github/workflows/python-package-conda.yml rename to .github/workflows/lint-and-test.yml index db18b1b..4b16456 100644 --- a/.github/workflows/python-package-conda.yml +++ b/.github/workflows/lint-and-test.yml @@ -1,9 +1,9 @@ -name: Lint and test. +name: Lint and test on: [push] jobs: - build-linux: + Lint-and-test: runs-on: ubuntu-latest strategy: max-parallel: 5 @@ -22,6 +22,9 @@ jobs: run: | conda install flake8 flake8 . + - name: Install onnxruntime and onnxsim + run: + pip install onnxruntime onnx-simplifier - name: Install Protobuf run: conda install protobuf=3.20.1 diff --git a/mqbench/custom_quantizer/model_quantizer.py b/mqbench/custom_quantizer/model_quantizer.py index 0c0434f..a9c7829 100644 --- a/mqbench/custom_quantizer/model_quantizer.py +++ b/mqbench/custom_quantizer/model_quantizer.py @@ -263,10 +263,6 @@ def _convert(self, module, mapping=None, inplace=False, scope=''): if not isinstance(mod, _FusedModule): self._convert(mod, mapping, True, new_scope) reassign[name] = swap_module(mod, mapping, {}) - if isinstance(mod, torch.nn.ConvTranspose2d): - if hasattr(reassign[name], "weight_fake_quant") and reassign[name].weight_fake_quant.ch_axis != -1: - reassign[name].weight_fake_quant.ch_axis = 1 - reassign[name].weight_fake_quant.activation_post_process.ch_axis = 1 for key, value in reassign.items(): module._modules[key] = value diff --git a/mqbench/custom_quantizer/onnx_qnn_quantizer.py b/mqbench/custom_quantizer/onnx_qnn_quantizer.py index 7c4ebf4..c4469e4 100644 --- a/mqbench/custom_quantizer/onnx_qnn_quantizer.py +++ b/mqbench/custom_quantizer/onnx_qnn_quantizer.py @@ -7,8 +7,8 @@ from torch.quantization.utils import get_combined_dict -import mqbench.nn as qnn -import mqbench.nn.intrinsic as qnni +import mqbench.nn as qnn +import mqbench.nn.intrinsic as qnni from mqbench.utils.registry import register_model_quantizer from mqbench.prepare_by_platform import BackendType from mqbench.custom_quantizer import ModelQuantizer @@ -57,7 +57,6 @@ def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Di get_default_qat_module_mappings(), additional_qat_module_mapping) # There is no QLinearFC in ONNX for now. del(all_mappings[torch.nn.modules.linear.Linear]) - del(all_mappings[torch.nn.modules.linear._LinearWithBias]) del(all_mappings[torch.nn.intrinsic.modules.fused.LinearReLU]) del(all_mappings[qnni.modules.fused.LinearBn1d]) root = self._convert(root, all_mappings, inplace=True) @@ -97,4 +96,4 @@ def implicit_merge_patterns(self) -> list: # In reversed order! return [ (torch.nn.ReLU, operator.add) - ] \ No newline at end of file + ] diff --git a/mqbench/custom_quantizer/tensorrt_quantizer.py b/mqbench/custom_quantizer/tensorrt_quantizer.py index 1e3e0e1..3dcb5b7 100644 --- a/mqbench/custom_quantizer/tensorrt_quantizer.py +++ b/mqbench/custom_quantizer/tensorrt_quantizer.py @@ -128,9 +128,6 @@ def _find_act_quants(self, model: GraphModule) -> List: if node.op == "call_function" and node.target == operator.add and \ self._is_skiped_add(node, modules, input_node_list): continue - if node.op == "call_function" and node.target == operator.add: - import pdb - pdb.set_trace() for _node in input_node_list: if self._is_implicit_merge(modules, (node, _node)): logger.info("Implicit merge: {} + {}".format(_node.name, node.name)) diff --git a/mqbench/deploy/deploy_tengine.py b/mqbench/deploy/deploy_tengine.py index eeda13f..ab63ef2 100644 --- a/mqbench/deploy/deploy_tengine.py +++ b/mqbench/deploy/deploy_tengine.py @@ -1,10 +1,6 @@ import os from collections import OrderedDict -import onnx -from onnx import numpy_helper -from onnxsim import simplify - from ..utils.logger import logger from .deploy_linear import ( LinearQuantizer_process, @@ -20,6 +16,14 @@ get_constant_inputs ) +import onnx +from onnx import numpy_helper +try: + from onnxsim import simplify +except ModuleNotFoundError: + logger.warn('onnxsim not found, if you want to use deploy_tengine, please install it.') + + class Tengine_process(LinearQuantizer_process): diff --git a/mqbench/fuser_method_mappings.py b/mqbench/fuser_method_mappings.py index 6aef34c..eb1a56a 100644 --- a/mqbench/fuser_method_mappings.py +++ b/mqbench/fuser_method_mappings.py @@ -33,6 +33,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node): self.conv_node = node self.conv = quantizer.modules[self.conv_node.target] + def fuse_linear_bn(linear, bn): r"""Given the linear and bn modules, fuses them and returns the fused module @@ -83,7 +84,6 @@ def fuse_deconv_bn_relu(deconv, bn, relu): return qnni.ConvTransposeReLU2d(fuse_deconv_bn_eval(deconv, bn), relu) - def fuse_conv_freezebn(conv, bn): assert(bn.training is False), "Freezebn must be eval." @@ -100,6 +100,7 @@ def fuse_conv_freezebn(conv, bn): else: return nn.utils.fuse_conv_bn_eval(conv, bn) + def fuse_conv_freezebn_relu(conv, bn, relu): assert(conv.training == relu.training and bn.training is False), "Conv and relu both must be in the same mode (train or eval) and bn must be eval." fused_module : Optional[Type[nn.Sequential]] = None diff --git a/mqbench/fusion_method.py b/mqbench/fusion_method.py index 7c607de..2bbf693 100644 --- a/mqbench/fusion_method.py +++ b/mqbench/fusion_method.py @@ -42,20 +42,37 @@ def convert_qnniqat_linearbn(model, fused_node): @register_convert_function(qnniqat.ConvFreezebn2d) @register_convert_function(nniqat.ConvBn2d) +@register_convert_function(nniqat.ConvBn3d) def convert_nniqat_convbn(model, fused_node): + """nniqat.ConvBn2d ----> nn.Conv2d ----> nniqat.Conv2d + """ + fused_module_class_map = { + qnniqat.ConvFreezebn2d: torch.nn.Conv2d, + qnniqat.ConvFreezebnReLU2d: torch.nn.Conv2d, + nniqat.ConvBn2d: torch.nn.Conv2d, + nniqat.ConvBnReLU2d: torch.nn.Conv2d, + nniqat.ConvBn3d: torch.nn.Conv3d, + nniqat.ConvBnReLU3d: torch.nn.Conv3d, + } + fused_qat_module_class_map = { + torch.nn.Conv2d: torch.nn.qat.Conv2d, + torch.nn.Conv3d: torch.nn.qat.Conv3d, + } modules = dict(model.named_modules()) fused_module = modules[fused_node.target] # Create a Conv2d from FusedModule. - conv = torch.nn.Conv2d(fused_module.in_channels, fused_module.out_channels, fused_module.kernel_size, - fused_module.stride, fused_module.padding, fused_module.dilation, - fused_module.groups, fused_module.bias is not None, fused_module.padding_mode) + conv = fused_module_class_map[type(fused_module)](fused_module.in_channels, fused_module.out_channels, + fused_module.kernel_size, fused_module.stride, + fused_module.padding, fused_module.dilation, + fused_module.groups, fused_module.bias is not None, + fused_module.padding_mode) conv.weight = fused_module.weight if fused_module.bias is not None: conv.bias = fused_module.bias fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn) # We need nn.qat.conv here to export weight quantize node. fused_conv.qconfig = fused_module.qconfig - fused_conv = torch.nn.qat.Conv2d.from_float(fused_conv) + fused_conv = fused_qat_module_class_map[type(conv)].from_float(fused_conv) # Attach weight fake quantize params. fused_conv.weight_fake_quant = fused_module.weight_fake_quant conv_parent_name, conv_name = _parent_name(fused_node.target) @@ -64,7 +81,8 @@ def convert_nniqat_convbn(model, fused_node): @register_convert_function(qnniqat.ConvFreezebnReLU2d) @register_convert_function(nniqat.ConvBnReLU2d) -def convert_nniqat_convbnrelu(model, fused_node): +@register_convert_function(nniqat.ConvBnReLU3d) +def convert_nniqat_convbnrelu(model, fused_node): convert_nniqat_convbn(model, fused_node) modules = dict(model.named_modules()) fused_module = modules[fused_node.target] @@ -196,6 +214,9 @@ def convert_qnniqat_deconvbnrelu(model, fused_node): @register_convert_function(qnniqat.ConvBn2d) def convert_qnniqat_convbn(model, fused_node): + """mqbench.nn.intrinsic.qat module add bias quant. + That is the difference between torch.nn.intrinsic.qat module. + """ modules = dict(model.named_modules()) fused_module = modules[fused_node.target] # Create a Conv2d from FusedModule. @@ -222,6 +243,9 @@ def convert_qnniqat_convbn(model, fused_node): @register_convert_function(qnniqat.ConvBnReLU2d) def convert_qnniqat_convbnrelu(model, fused_node): + """mqbench.nn.intrinsic.qat module add bias quant. + That is the difference between torch.nn.intrinsic.qat module. + """ convert_qnniqat_convbn(model, fused_node) modules = dict(model.named_modules()) fused_module = modules[fused_node.target] diff --git a/mqbench/observer.py b/mqbench/observer.py index eaf4ddf..dc1b2b0 100644 --- a/mqbench/observer.py +++ b/mqbench/observer.py @@ -1,7 +1,7 @@ import math from functools import partial from typing import Tuple -from copy import deepcopy + import torch from torch.quantization.observer import _ObserverBase @@ -28,12 +28,14 @@ class ObserverBase(_ObserverBase): def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False, factory_kwargs=None): - factory_kwargs = deepcopy(factory_kwargs) - self.not_calc_quant_min_max = factory_kwargs.pop('not_calc_quant_min_max', False) if isinstance(factory_kwargs, dict) else False + # Since torch 1.10, function calculate_qmin_qmax is not a member function of observer, + # but import from utils. It is hard to control. We use try...except here. + stored_min, sotred_max = quant_min, quant_max + if quant_max is not None and quant_min is not None and (quant_max - quant_min + 1 > 256): + quant_min, quant_max = -128, 127 super(ObserverBase, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max) - # for compatibility with 1.10, prevent the value of self.quant_min,self.quant_max being modified - self.quant_min = quant_min - self.quant_max = quant_max + self.quant_min = stored_min + self.quant_max = sotred_max self.quant_min, self.quant_max = self._calculate_qmin_qmax() self.ch_axis = ch_axis self.pot_scale = pot_scale @@ -79,28 +81,7 @@ def _calculate_qmin_qmax(self) -> Tuple[int, int]: observer datatype and if range is reduced. """ if self.has_customized_qrange: - # This initialization here is to be resolve TorchScript compilation issues and allow - # using of refinement to decouple initial_qmin and initial_qmax from quantization range. - # The actual values of initial_qmin and initial_qmax will be reset below. - initial_quant_min, initial_quant_max = 0, 255 - # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the - # attribute from Optional valid integers for use, based on TorchScript's requirements. - custom_quant_min, custom_quant_max = self.quant_min, self.quant_max - if custom_quant_min is not None and custom_quant_max is not None: - initial_quant_min, initial_quant_max = ( - custom_quant_min, - custom_quant_max, - ) - - qrange_len = initial_quant_max - initial_quant_min + 1 - if is_symmetric_quant(self.qscheme): - quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1 - else: - quant_min, quant_max = 0, qrange_len - 1 - if self.reduce_range: - quant_min, quant_max = quant_min // 2, quant_max // 2 - if self.not_calc_quant_min_max: - quant_min, quant_max = self.quant_min, self.quant_max + quant_min, quant_max = self.quant_min, self.quant_max else: # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. if self.dtype == torch.qint8: diff --git a/requirements.txt b/requirements.txt index 7a2a16c..7d0d4b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -torch==1.8.1 -torchvision==0.9.1 -onnx-simplifier +torch==1.10.0 +torchvision==0.11.1 onnx diff --git a/test/model/test_model.py b/test/model/test_model.py index 8c6f260..392be45 100644 --- a/test/model/test_model.py +++ b/test/model/test_model.py @@ -11,11 +11,12 @@ class TestQuantizeModel(unittest.TestCase): def test_model_ppl(self): - exclude_list = ['googlenet', 'deeplabv3_mobilenet_v3_large', 'inception_v3', 'lraspp_mobilenet_v3_large', - 'mobilenet_v3_large', 'mobilenet_v3_small'] + test_model_list = ['alexnet', 'deeplabv3_resnet50', 'densenet121', 'fcn_resnet50', 'mnasnet0_5', + 'mobilenet_v2', 'resnet18', 'resnext50_32x4d', 'shufflenet_v2_x0_5', 'squeezenet1_0', + 'vgg11', 'vgg11_bn', 'wide_resnet50_2', 'regnet_x_400mf'] entrypoints = torch.hub.list(GITHUB_RES, force_reload=False) for entrypoint in entrypoints: - if entrypoint in exclude_list: + if entrypoint not in test_model_list: continue logger.info(f'testing {entrypoint}') if 'deeplab' in entrypoint or 'fcn' in entrypoint: diff --git a/test/version.py b/test/version.py index 25254b6..052cff8 100644 --- a/test/version.py +++ b/test/version.py @@ -1,2 +1,2 @@ -TORCHVISION_VERSION = 'v0.9.0' -GITHUB_RES = 'pytorch/vision:{}'.format(TORCHVISION_VERSION) \ No newline at end of file +TORCHVISION_VERSION = 'v0.11.1' +GITHUB_RES = 'pytorch/vision:{}'.format(TORCHVISION_VERSION)