Skip to content

Commit

Permalink
Improve weight compression time and memory for BF16 OV models (#2949)
Browse files Browse the repository at this point in the history
### Changes

Avoid retrieving weight data if not necessary.

For llama-3-8b with BF16 weights there is about 1.5x time and 1.26x peak
memory improvement. Please see the figure below.
| Before         | After   |
|-------------------|-------|
|
![system_memory_usage_from-zero](https://github.com/user-attachments/assets/910e5525-c1c0-4f18-a379-2ae87a2648ba)
|
![system_memory_usage_from-zero](https://github.com/user-attachments/assets/589c0e46-acd0-4a33-91ed-843e631d5a8d)
|

### Reason for changes

For BF16 dtype when constant data is fetched, it is loaded into memory
resulting in additional time and memory overhead.
  • Loading branch information
nikita-savelyevv authored Sep 5, 2024
1 parent 187dda2 commit 6f1b2dd
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 41 deletions.
1 change: 1 addition & 0 deletions nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def create_nncf_graph(model: ov.Model) -> NNCFGraph:
const_attrs[const_port_id] = {
"name": const_node.get_friendly_name(),
"shape": tuple(const_node.get_output_shape(0)),
"dtype": const_node.output(0).get_element_type().get_type_name(),
}

if metatype == OVMatMulMetatype:
Expand Down
13 changes: 8 additions & 5 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from collections import defaultdict
from functools import reduce
from typing import Dict, List, Optional, OrderedDict, Tuple, TypeVar

import nncf
Expand Down Expand Up @@ -326,14 +327,16 @@ def apply(
is_last_layer_shared = True
continue

weight = self._backend_entity.get_weight(node, weight_port_id, model, graph)
if weight.dtype not in [
weight_dtype = self._backend_entity.get_weight_dtype(node, weight_port_id, model, graph)
if weight_dtype not in [
TensorDataType.float16,
TensorDataType.bfloat16,
TensorDataType.float32,
TensorDataType.float64,
]:
continue
weight_shape = self._backend_entity.get_weight_shape(node, weight_port_id, graph)
weight_size = reduce(operator.mul, weight_shape, 1)
reduction_axes = self._backend_entity.get_reduction_axes(node, weight_port_id, graph)
if (
self._group_size != -1
Expand All @@ -348,12 +351,12 @@ def apply(
# MatMul ops can't have multiple reduction axes.
nncf_logger.warning(
f"Weight compression expects a single reduction axis, but {len(reduction_axes)} given. "
f"Weight shape: {weight.shape}, reduction axes: {reduction_axes}, "
f"Weight shape: {weight_shape}, reduction axes: {reduction_axes}, "
f"node name: {node.node_name}. The node will be asymmetrically quantized to 8 bits."
)

weight_params = WeightCompressionParameters(
weight_name, node, weight_port_id, weight.size, reduction_axes
weight_name, node, weight_port_id, weight_size, reduction_axes
)
all_weight_params.append(weight_params)
weight_names.add(weight_name)
Expand Down
27 changes: 27 additions & 0 deletions nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -93,6 +94,32 @@ def get_weight(self, node_with_weight: NNCFNode, weight_port_id: int, model: TMo
:return: The weight tensor.
"""

@abstractmethod
def get_weight_dtype(
self, node_with_weight: NNCFNode, weight_port_id: int, model: TModel, graph: NNCFGraph
) -> TensorDataType:
"""
Returns a weight data type associated with the given node on the given port id.
:param node_with_weight: The node with weight.
:param weight_port_id: The weight port id for given node with weight.
:param model: The model.
:param graph: The model graph associated with the model.
:return: The weight data type.
"""

@staticmethod
@abstractmethod
def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Tuple:
"""
Returns a weight shape associated with the given node on the given port id.
:param node_with_weight: The node with weight.
:param weight_port_id: The weight port id for given node with weight.
:param graph: The model graph associated with the model.
:return: The weight shape.
"""

@abstractmethod
def set_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: TModel, graph: NNCFGraph, weight: Tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,26 @@ def get_weight(self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.
weight_tensor = get_const_value(weight_node)
return Tensor(weight_tensor)

def get_weight_dtype(
self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.Model, graph: NNCFGraph
) -> TensorDataType:
ov_type_name = node_with_weight.layer_attributes.constant_attributes[weight_port_id]["dtype"]
dtype_map = {
"f16": TensorDataType.float16,
"bf16": TensorDataType.bfloat16,
"f32": TensorDataType.float32,
"f64": TensorDataType.float64,
"i8": TensorDataType.int8,
"i32": TensorDataType.int32,
"i64": TensorDataType.int64,
"u8": TensorDataType.uint8,
}
return dtype_map.get(ov_type_name)

@staticmethod
def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Tuple:
return node_with_weight.layer_attributes.constant_attributes[weight_port_id]["shape"]

def set_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: ov.Model, graph: NNCFGraph, weight: Tensor
):
Expand Down
10 changes: 10 additions & 0 deletions nncf/quantization/algorithms/weight_compression/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def get_weight(

return Tensor(weight)

def get_weight_dtype(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph
) -> TensorDataType:
return self.get_weight(node_with_weight, weight_port_id, model, graph).dtype

@staticmethod
def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNCFGraph) -> Tuple:
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
return tuple(weight_node.layer_attributes.shape)

def set_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor
):
Expand Down
6 changes: 3 additions & 3 deletions tests/openvino/native/quantization/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def create_target_point(self, target_point_type: TargetType, name: str, port_id:

class TestOVGetTargetPointShape(TemplateTestGetTargetPointShape, TestOVMinMaxAlgorithm):
def get_nncf_graph(self, weight_port_id: int, weight_shape: Tuple[int]) -> NNCFGraph:
conv_layer_attrs = OVLayerAttributes({weight_port_id: {"name": "dummy", "shape": weight_shape}})
conv_layer_attrs = OVLayerAttributes({weight_port_id: {"name": "dummy", "shape": weight_shape, "dtype": "f32"}})
return NNCFGraphToTest(OVConvolutionMetatype, conv_layer_attrs).nncf_graph


Expand All @@ -58,7 +58,7 @@ def matmul_metatype(self):

@staticmethod
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> OVLayerAttributes:
constant_attributes = {weight_port_id: {"name": "dummy", "shape": weight_shape}}
constant_attributes = {weight_port_id: {"name": "dummy", "shape": weight_shape, "dtype": "f32"}}
return OVLayerAttributes(constant_attributes, {}, {})

@staticmethod
Expand All @@ -69,7 +69,7 @@ def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int])
def get_matmul_node_attrs(
weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]
) -> OVLayerAttributes:
constant_attributes = {weight_port_id: {"name": "dummy", "shape": weight_shape}}
constant_attributes = {weight_port_id: {"name": "dummy", "shape": weight_shape, "dtype": "f32"}}
constant_attributes[weight_port_id]["transpose"] = transpose_weight
return OVLayerAttributes(constant_attributes, {}, {})

Expand Down
4 changes: 2 additions & 2 deletions tests/openvino/native/quantization/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_algo_backend(self):

@pytest.fixture
def single_conv_nncf_graph(self) -> NNCFGraphToTest:
conv_layer_attrs = OVLayerAttributes({0: {"name": "dummy", "shape": (4, 4, 4, 4)}})
conv_layer_attrs = OVLayerAttributes({0: {"name": "dummy", "shape": (4, 4, 4, 4), "dtype": "f32"}})
return NNCFGraphToTest(OVConvolutionMetatype, conv_layer_attrs)

@pytest.fixture
Expand All @@ -37,5 +37,5 @@ def depthwise_conv_nncf_graph(self):

@pytest.fixture
def conv_sum_aggregation_nncf_graph(self) -> NNCFGraphToTestSumAggregation:
conv_layer_attrs = OVLayerAttributes({0: {"name": "dummy", "shape": (4, 4, 4, 4)}})
conv_layer_attrs = OVLayerAttributes({0: {"name": "dummy", "shape": (4, 4, 4, 4), "dtype": "f32"}})
return NNCFGraphToTestSumAggregation(OVConvolutionMetatype, OVSumMetatype, conv_layer_attrs)
32 changes: 16 additions & 16 deletions tests/openvino/native/test_layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class LayerAttributesTestCase:
(1, 3, 3, 3),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (4, 3, 2, 1)}},
{1: {"name": "Const", "shape": (4, 3, 2, 1), "dtype": "f32"}},
ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=3,
Expand All @@ -193,7 +193,7 @@ class LayerAttributesTestCase:
(1, 3, 3, 3),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (4, 3, 1, 1)}},
{1: {"name": "Const", "shape": (4, 3, 1, 1), "dtype": "f32"}},
ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=3,
Expand All @@ -220,7 +220,7 @@ class LayerAttributesTestCase:
(1, 3, 3, 3),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (3, 3, 1, 1, 1)}},
{1: {"name": "Const", "shape": (3, 3, 1, 1, 1), "dtype": "f32"}},
ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=1,
Expand Down Expand Up @@ -248,7 +248,7 @@ class LayerAttributesTestCase:
(1, 10, 3, 3),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (5, 10, 2, 1, 1)}},
{1: {"name": "Const", "shape": (5, 10, 2, 1, 1), "dtype": "f32"}},
ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=2,
Expand Down Expand Up @@ -276,7 +276,7 @@ class LayerAttributesTestCase:
(1, 3, 3, 3),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (3, 4, 2, 1)}},
{1: {"name": "Const", "shape": (3, 4, 2, 1), "dtype": "f32"}},
ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=3,
Expand All @@ -303,7 +303,7 @@ class LayerAttributesTestCase:
(1, 3, 3, 3),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (3, 1, 3, 1, 1)}},
{1: {"name": "Const", "shape": (3, 1, 3, 1, 1), "dtype": "f32"}},
ConvolutionLayerAttributes(
weight_requires_grad=False,
in_channels=1,
Expand Down Expand Up @@ -335,7 +335,7 @@ class LayerAttributesTestCase:
(1, 3, 4),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (1, 4), "transpose": True}},
{1: {"name": "Const", "shape": (1, 4), "dtype": "f32", "transpose": True}},
LinearLayerAttributes(
weight_requires_grad=False,
in_features=4,
Expand All @@ -352,7 +352,7 @@ class LayerAttributesTestCase:
(1, 3, 4),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (3, 1), "transpose": False}},
{1: {"name": "Const", "shape": (3, 1), "dtype": "f32", "transpose": False}},
LinearLayerAttributes(
weight_requires_grad=False,
in_features=3,
Expand All @@ -369,7 +369,7 @@ class LayerAttributesTestCase:
(1, 3, 4),
1,
OVLayerAttributes(
{0: {"name": "Const", "shape": (3, 1), "transpose": True}},
{0: {"name": "Const", "shape": (3, 1), "dtype": "f32", "transpose": True}},
LinearLayerAttributes(
weight_requires_grad=False,
in_features=3,
Expand All @@ -386,7 +386,7 @@ class LayerAttributesTestCase:
(1, 3, 4),
1,
OVLayerAttributes(
{0: {"name": "Const", "shape": (1, 4), "transpose": False}},
{0: {"name": "Const", "shape": (1, 4), "dtype": "f32", "transpose": False}},
LinearLayerAttributes(
weight_requires_grad=False,
in_features=4,
Expand All @@ -403,7 +403,7 @@ class LayerAttributesTestCase:
(1, 3, 4),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (4,), "transpose": False}},
{1: {"name": "Const", "shape": (4,), "dtype": "f32", "transpose": False}},
LinearLayerAttributes(
weight_requires_grad=False,
in_features=4,
Expand All @@ -425,7 +425,7 @@ class LayerAttributesTestCase:
(1, 3, 4, 5),
0,
OVLayerAttributes(
{1: {"name": "Const", "shape": (1, 1, 1, 1)}},
{1: {"name": "Const", "shape": (1, 1, 1, 1), "dtype": "f32"}},
GenericWeightedLayerAttributes(False, weight_shape=(1, 1, 1, 1)),
{},
),
Expand All @@ -438,10 +438,10 @@ class LayerAttributesTestCase:
0,
OVLayerAttributes(
{
1: {"name": "hs", "shape": (2, 1, 4)},
2: {"name": "cs", "shape": (2, 1, 4)},
4: {"name": "w", "shape": (1, 16, 4)},
5: {"name": "r", "shape": (1, 16, 4)},
1: {"name": "hs", "shape": (2, 1, 4), "dtype": "f32"},
2: {"name": "cs", "shape": (2, 1, 4), "dtype": "f32"},
4: {"name": "w", "shape": (1, 16, 4), "dtype": "f32"},
5: {"name": "r", "shape": (1, 16, 4), "dtype": "f32"},
},
None,
{},
Expand Down
30 changes: 15 additions & 15 deletions tests/openvino/native/test_node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,29 +77,29 @@ def test_is_node_with_bias(model_to_create, is_with_bias, node_name):


@pytest.mark.parametrize(
"weights_port_id, transpose, shape, expected_channel_axes",
"weights_port_id, transpose, shape, dtype, expected_channel_axes",
[
(0, False, (1,), []),
(0, True, (1,), []),
(1, False, (1,), []),
(1, True, (1,), []),
(0, False, (1, 1), [0]),
(0, True, (1, 1), [1]),
(1, False, (1, 1), [1]),
(1, True, (1, 1), [0]),
(0, False, (1, 1, 1, 1), [0, 1, 2]),
(0, True, (1, 1, 1, 1), [0, 1, 3]),
(1, False, (1, 1, 1, 1), [0, 1, 3]),
(1, True, (1, 1, 1, 1), [0, 1, 2]),
(0, False, (1,), "f32", []),
(0, True, (1,), "f32", []),
(1, False, (1,), "f32", []),
(1, True, (1,), "f32", []),
(0, False, (1, 1), "f32", [0]),
(0, True, (1, 1), "f32", [1]),
(1, False, (1, 1), "f32", [1]),
(1, True, (1, 1), "f32", [0]),
(0, False, (1, 1, 1, 1), "f32", [0, 1, 2]),
(0, True, (1, 1, 1, 1), "f32", [0, 1, 3]),
(1, False, (1, 1, 1, 1), "f32", [0, 1, 3]),
(1, True, (1, 1, 1, 1), "f32", [0, 1, 2]),
],
)
def test_get_weight_channel_axes_for_matmul(weights_port_id, transpose, shape, expected_channel_axes):
def test_get_weight_channel_axes_for_matmul(weights_port_id, transpose, shape, dtype, expected_channel_axes):
input_1 = opset.parameter([1, 1], name="Input", dtype=np.float32)
constant_1 = opset.constant(np.ones(shape).astype(np.float32))
inputs_ = (input_1, constant_1) if weights_port_id == 1 else (constant_1, input_1)
matmul_1 = opset.matmul(*inputs_, transpose_a=transpose, transpose_b=transpose, name="MatMul")

constant_attrs = {weights_port_id: {"transpose": transpose, "shape": shape}}
constant_attrs = {weights_port_id: {"transpose": transpose, "shape": shape, "dtype": dtype}}
attributes = {
NNCFNode.ID_NODE_ATTR: 0,
NNCFNode.NODE_NAME_ATTR: "test",
Expand Down

0 comments on commit 6f1b2dd

Please sign in to comment.