Skip to content

Commit

Permalink
[Update] Conv3d && Fix bugs. (#146)
Browse files Browse the repository at this point in the history
* [Update] Conv3d && Fix bugs.

* [REQ] Update torch version to 1.10.0.

* [REQ] Update torchvision version to 0.11.1.

* [Update] Update observer.

* [Test] Update test.

Co-authored-by: zhangqi3 <[email protected]>
Co-authored-by: fanyunqian <[email protected]>
  • Loading branch information
3 people authored Jul 22, 2022
1 parent 55c304b commit 791d444
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name: Lint and test.
name: Lint and test

on: [push]

jobs:
build-linux:
Lint-and-test:
runs-on: ubuntu-latest
strategy:
max-parallel: 5
Expand All @@ -22,6 +22,9 @@ jobs:
run: |
conda install flake8
flake8 .
- name: Install onnxruntime and onnxsim
run:
pip install onnxruntime onnx-simplifier
- name: Install Protobuf
run:
conda install protobuf=3.20.1
Expand Down
4 changes: 0 additions & 4 deletions mqbench/custom_quantizer/model_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,6 @@ def _convert(self, module, mapping=None, inplace=False, scope=''):
if not isinstance(mod, _FusedModule):
self._convert(mod, mapping, True, new_scope)
reassign[name] = swap_module(mod, mapping, {})
if isinstance(mod, torch.nn.ConvTranspose2d):
if hasattr(reassign[name], "weight_fake_quant") and reassign[name].weight_fake_quant.ch_axis != -1:
reassign[name].weight_fake_quant.ch_axis = 1
reassign[name].weight_fake_quant.activation_post_process.ch_axis = 1
for key, value in reassign.items():
module._modules[key] = value

Expand Down
7 changes: 3 additions & 4 deletions mqbench/custom_quantizer/onnx_qnn_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch.quantization.utils import get_combined_dict


import mqbench.nn as qnn
import mqbench.nn.intrinsic as qnni
import mqbench.nn as qnn
import mqbench.nn.intrinsic as qnni
from mqbench.utils.registry import register_model_quantizer
from mqbench.prepare_by_platform import BackendType
from mqbench.custom_quantizer import ModelQuantizer
Expand Down Expand Up @@ -57,7 +57,6 @@ def _qat_swap_modules(self, root: GraphModule, additional_qat_module_mapping: Di
get_default_qat_module_mappings(), additional_qat_module_mapping)
# There is no QLinearFC in ONNX for now.
del(all_mappings[torch.nn.modules.linear.Linear])
del(all_mappings[torch.nn.modules.linear._LinearWithBias])
del(all_mappings[torch.nn.intrinsic.modules.fused.LinearReLU])
del(all_mappings[qnni.modules.fused.LinearBn1d])
root = self._convert(root, all_mappings, inplace=True)
Expand Down Expand Up @@ -97,4 +96,4 @@ def implicit_merge_patterns(self) -> list:
# In reversed order!
return [
(torch.nn.ReLU, operator.add)
]
]
3 changes: 0 additions & 3 deletions mqbench/custom_quantizer/tensorrt_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,6 @@ def _find_act_quants(self, model: GraphModule) -> List:
if node.op == "call_function" and node.target == operator.add and \
self._is_skiped_add(node, modules, input_node_list):
continue
if node.op == "call_function" and node.target == operator.add:
import pdb
pdb.set_trace()
for _node in input_node_list:
if self._is_implicit_merge(modules, (node, _node)):
logger.info("Implicit merge: {} + {}".format(_node.name, node.name))
Expand Down
12 changes: 8 additions & 4 deletions mqbench/deploy/deploy_tengine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os
from collections import OrderedDict

import onnx
from onnx import numpy_helper
from onnxsim import simplify

from ..utils.logger import logger
from .deploy_linear import (
LinearQuantizer_process,
Expand All @@ -20,6 +16,14 @@
get_constant_inputs
)

import onnx
from onnx import numpy_helper
try:
from onnxsim import simplify
except ModuleNotFoundError:
logger.warn('onnxsim not found, if you want to use deploy_tengine, please install it.')



class Tengine_process(LinearQuantizer_process):

Expand Down
3 changes: 2 additions & 1 deletion mqbench/fuser_method_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, quantizer: QuantizerCls, node: Node):
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]


