Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP [TorchFX] Bias fusing is removed from default transformations #3027

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,14 @@ def _apply_model_extraction(
def remap_fn(node: torch.fx.Node):
return value_remap.get(node) # noqa F821

visited_outputs_names = []
for node in model.graph.nodes:
if node.name not in visited or node.op == "output":
if node.name not in visited:
continue
if node.op == "output":
visited_outputs_names.append(node.name)
continue
value_remap[node] = extracted_graph.node_copy(node, remap_fn)
del value_remap

for input_name in transformation.input_node_names:
node_with_input = get_graph_node_by_name(extracted_graph, input_name)
Expand All @@ -149,7 +152,29 @@ def remap_fn(node: torch.fx.Node):
args[0] = graph_input
node_with_input.args = tuple(args)

nodes_with_output = [get_graph_node_by_name(extracted_graph, name) for name in transformation.output_node_names]
# Merge new output with the original output in case
# the original output is requested in the extracted graph.
nodes_with_output = []
for name in transformation.output_node_names:
nodes_with_output.append(
name if name in visited_outputs_names else get_graph_node_by_name(extracted_graph, name)
)

for idx, node in enumerate(nodes_with_output):
if isinstance(node, torch.fx.Node):
continue
output_node = get_graph_node_by_name(model.graph, node)
args = output_node.args[0]
if isinstance(args, torch.fx.Node):
args = value_remap[args]
else:
args = [value_remap[n] for n in args]
# Unpack target output args in case
# only one arg is presented.
if len(args) == 1:
args = args[0]
nodes_with_output[idx] = args

last_node = list(extracted_graph.nodes)[-1]
with extracted_graph.inserting_after(last_node):
graph_output_name = "output"
Expand Down
15 changes: 7 additions & 8 deletions nncf/experimental/torch/fx/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
:return: True if the node has a bias, False otherwise.
"""
# Assumes that all biases were unfused
if node.metatype in FX_OPERATORS_WITH_BIAS_METATYPES:
next_nodes = nncf_graph.get_next_nodes(node)
if len(next_nodes) != 1:
return False
return next_nodes[0].metatype in (om.PTAddMetatype,)
if node.metatype not in FX_OPERATORS_WITH_BIAS_METATYPES or len(nncf_graph.get_input_edges(node)) != 3:
return False
const_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
return const_node.metatype is om.PTConstNoopMetatype


def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor:
Expand All @@ -82,7 +81,7 @@ def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphM
:param model: Target GraphModule.
:return: Bias value of the given node.
"""
bias_node = nncf_graph.get_next_nodes(node)[0]
bias_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
# TODO(dlyakhov): make a node_name_vs_node map to speed up the process
graph_bias_node = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model))
graph_bias_const = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_const, model))
211 changes: 1 addition & 210 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,33 +160,6 @@ def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor, input_port_id: int) -> TransformationFNType:
"""
Return transformation which updates constant of the given node with bias to the given value.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:param input_port_id: Input port id to get constant node from.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

def bias_update_transformation(model: torch.fx.GraphModule):
graph = model.graph
target_node_name = node.node_name
graph_node = get_graph_node_by_name(graph, target_node_name)
add_nodes = []
for user in graph_node.users:
if _is_add(user):
add_nodes.append(user)
if len(add_nodes) != 1:
raise nncf.InternalError(f"Node {graph_node.name} has {len(add_nodes)} outputs with adds, 1 expected")

bias_node = add_nodes[0]
constant_update_fn(model, bias_node, value, input_port_id=input_port_id)

return bias_update_transformation


def constant_update_transformation_builder(
node: NNCFNode, value: torch.Tensor, input_port_id: int = 1
) -> TransformationFNType:
Expand Down Expand Up @@ -794,8 +767,6 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
# are being fused
fold_constant_except_qdq(model)
fuse_conv_bn(model)
separate_conv_and_bias(model)
separate_linear_and_bias(model)


def fold_constant_except_qdq(model: torch.fx.GraphModule):
Expand All @@ -817,8 +788,7 @@ def revert_quantization_transformations(model: torch.fx.GraphModule) -> None:
:param model: Model to revert transformations from.
"""
merge_conv_and_bias(model)
merge_linear_and_bias(model)
pass


def _is_linear(n: torch.fx.Node) -> bool:
Expand All @@ -840,182 +810,3 @@ def _is_conv(n: torch.fx.Node):
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
)


def _is_add(n: torch.fx.Node):
"""
Return whether the node refers to an aten add op.
"""
return n.op == "call_function" and n.target in (
torch.ops.aten.add_.Tensor,
torch.ops.aten.add.Tensor,
)


def separate_linear_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined linear+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
add_node_target = torch.ops.aten.add.Tensor
for n in model.graph.nodes:
if not _is_linear(n):
continue
# This check also makes sure to ignore linear nodes which might already
# have quantization applied to the weights.
if len(n.args) < 3 or n.args[2] is None or n.args[1].op != "get_attr":
continue
linear_node = n
linear_bias_node = linear_node.args[2]
while linear_bias_node.op != "get_attr":
# Assume zero argument is on a path to the constant
linear_bias_node = linear_bias_node.args[0]
linear_bias_value = get_tensor_constant_from_node(linear_bias_node, model)
args = list(n.args)
args[2] = None
linear_node.args = tuple(args)
with model.graph.inserting_after(linear_node):
new_linear_bias_node = create_getattr_from_value(
model,
model.graph,
linear_bias_node.name + "_",
linear_bias_value,
)
with model.graph.inserting_after(new_linear_bias_node):
add_node = model.graph.create_node(
"call_function", add_node_target, (linear_node, new_linear_bias_node), {}
)
for user in list(linear_node.users):
if user is add_node:
continue
user.replace_input_with(linear_node, add_node)
if "val" in linear_node.meta:
add_node.meta["val"] = linear_node.meta["val"]
model.graph.eliminate_dead_code()
model.recompile()


