Skip to content

Commit

Permalink
Revert "[quant][pt2e][bc-breaking] Remove fold_quantize flag (pytorch…
Browse files Browse the repository at this point in the history
…#118701)"

This reverts commit 482d952.

Reverted pytorch#118701 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#118701 (comment)))
  • Loading branch information
pytorchmergebot committed Feb 7, 2024
1 parent a6e16fe commit 81abc2b
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 25 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _generate_qdq_quantized_model(self, mod, inputs, is_qat=False):
else prepare_pt2e(export_model, quantizer)
)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model)
convert_model = convert_pt2e(prepare_model, fold_quantize=True)
torch.ao.quantization.move_exported_model_to_eval(convert_model)
return convert_model

Expand Down
2 changes: 1 addition & 1 deletion test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _test_duplicate_dq(
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

pt2_quant_output = m(*example_inputs)
for n in m.graph.nodes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def test_quantize_pt2e_preserve_handle(self):
debug_handle_map = _extract_conv2d_pattern_debug_handle_map(m)
self.assertEqual(debug_handle_map, debug_handle_map_ref)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
debug_handle_map = _extract_conv2d_pattern_debug_handle_map(m)
self.assertEqual(debug_handle_map, debug_handle_map_ref)
2 changes: 1 addition & 1 deletion test/quantization/pt2e/test_metadata_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _test_metadata_porting(
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

pt2_quant_output = m(*example_inputs)
recorded_node_tags = {}
Expand Down
12 changes: 6 additions & 6 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def validate(self, model: torch.fx.GraphModule) -> None:
assert conv_output_obs[0] == conv_output_obs[1]

m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

node_occurrence = {
# two for input of the first conv, one for output for the first conv
Expand Down Expand Up @@ -739,7 +739,7 @@ def _test_transitive_sharing_with_cat_helper(self, quantizer):
assert conv_output_obs[0] == conv_output_obs[1]

m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

node_occurrence = {
# two for input of the first conv, one for output for the first conv
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def forward(self, x):

m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

for n in m.graph.nodes:
if n.op == "get_attr" and "frozen_param" in n.target:
Expand Down Expand Up @@ -1619,7 +1619,7 @@ def test_disallow_eval_train(self):
m.train()

# After convert: still not OK
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -1706,12 +1706,12 @@ def test_reentrant(self):
m.conv_bn_relu = capture_pre_autograd_graph(m.conv_bn_relu, example_inputs)
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
m(*example_inputs)
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu, fold_quantize=True)

quantizer = XNNPACKQuantizer().set_module_type(torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False))
m = capture_pre_autograd_graph(m, example_inputs)
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

node_occurrence = {
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4,
Expand Down
12 changes: 6 additions & 6 deletions test/quantization/pt2e/test_quantize_pt2e_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper(
if verify_convert:
# We don't want to impose any ordering requirements between move_exported_model_to_eval and convert_pt2e
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e, fold_quantize=True)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx.eval()
model_fx = _convert_to_reference_decomposed_fx(
Expand Down Expand Up @@ -631,7 +631,7 @@ def forward(self, x):
m = capture_pre_autograd_graph(m, example_inputs)
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

# Extract the conv and relu nodes (bn was folded into conv)
first_conv, first_relu, second_conv, second_relu = None, None, None, None
Expand Down Expand Up @@ -690,7 +690,7 @@ def test_qat_conv_bn_bias_derived_qspec(self):
quantizer = ConvBnDerivedBiasQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
m(*example_inputs)

# Assert that both weight and bias are quantized
Expand Down Expand Up @@ -737,7 +737,7 @@ def test_qat_per_channel_weight_custom_dtype(self):
quantizer = ConvBnInt32WeightQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
m(*example_inputs)

# Assert that conv weight is quantized per channel
Expand Down Expand Up @@ -972,7 +972,7 @@ def _convert_qat_linears(self, model):
for name, child in model.named_children():
if isinstance(child, torch.fx.GraphModule):
torch.ao.quantization.move_exported_model_to_eval(child)
converted_child = convert_pt2e(child)
converted_child = convert_pt2e(child, fold_quantize=True)
setattr(model, name, converted_child)
else:
self._convert_qat_linears(child)
Expand All @@ -999,7 +999,7 @@ def test_mixing_qat_ptq(self):
quantizer.set_global(quantization_config)
model_pt2e = prepare_pt2e(model_pt2e, quantizer)
after_prepare_result_pt2e = model_pt2e(*example_inputs)
model_pt2e = convert_pt2e(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e, fold_quantize=True)
quant_result_pt2e = model_pt2e(*example_inputs)

exported_model = torch.export.export(model_pt2e, example_inputs)
Expand Down
8 changes: 6 additions & 2 deletions test/quantization/pt2e/test_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def _test_representation(
model = prepare_pt2e(model, quantizer)
# Calibrate
model(*example_inputs)
model = convert_pt2e(model, use_reference_representation=True)
model = convert_pt2e(
model, use_reference_representation=True, fold_quantize=True
)
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
# make sure it runs
pt2e_quant_output = model(*example_inputs)
Expand All @@ -52,7 +54,9 @@ def _test_representation(
model_copy = prepare_pt2e(model_copy, quantizer)
# Calibrate
model_copy(*example_inputs)
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
model_copy = convert_pt2e(
model_copy, use_reference_representation=False, fold_quantize=True
)
self.checkGraphModuleNodes(
model_copy, expected_node_occurrence=non_ref_node_occurrence
)
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/pt2e/test_x86inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _test_quantizer(
# Calibrate
m(*example_inputs)
prepare_model = copy.deepcopy(m)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
convert_model = copy.deepcopy(m)
pt2_quant_output = m(*example_inputs)
node_occurrence = {
Expand Down
8 changes: 4 additions & 4 deletions test/quantization/pt2e/test_xnnpack_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_propagate_annotation(self):
output_act = getattr(m, next(iter(n.users)).target)
self.assertIs(input_act, output_act)

m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
node_occurrence = {
# input and output are using quantize_per_tensor and weight is using quantize_per_channel
ns.call_function(
Expand Down Expand Up @@ -723,7 +723,7 @@ def forward(self, input_tensor, hidden_tensor):
quantizer.set_global(quantization_config)
model_graph = prepare_pt2e(model_graph, quantizer)
model_graph(*example_inputs)
model_graph = convert_pt2e(model_graph)
model_graph = convert_pt2e(model_graph, fold_quantize=True)
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))

def test_linear_gru(self):
Expand Down Expand Up @@ -787,7 +787,7 @@ def forward(self, input_tensor, hidden_tensor):
quantizer.set_global(quantization_config)
model_graph = prepare_pt2e(model_graph, quantizer)
model_graph(*example_inputs)
model_graph = convert_pt2e(model_graph)
model_graph = convert_pt2e(model_graph, fold_quantize=True)
self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs))

def test_add_and_inplace_add(self):
Expand Down Expand Up @@ -968,7 +968,7 @@ def test_resnet18(self):
id(m.activation_post_process_3), id(m.activation_post_process_2)
)
after_prepare_result = m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

after_quant_result = m(*example_inputs)

Expand Down
6 changes: 6 additions & 0 deletions torch/ao/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,18 @@ def _quant_node_constraint(n: Node) -> bool:
def convert_pt2e(
model: GraphModule,
use_reference_representation: bool = False,
fold_quantize: bool = False,
) -> GraphModule:
"""Convert a calibrated/trained model to a quantized model
Args:
* `model` (torch.fx.GraphModule): calibrated/trained model
* `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not
* `fold_quantize` (bool): boolean flag to indicate whether fold the quantize op or not
Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and
make True the default option in the future, to make sure the change doesn't break BC for you, it's
better to set the flag to True now.
Returns:
quantized model, either in q/dq representation or reference representation
Expand Down
4 changes: 2 additions & 2 deletions torch/testing/_internal/common_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ def _test_quantizer(
m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)

pt2_quant_output = m(*example_inputs)
ns = NodeSpec
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def _quantize(self, m, quantizer, example_inputs):
)
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m = convert_pt2e(m, fold_quantize=True)
return m

def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule:
Expand Down

0 comments on commit 81abc2b

Please sign in to comment.