def fuse_linear_bn(linear, bn):
r"""Given the linear and bn modules, fuses them and returns the fused module
Expand Down Expand Up @@ -83,7 +84,6 @@ def fuse_deconv_bn_relu(deconv, bn, relu):
return qnni.ConvTransposeReLU2d(fuse_deconv_bn_eval(deconv, bn), relu)



def fuse_conv_freezebn(conv, bn):
assert(bn.training is False), "Freezebn must be eval."

Expand All @@ -100,6 +100,7 @@ def fuse_conv_freezebn(conv, bn):
else:
return nn.utils.fuse_conv_bn_eval(conv, bn)


def fuse_conv_freezebn_relu(conv, bn, relu):
assert(conv.training == relu.training and bn.training is False), "Conv and relu both must be in the same mode (train or eval) and bn must be eval."
fused_module : Optional[Type[nn.Sequential]] = None
Expand Down
34 changes: 29 additions & 5 deletions mqbench/fusion_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,37 @@ def convert_qnniqat_linearbn(model, fused_node):

@register_convert_function(qnniqat.ConvFreezebn2d)
@register_convert_function(nniqat.ConvBn2d)
@register_convert_function(nniqat.ConvBn3d)
def convert_nniqat_convbn(model, fused_node):
"""nniqat.ConvBn2d ----> nn.Conv2d ----> nniqat.Conv2d
"""
fused_module_class_map = {
qnniqat.ConvFreezebn2d: torch.nn.Conv2d,
qnniqat.ConvFreezebnReLU2d: torch.nn.Conv2d,
nniqat.ConvBn2d: torch.nn.Conv2d,
nniqat.ConvBnReLU2d: torch.nn.Conv2d,
nniqat.ConvBn3d: torch.nn.Conv3d,
nniqat.ConvBnReLU3d: torch.nn.Conv3d,
}
fused_qat_module_class_map = {
torch.nn.Conv2d: torch.nn.qat.Conv2d,
torch.nn.Conv3d: torch.nn.qat.Conv3d,
}
modules = dict(model.named_modules())
fused_module = modules[fused_node.target]
# Create a Conv2d from FusedModule.
conv = torch.nn.Conv2d(fused_module.in_channels, fused_module.out_channels, fused_module.kernel_size,
fused_module.stride, fused_module.padding, fused_module.dilation,
fused_module.groups, fused_module.bias is not None, fused_module.padding_mode)
conv = fused_module_class_map[type(fused_module)](fused_module.in_channels, fused_module.out_channels,
fused_module.kernel_size, fused_module.stride,
fused_module.padding, fused_module.dilation,
fused_module.groups, fused_module.bias is not None,
fused_module.padding_mode)
conv.weight = fused_module.weight
if fused_module.bias is not None:
conv.bias = fused_module.bias
fused_conv = fuse_conv_bn_eval(conv.eval(), fused_module.bn)
# We need nn.qat.conv here to export weight quantize node.
fused_conv.qconfig = fused_module.qconfig
fused_conv = torch.nn.qat.Conv2d.from_float(fused_conv)
fused_conv = fused_qat_module_class_map[type(conv)].from_float(fused_conv)
# Attach weight fake quantize params.
fused_conv.weight_fake_quant = fused_module.weight_fake_quant
conv_parent_name, conv_name = _parent_name(fused_node.target)
Expand All @@ -64,7 +81,8 @@ def convert_nniqat_convbn(model, fused_node):

@register_convert_function(qnniqat.ConvFreezebnReLU2d)
@register_convert_function(nniqat.ConvBnReLU2d)
def convert_nniqat_convbnrelu(model, fused_node):
@register_convert_function(nniqat.ConvBnReLU3d)
def convert_nniqat_convbnrelu(model, fused_node):
convert_nniqat_convbn(model, fused_node)
modules = dict(model.named_modules())
fused_module = modules[fused_node.target]
Expand Down Expand Up @@ -196,6 +214,9 @@ def convert_qnniqat_deconvbnrelu(model, fused_node):

@register_convert_function(qnniqat.ConvBn2d)
def convert_qnniqat_convbn(model, fused_node):
"""mqbench.nn.intrinsic.qat module add bias quant.
That is the difference between torch.nn.intrinsic.qat module.
"""
modules = dict(model.named_modules())
fused_module = modules[fused_node.target]
# Create a Conv2d from FusedModule.
Expand All @@ -222,6 +243,9 @@ def convert_qnniqat_convbn(model, fused_node):

@register_convert_function(qnniqat.ConvBnReLU2d)
def convert_qnniqat_convbnrelu(model, fused_node):
"""mqbench.nn.intrinsic.qat module add bias quant.
That is the difference between torch.nn.intrinsic.qat module.
"""
convert_qnniqat_convbn(model, fused_node)
modules = dict(model.named_modules())
fused_module = modules[fused_node.target]
Expand Down
37 changes: 9 additions & 28 deletions mqbench/observer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from functools import partial
from typing import Tuple
from copy import deepcopy

import torch
from torch.quantization.observer import _ObserverBase

Expand All @@ -28,12 +28,14 @@ class ObserverBase(_ObserverBase):
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False,
factory_kwargs=None):
factory_kwargs = deepcopy(factory_kwargs)
self.not_calc_quant_min_max = factory_kwargs.pop('not_calc_quant_min_max', False) if isinstance(factory_kwargs, dict) else False
# Since torch 1.10, function calculate_qmin_qmax is not a member function of observer,
# but import from utils. It is hard to control. We use try...except here.
stored_min, sotred_max = quant_min, quant_max
if quant_max is not None and quant_min is not None and (quant_max - quant_min + 1 > 256):
quant_min, quant_max = -128, 127
super(ObserverBase, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max)
# for compatibility with 1.10, prevent the value of self.quant_min,self.quant_max being modified
self.quant_min = quant_min
self.quant_max = quant_max
self.quant_min = stored_min
self.quant_max = sotred_max
self.quant_min, self.quant_max = self._calculate_qmin_qmax()
self.ch_axis = ch_axis
self.pot_scale = pot_scale
Expand Down Expand Up @@ -79,28 +81,7 @@ def _calculate_qmin_qmax(self) -> Tuple[int, int]:
observer datatype and if range is reduced.
"""
if self.has_customized_qrange:
# This initialization here is to be resolve TorchScript compilation issues and allow
# using of refinement to decouple initial_qmin and initial_qmax from quantization range.
# The actual values of initial_qmin and initial_qmax will be reset below.
initial_quant_min, initial_quant_max = 0, 255
# The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
# attribute from Optional valid integers for use, based on TorchScript's requirements.
custom_quant_min, custom_quant_max = self.quant_min, self.quant_max
if custom_quant_min is not None and custom_quant_max is not None:
initial_quant_min, initial_quant_max = (
custom_quant_min,
custom_quant_max,
)

