Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Improve 16bit patching #27693

Open
wants to merge 2 commits into
base: releases/2024/5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 32 additions & 51 deletions src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,15 @@
# flake8: noqa
# mypy: ignore-errors

import functools
import logging
import torch
from openvino.frontend.pytorch import ModuleExtension

log = logging.getLogger(__name__)


class no_jit_trace:
def __enter__(self):
self.state = torch._C._get_tracing_state()
torch._C._set_tracing_state(None)

def __exit__(self, *args):
torch._C._set_tracing_state(self.state)
self.state = None


def patch_model(model, module_extensions, orig_forward_name, use_meta=False):
def patch_model(model, module_extensions, orig_forward_name):
def module_patcher(m, name):
extension = None
if m in module_extensions:
Expand All @@ -33,43 +24,31 @@ def module_patcher(m, name):

if extension:
log.debug("Patching module %s", m)
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.
# The Trampoline class is instantiated for every module replacement, so we can use
# class members individually for each module.

class Trampoline(torch.autograd.Function):
# required to be saved in class
target_extension = extension
original_module = m
stashed_args = tuple()
stashed_kwargs = {}

@staticmethod
@torch.jit.ignore
def forward(*args, **kwargs):
with no_jit_trace():
# `module` is going to be passed to a user-defined function `evaluate`
# `module` is patched: forward function was replaced, and we are actually in this patched function right in this code
# if we pass `module` as-is to the user code below, and it happens to call forward it will lead to infinite recursion or fail
# so we need to temporary patch the module back to the original forward and then return it back again
# stash the current forward to be able to return it back
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(m, *Trampoline.stashed_args,
**Trampoline.stashed_kwargs)
m.forward = patched_forward # return patched forward back
return results
def forward(ctx, *args, **kwargs):
# Temporarily restore the original forward function of `module` to avoid
# recursion issues in `evaluate`, then revert it back.
patched_forward = m.forward
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(m, *args, **kwargs)
m.forward = patched_forward # return patched forward back
return results

def new_forward(*args, **kwargs):
# use meta device to store args, to save memory
if use_meta:
d = torch.device("meta")
Trampoline.stashed_args = tuple(a.to(d) for a in args)
Trampoline.stashed_kwargs = dict((k, v.to(d)) for k, v in kwargs.items())
else:
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
return extension.convert(m, Trampoline.apply, *args, **kwargs)

# make signature of new_forward same as of forward
new_forward = functools.wraps(m.forward)(new_forward)
setattr(m, orig_forward_name, m.forward)
m.forward = new_forward

Expand Down Expand Up @@ -106,36 +85,38 @@ def __make_16bit_traceable(model: torch.nn.Module):
extensions = {
torch.nn.Linear: ModuleExtension(
torch.nn.Linear, "ov_ext::linear",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias)),
module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32)),
torch.nn.Embedding: ModuleExtension(
torch.nn.Embedding, "ov_ext::embedding",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(module.weight,
args[0],
module.padding_idx,
module.scale_grad_by_freq,
module.sparse)),
module.sparse),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[1].shape) + [module.embedding_dim], 0.5, dtype=torch.float32)),
}
try:
from transformers.pytorch_utils import Conv1D
extensions[Conv1D] = ModuleExtension(
Conv1D, "ov_ext::conv1d",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias))
except:
module.bias),
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32))
except ImportError:
pass
patch_model(model, extensions,
"_openvino_module_extension_patch_orig_forward", use_meta=True)
"_openvino_module_extension_patch_orig_forward")
dtype_to_patch = [torch.float16, torch.bfloat16]
for _, module in model.named_modules():
if module.__class__ not in extensions and (any(p.dtype in [torch.float16, torch.bfloat16] for p in module.parameters(False))
or any(b.dtype in [torch.float16, torch.bfloat16] for b in module.buffers(False))):
if (module.__class__ not in extensions and
(any(p.dtype in dtype_to_patch for p in module.parameters(False))
or any(b.dtype in dtype_to_patch for b in module.buffers(False)))):
log.debug("Casting module %s to float32", module)
module.float()
4 changes: 4 additions & 0 deletions tests/layer_tests/py_frontend_tests/test_torch_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def test_patched_16bit_model_converts():
from openvino.frontend.pytorch import patch_model
from openvino import convert_model, compile_model
import copy
import inspect
from transformers.pytorch_utils import Conv1D

class ModelWithLinear(torch.nn.Module):
Expand Down Expand Up @@ -716,6 +717,9 @@ def forward(self, x1, x2):
model_fp16 = copy.deepcopy(model_ref).half()

patch_model.__make_16bit_traceable(model_fp16)
# verify torch.nn.Linear signature after patching
signature = inspect.signature(model_ref.branch1[0].forward).parameters
assert ["input"] == list(signature)
# the approach with patching only works for node with no grad
with torch.no_grad():
converted_model = convert_model(model_fp16, example_input=example)
Expand Down
Loading