Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT] add metatype for torch2 #3107

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@


from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import networkx as nx # type: ignore
import torch
from torch import nn

import nncf
import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import Dtype
Expand Down Expand Up @@ -76,6 +77,20 @@ def get_dtype(dtype: torch.dtype) -> Dtype:
return Dtype.INTEGER


def get_meta_type(node_type: str, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> om.PTOperatorMetatype:
"""
Converts the node type and metadata into a PTOperatorMetatype object.
:param node_type: The type of the node.
:param meta: The metadata associated with the node.
:return: The PTOperatorMetatype object.
"""
node_metatype = cast(om.PTOperatorMetatype, om.PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type))
node_sub_meta_type: Optional[om.PTOperatorMetatype] = None
if node_metatype.get_subtypes() and isinstance(meta, FunctionMeta):
node_sub_meta_type = node_metatype.determine_subtype(function_args=meta.args, functions_kwargs=meta.kwargs)
return node_sub_meta_type or node_metatype


def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
"""
Converts a graph to an NNCFGraph.
Expand All @@ -88,12 +103,14 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
for node, data in nx_graph.nodes(data=True):
meta: Union[ConstMeta, FunctionMeta, InOutMeta] = data["meta"]
node_name = get_name_of_node(meta)
node_type = get_node_type(data["type"], meta)
node_metatype = None # TODO(AlexanderDokuchaev): add node_metatype
meta_type = get_meta_type(node_type, meta)

nncf_node = nncf_graph.add_nncf_node(
node_name=get_name_of_node(meta),
node_name=node_name,
node_type=node_type,
node_metatype=node_metatype, # type: ignore[arg-type]
node_metatype=meta_type, # type: ignore[arg-type]
)
map_nx_node_to_nncf_node[node] = nncf_node

Expand Down
4 changes: 3 additions & 1 deletion nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Base subpackage for NNCF PyTorch functionality.
"""

import os
from nncf import nncf_logger
from nncf.common.logging.logger import warn_bkc_version_mismatch

Expand Down Expand Up @@ -76,4 +77,5 @@
if torch.__version__ >= "2.5.0":
from torch._dynamo.polyfills import loader

patch_torch_operators()
if os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is None:
patch_torch_operators()
7 changes: 7 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ class PTDepthwiseConvOperatorSubtype(PTOperatorSubtype):
def matches(
cls, layer_attributes: Optional[BaseLayerAttributes] = None, function_args=None, functions_kwargs=None
) -> bool:
if layer_attributes is None and function_args is not None and functions_kwargs is not None:
# Used for torch2
weight_meta = functions_kwargs.get("weight", function_args[0])
in_channels = weight_meta.shape[1]
groups = functions_kwargs.get("groups", function_args[6] if len(function_args) > 6 else 1)
return in_channels > 1 and groups == in_channels

if _is_called_inside_nncf_module(functions_kwargs):
return False
if not isinstance(layer_attributes, ConvolutionLayerAttributes):
Expand Down
3 changes: 3 additions & 0 deletions tests/torch2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def pytest_configure(config: Config) -> None:
if nncf_debug:
set_log_level(logging.DEBUG)

# Disable patching of torch functions
os.environ["NNCF_EXPERIMENTAL_TORCH_TRACING"] = "1"


@pytest.fixture
def regen_ref_data(request: FixtureRequest):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
strict digraph {
"0 x" [id=0, type=nncf_model_input];
"1 conv.weight" [id=1, type=nncf_model_const];
"2 conv.bias" [id=2, type=nncf_model_const];
"3 conv/conv2d/0" [id=3, type=conv2d];
"4 __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" [id=4, type=nncf_model_const];
"5 conv/post_hook__conv-conv2d-0__0[0]/add/0" [id=5, type=add];
"6 /relu/0" [id=6, type=relu];
"7 output" [id=7, type=nncf_model_output];
"0 x" -> "3 conv/conv2d/0" [label="(1, 1, 3, 3)", style=solid];
"1 conv.weight" -> "3 conv/conv2d/0" [label="(1, 1, 1, 1)", style=solid];
"2 conv.bias" -> "3 conv/conv2d/0" [label="(1,)", style=solid];
"3 conv/conv2d/0" -> "5 conv/post_hook__conv-conv2d-0__0[0]/add/0" [label="(1, 1, 3, 3)", style=solid];
"4 __nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" -> "5 conv/post_hook__conv-conv2d-0__0[0]/add/0" [label="(1,)", style=solid];
"5 conv/post_hook__conv-conv2d-0__0[0]/add/0" -> "6 /relu/0" [label="(1, 1, 3, 3)", style=solid];
"6 /relu/0" -> "7 output" [label="(1, 1, 3, 3)", style=solid];
x [id=0, metatype=PTInputNoopMetatype, type=nncf_model_input];
"conv.weight" [id=1, metatype=PTConstNoopMetatype, type=nncf_model_const];
"conv.bias" [id=2, metatype=PTConstNoopMetatype, type=nncf_model_const];
"conv/conv2d/0" [id=3, metatype=PTConv2dMetatype, type=conv2d];
"__nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" [id=4, metatype=PTConstNoopMetatype, type=nncf_model_const];
"conv/post_hook__conv-conv2d-0__0[0]/add/0" [id=5, metatype=PTAddMetatype, type=add];
"/relu/0" [id=6, metatype=PTRELUMetatype, type=relu];
output [id=7, metatype=PTOutputNoopMetatype, type=nncf_model_output];
x -> "conv/conv2d/0" [dtype=float, shape="(1, 1, 3, 3)"];
"conv.weight" -> "conv/conv2d/0" [dtype=float, shape="(1, 1, 1, 1)"];
"conv.bias" -> "conv/conv2d/0" [dtype=float, shape="(1,)"];
"conv/conv2d/0" -> "conv/post_hook__conv-conv2d-0__0[0]/add/0" [dtype=float, shape="(1, 1, 3, 3)"];
"__nncf_hooks.post_hooks.conv/conv2d/0__0.0.w" -> "conv/post_hook__conv-conv2d-0__0[0]/add/0" [dtype=float, shape="(1,)"];
"conv/post_hook__conv-conv2d-0__0[0]/add/0" -> "/relu/0" [dtype=float, shape="(1, 1, 3, 3)"];
"/relu/0" -> output [dtype=float, shape="(1, 1, 3, 3)"];
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
strict digraph {
"0 x" [id=0, type=nncf_model_input];
"1 /add/0" [id=1, type=add];
"2 output" [id=2, type=nncf_model_output];
"0 x" -> "1 /add/0" [label="parallel_input_port_ids [1], shape (1, 1)", style=solid];
"1 /add/0" -> "2 output" [label="(1, 1)", style=solid];
x [id=0, metatype=PTInputNoopMetatype, type=nncf_model_input];
"/add/0" [id=1, metatype=PTAddMetatype, type=add];
output [id=2, metatype=PTOutputNoopMetatype, type=nncf_model_output];
x -> "/add/0" [dtype=float, parallel_input_port_ids="[1]", shape="(1, 1)"];
"/add/0" -> output [dtype=float, shape="(1, 1)"];
}
Loading
Loading