def separate_conv_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined conv+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
add_node_target = torch.ops.aten.add_.Tensor
for n in model.graph.nodes:
if not _is_conv(n):
continue
# This check also makes sure to ignore convolution nodes which might
# already have quantization applied to the weights.
if len(n.args) < 3 or n.args[2] is None or n.args[1].op != "get_attr":
continue
conv_node = n
dims = len(get_tensor_constant_from_node(conv_node.args[1], model).shape)
conv_bias_node = conv_node.args[2]
conv_bias_value = get_tensor_constant_from_node(conv_bias_node, model)
args = list(n.args)
args[2] = None
conv_node.args = tuple(args)
with model.graph.inserting_after(conv_node):
new_conv_bias_node = create_getattr_from_value(
model, model.graph, conv_bias_node.name + "_", conv_bias_value.reshape((1, -1) + (1,) * (dims - 2))
)
with model.graph.inserting_after(new_conv_bias_node):
add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {})
for user in list(conv_node.users):
if user is add_node:
continue
user.replace_input_with(conv_node, add_node)

if "val" in conv_node.meta:
add_node.meta["val"] = conv_node.meta["val"]
model.graph.eliminate_dead_code()
model.recompile()


def merge_conv_and_bias(model: torch.fx.GraphModule):
"""
Merges two separate conv and bias nodes to a one node: conv+bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
_merge_node_and_bias(model, _is_conv)


def merge_linear_and_bias(model: torch.fx.GraphModule):
"""
Merges two separate linear and bias nodes to a one node: linear+bias.
:param model: Target model.
"""
_merge_node_and_bias(model, _is_linear)


def _get_connected_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]:
"""
Returns the List of nodes which are directly or indirectly connected
to the output node.
:param graph: The torch FX graph to get nodes from.
"""
output_nodes = [node for node in graph.nodes if node.op == "output"]
assert len(output_nodes) == 1
output_node = output_nodes[0]
connected_nodes = set() # Every node is unique in the graph
nodes_to_visit = [output_node]
while nodes_to_visit:
current_node = nodes_to_visit.pop()
if current_node in connected_nodes:
continue
connected_nodes.add(current_node)
nodes_to_visit.extend(current_node.all_input_nodes)
return list(connected_nodes)


def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[torch.fx.Node], bool]):
"""
Merges two separate node and bias node to a one node: node+bias.
Check which node should be merged by the given `is_target_node` predicate.
:param model: Target model.
:param is_target_node: Predicate to specify nodes which should be merged with the bias
"""
add_node_targets = (torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor)
for n in model.graph.nodes:
if not is_target_node(n):
continue
if len(n.args) > 2 and n.args[2] is not None:
continue
bias_node = next(iter(n.users))
if len(n.users) > 1 or bias_node.target not in add_node_targets:
continue
conv_node = n
const_node = None
for node in bias_node.all_input_nodes:
if node is not conv_node:
const_node = node
break
assert const_node is not None
bias_value = get_tensor_constant_from_node(const_node, model).squeeze()
with model.graph.inserting_before(conv_node):
new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value)
args = list(conv_node.args)
args[2] = new_bias_node
conv_node.args = tuple(args)
for user in list(bias_node.users):
user.replace_input_with(bias_node, conv_node)

# Remove nodes which are not connected to output. This removes dead nodes and dead subgraphs in the model graph.
nodes_connected_to_output = _get_connected_nodes(model.graph)
is_impure = lambda node: node in nodes_connected_to_output

for node in reversed(model.graph.nodes):
if not is_impure(node) and len(node.users) == 0:
model.graph.erase_node(node)

model.graph.eliminate_dead_code()
model.recompile()
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from nncf.experimental.torch.fx.node_utils import get_bias_value
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import is_node_with_bias
from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder
from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder
from nncf.experimental.torch.fx.transformations import output_insertion_transformation_builder
from nncf.quantization.algorithms.bias_correction.backend import BiasCorrectionAlgoBackend
from nncf.tensor import Tensor
Expand All @@ -45,7 +45,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))
return FXApplyTransformationCommand(
constant_update_transformation_builder(node, bias_value.data, input_port_id=2)
)

@staticmethod
def model_extraction_command(
Expand Down Expand Up @@ -90,6 +92,10 @@ def get_input_name(model: torch.fx.GraphModule, node_name: str, input_port_id: i
@staticmethod
def get_output_name(model: torch.fx.GraphModule, node_name: str, output_port_id: int) -> int:
graph_node = get_graph_node_by_name(model.graph, node_name)
if graph_node.op == "output":
# Original node output is kept as the first
# output tensor, thus returns 0.
return 0
nodes = list(graph_node.users)
while nodes:
node = nodes.pop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nncf.experimental.torch.fx.model_utils import get_target_point
from nncf.experimental.torch.fx.node_utils import get_bias_value
from nncf.experimental.torch.fx.node_utils import is_node_with_bias
from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder
from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
from nncf.tensor import Tensor
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
Expand All @@ -41,7 +41,9 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))
return FXApplyTransformationCommand(
constant_update_transformation_builder(node, bias_value.data, input_port_id=2)
)

@staticmethod
def model_extraction_command(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _param_constant0" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 conv2d_input" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _param_constant0" -> "1 add" [label="(1, 1, 1, 1)", style=solid];
"1 add" -> "3 conv2d" [label="(1, 1, 1, 1)", style=solid];
"2 conv2d_input" -> "3 conv2d" [label=None, style=solid];
"3 conv2d" -> "4 output" [label="(1, 1, 3, 3)", style=solid];
}
Loading
Loading