qrange_len = initial_quant_max - initial_quant_min + 1
if is_symmetric_quant(self.qscheme):
quant_min, quant_max = -qrange_len // 2, qrange_len // 2 - 1
else:
quant_min, quant_max = 0, qrange_len - 1
if self.reduce_range:
quant_min, quant_max = quant_min // 2, quant_max // 2
if self.not_calc_quant_min_max:
quant_min, quant_max = self.quant_min, self.quant_max
quant_min, quant_max = self.quant_min, self.quant_max
else:
# Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
if self.dtype == torch.qint8:
Expand Down
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
torch==1.8.1
torchvision==0.9.1
onnx-simplifier
torch==1.10.0
torchvision==0.11.1
onnx
7 changes: 4 additions & 3 deletions test/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

class TestQuantizeModel(unittest.TestCase):
def test_model_ppl(self):
exclude_list = ['googlenet', 'deeplabv3_mobilenet_v3_large', 'inception_v3', 'lraspp_mobilenet_v3_large',
'mobilenet_v3_large', 'mobilenet_v3_small']
test_model_list = ['alexnet', 'deeplabv3_resnet50', 'densenet121', 'fcn_resnet50', 'mnasnet0_5',
'mobilenet_v2', 'resnet18', 'resnext50_32x4d', 'shufflenet_v2_x0_5', 'squeezenet1_0',
'vgg11', 'vgg11_bn', 'wide_resnet50_2', 'regnet_x_400mf']
entrypoints = torch.hub.list(GITHUB_RES, force_reload=False)
for entrypoint in entrypoints:
if entrypoint in exclude_list:
if entrypoint not in test_model_list:
continue
logger.info(f'testing {entrypoint}')
if 'deeplab' in entrypoint or 'fcn' in entrypoint:
Expand Down
4 changes: 2 additions & 2 deletions test/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
TORCHVISION_VERSION = 'v0.9.0'
GITHUB_RES = 'pytorch/vision:{}'.format(TORCHVISION_VERSION)
TORCHVISION_VERSION = 'v0.11.1'
GITHUB_RES = 'pytorch/vision:{}'.format(TORCHVISION_VERSION)

0 comments on commit 791d444

Please sign in to comment.