Skip to content

Commit

Permalink
Avoid squeeze for seq_len equal to one. (#2821)
Browse files Browse the repository at this point in the history
### Changes

Changed squeeze logic for statistics postprocessing. Some models has
internal processing which leads to variational seq_len and in this case
we have statistics with different dimension if one off them has seq_len
equal to one.

### Reason for changes

Bug with AWQ/SE for mixtral-8x7b-v0.1

### Related tickets



### Tests


tests/openvino/native/quantization/test_weights_compression.py::test_one_dimentional_samples
  • Loading branch information
andreyanufr authored Sep 9, 2024
1 parent 6f1b2dd commit 67d6386
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
29 changes: 28 additions & 1 deletion nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,31 @@ def input_filter_func(point):
self._fp_inputs[input_id] = input_fp
return self._fp_inputs[input_id]

@staticmethod
def _get_dynamic_shape(x: List[Tensor]):
"""
Compute common shape for set of tensors.
For example: return [-1, 10] for tensors with shapes [[1, 10], [5, 10], [100, 10]]
or [-1, 10] for tensors with shapes [[1, 1, 10], [1, 1, 10], [1, 100, 10]]
:param x: (List[Tensor]): Set of tensors.
:return: resulting shape with -1 for dimension with dynamic axis,
common size for dimension with static axis if size > 1 else None.
"""
if len(x) == 0:
return []
res = list(x[0].shape)
sz = len(res)

for i in x:
i_shape = i.shape
for j in range(sz):
if i_shape[j] != res[j]:
res[j] = -1
res = [i for i in res if i != 1]

return res

def _get_activations(
self, dataset: Dataset, subset_size: int, nodes_to_compress: List[NNCFNode], graph: NNCFGraph, model: TModel
) -> Dict[str, List[Tensor]]:
Expand Down Expand Up @@ -561,7 +586,9 @@ def _get_activations(
for node_name, output_id in _collected_stat_inputs_map.items():
act_node_name, output_port_id = output_id
x_fp = self._get_fp_inputs(statistic_container, node_name=act_node_name, port_id=output_port_id)
x_fp = [i.squeeze() for i in x_fp] # List[tensor(seq_length, hidden_dim)]

d_shape = self._get_dynamic_shape(x_fp)
x_fp = [i.reshape(d_shape) for i in x_fp] # List[tensor(seq_length, hidden_dim)]
activations[node_name] = x_fp

for shared_node_name in act_vs_shared_node_names_mapping[act_node_name]:
Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def get_weights(weights_data, is_int8, name):
return (qw - zp) * scale

def _create_ov_model(self, is_int8=False):
input_node = opset.parameter([8, 8], name="Input_1")
input_node = opset.parameter([-1, 8], name="Input_1")

weights_data1 = np.arange(0, 64).reshape(8, 8)
weights_data1[:] = 2.0
Expand Down
15 changes: 15 additions & 0 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,21 @@ def test_call_max_var_criterion_with_dataset_gptq_neg_group_size(mode):
assert op.get_shape() == [sz, 1]


@pytest.mark.parametrize("mode", INT4_MODES)
def test_one_dimentional_samples(mode):
model = AWQMatmulModel().ov_model
sz = 8
n_samples = 10
dataset = Dataset([np.ones([i + 1, sz]) for i in range(n_samples)])

compressed_model = compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, awq=True)

for op in compressed_model.get_ordered_ops():
op_name = op.get_friendly_name()
if op.get_type_name() == "Constant" and ("/zero_point" in op_name or "/scale" in op_name):
assert op.get_shape() == [sz, 1]


def get_shape_for_second_input(op_with_weights: ov.Node) -> List[int]:
return list(op_with_weights.inputs()[1].get_shape())

Expand Down

0 comments on commit 67d6386

Please sign in to comment.