diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 46b01b015de..6452f69fc1b 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -506,6 +506,31 @@ def input_filter_func(point): self._fp_inputs[input_id] = input_fp return self._fp_inputs[input_id] + @staticmethod + def _get_dynamic_shape(x: List[Tensor]): + """ + Compute common shape for set of tensors. + For example: return [-1, 10] for tensors with shapes [[1, 10], [5, 10], [100, 10]] + or [-1, 10] for tensors with shapes [[1, 1, 10], [1, 1, 10], [1, 100, 10]] + :param x: (List[Tensor]): Set of tensors. + + :return: resulting shape with -1 for dimension with dynamic axis, + common size for dimension with static axis if size > 1 else None. + """ + if len(x) == 0: + return [] + res = list(x[0].shape) + sz = len(res) + + for i in x: + i_shape = i.shape + for j in range(sz): + if i_shape[j] != res[j]: + res[j] = -1 + res = [i for i in res if i != 1] + + return res + def _get_activations( self, dataset: Dataset, subset_size: int, nodes_to_compress: List[NNCFNode], graph: NNCFGraph, model: TModel ) -> Dict[str, List[Tensor]]: @@ -561,7 +586,9 @@ def _get_activations( for node_name, output_id in _collected_stat_inputs_map.items(): act_node_name, output_port_id = output_id x_fp = self._get_fp_inputs(statistic_container, node_name=act_node_name, port_id=output_port_id) - x_fp = [i.squeeze() for i in x_fp] # List[tensor(seq_length, hidden_dim)] + + d_shape = self._get_dynamic_shape(x_fp) + x_fp = [i.reshape(d_shape) for i in x_fp] # List[tensor(seq_length, hidden_dim)] activations[node_name] = x_fp for shared_node_name in act_vs_shared_node_names_mapping[act_node_name]: diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index 77ec443274f..bc5b6e00d61 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -950,7 +950,7 @@ def get_weights(weights_data, is_int8, name): return (qw - zp) * scale def _create_ov_model(self, is_int8=False): - input_node = opset.parameter([8, 8], name="Input_1") + input_node = opset.parameter([-1, 8], name="Input_1") weights_data1 = np.arange(0, 64).reshape(8, 8) weights_data1[:] = 2.0 diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index c51cf667ca2..c39626b2a8e 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -1018,6 +1018,21 @@ def test_call_max_var_criterion_with_dataset_gptq_neg_group_size(mode): assert op.get_shape() == [sz, 1] +@pytest.mark.parametrize("mode", INT4_MODES) +def test_one_dimentional_samples(mode): + model = AWQMatmulModel().ov_model + sz = 8 + n_samples = 10 + dataset = Dataset([np.ones([i + 1, sz]) for i in range(n_samples)]) + + compressed_model = compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, awq=True) + + for op in compressed_model.get_ordered_ops(): + op_name = op.get_friendly_name() + if op.get_type_name() == "Constant" and ("/zero_point" in op_name or "/scale" in op_name): + assert op.get_shape() == [sz, 1] + + def get_shape_for_second_input(op_with_weights: ov.Node) -> List[int]: return list(op_with_weights.inputs()[1].get_shape())