Skip to content

Commit

Permalink
Revert "[dynamo] Forward OptimizedModule.__setattr__ to the wrapped m…
Browse files Browse the repository at this point in the history
…odule (pytorch#122098)"

This reverts commit b6982bf.

Reverted pytorch#122098 on behalf of https://github.com/atalman due to Failing internally ([comment](pytorch#122098 (comment)))
  • Loading branch information
pytorchmergebot committed Mar 26, 2024
1 parent 537cd66 commit f631586
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 52 deletions.
31 changes: 0 additions & 31 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Owner(s): ["module: dynamo"]

import collections
import copy
import itertools
import traceback
import types
Expand Down Expand Up @@ -2378,36 +2377,6 @@ def generate(x, c):
generate(torch.randn(10, 10), 0)
self.assertEqual(cnt.frame_count, 3)

def test_setattr_on_compiled_module(self):
# https://github.com/pytorch/pytorch/issues/114844

class ReplayMutation(torch.nn.Module):
def __init__(self, inp_size, out_size, inner_size):
super().__init__()
self.Linear1 = torch.nn.Linear(inp_size, inner_size)
self.Linear2 = torch.nn.Linear(inner_size, out_size)
self.x = None

def forward(self, inp):
res = self.Linear1(inp)
self.x = res
return self.Linear2(res)

N, D_in, H, D_out, inner = 2, 2, 2, 2, 4
model = ReplayMutation(D_in, H, inner)
model2 = copy.deepcopy(model)
input = torch.ones(N, D_in)

# Keep some intermediate value in model.x
model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
model(input)

compiled_model = torch.compile(model2, backend="eager")
compiled_model.x = torch.tensor([[100, 100, 100, 100], [200, 200, 200, 200]])
compiled_model(input)

self.assertEqual(model.x, compiled_model.x)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
29 changes: 27 additions & 2 deletions torch/_dynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
When disable_clone is True, we will use args as-is without cloning.
This is higher fidelity but we may destroy the args in the process.
"""
from torch._functorch.aot_autograd import make_boxed_func

from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass

gm = copy.deepcopy(gm)
Expand All @@ -319,9 +321,19 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
if hasattr(gm, "zero_grad"):
gm.zero_grad(True)

# TorchInductor returned callable expects lists. So, may need a boxed calling convention.
out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args)
# TorchInductor returned callable expects lists. So, boxing the call.
orig_named_parameters = getattr(gm, "named_parameters", None)
orig_named_buffers = getattr(gm, "named_buffers", None)
if not hasattr(gm, "_boxed_call") and (
orig_named_parameters is not None or orig_named_buffers is not None
):
gm = make_boxed_func(gm)
if orig_named_parameters is not None:
gm.named_parameters = orig_named_parameters
if orig_named_buffers is not None:
gm.named_buffers = orig_named_buffers

out = gm(args)
if only_fwd:
return out
if requires_bwd_pass(out):
Expand All @@ -347,8 +359,21 @@ def same_two_models(
is mostly useful for the minifier (which wants to avoid quantizing floating point
error into integer/boolean error)
"""
from .eval_frame import OptimizedModule
from .testing import (
named_buffers_for_optimized_module,
named_parameters_for_optimized_module,
)
from .utils import same

if isinstance(gm, OptimizedModule):
gm.named_parameters = named_parameters_for_optimized_module(gm)
gm.named_buffers = named_buffers_for_optimized_module(gm)

if isinstance(opt_gm, OptimizedModule):
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm)

ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)

fp64_ref = None
Expand Down
19 changes: 0 additions & 19 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,6 @@ class OptimizedModule(torch.nn.Module):
_torchdynamo_orig_callable: Callable[..., Any]
get_compiler_config: Callable[[], Any]

_opt_mod_attributes = {
"_orig_mod",
"dynamo_ctx",
"_torchdynamo_orig_callable",
"get_compiler_config",
"forward",
"_forward",
"__dict__",
}

def __init__(self, mod: torch.nn.Module, dynamo_ctx):
super().__init__()
# Installs the params/buffer
Expand Down Expand Up @@ -166,15 +156,6 @@ def __getattr__(self, name):
return self._modules["_orig_mod"]
return getattr(self._orig_mod, name)

def __setattr__(self, name, val):
# Allow patching over class attributes
if hasattr(type(self), name):
return super().__setattr__(name, val)

if name in OptimizedModule._opt_mod_attributes:
return super().__setattr__(name, val)
return setattr(self._orig_mod, name, val)

def _call_lazy_check(self, *args, **kwargs):
if hasattr(self._orig_mod, "_initialize_hook"):
# In the case of a lazy module, we want to run
Expand Down
10 changes: 10 additions & 0 deletions torch/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def clone_me(x):
return x.detach().clone().requires_grad_(x.requires_grad)


def named_parameters_for_optimized_module(mod):
assert isinstance(mod, eval_frame.OptimizedModule)
return mod._orig_mod.named_parameters


def named_buffers_for_optimized_module(mod):
assert isinstance(mod, eval_frame.OptimizedModule)
return mod._orig_mod.named_buffers


def remove_optimized_module_prefix(name) -> str:
return re.sub(r"^_orig_mod[.]", "", name)

Expand Down

0 comments on commit f631586

Please sign in to comment.