Skip to content

Commit

Permalink
[Dynamo] Handle guard_size_oblivious in user code (pytorch#120379)
Browse files Browse the repository at this point in the history
Fixes pytorch#120083

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#120379
Approved by: https://github.com/yanboliang
  • Loading branch information
ezyang authored and pytorchmergebot committed Feb 23, 2024
1 parent a5548c6 commit edf1c4e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
13 changes: 13 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
constrain_unify,
ConstraintViolationError,
expect_true,
guard_size_oblivious,
ShapeEnv,
)
from torch.nn import functional as F
Expand Down Expand Up @@ -9515,6 +9516,18 @@ def fn(x):
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c2), 0)

@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_guard_size_oblivious(self):
# This code, in fact, does NOT work in eager
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
y = torch.zeros(x.item())
if guard_size_oblivious(y.size(0) == 0):
assert False
return y

self.assertEqual(fn(torch.tensor([0])), torch.zeros(0))

@unittest.skipIf(not TEST_CUDA, "requires cuda")
def test_module_free(self):
"""Test that CUDA memory is freed when a model goes out of scope"""
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@
"torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable,
"torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable,
"torch._dynamo.mark_static": UserFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
}


Expand Down
14 changes: 14 additions & 0 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,20 @@ def fn_with_prim_types(x):
raise unimplemented(
"torch.nn.functional.one_hot with data-dependent output shape"
)
elif (
self.value is torch.fx.experimental.symbolic_shapes.guard_size_oblivious
and len(args) == 1
and isinstance(args[0], SymNodeVariable)
):
# TODO: this probably should be folded somewhere else but I'm not
# sure where
# TODO: some of the other symbolic_shapes special tools can also
# get this treatment too
(cond,) = args
return variables.ConstantVariable.create(
torch.fx.experimental.symbolic_shapes.guard_size_oblivious(cond.sym_num)
)

else:
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
all_ints_or_floats = all(
Expand Down

0 comments on commit edf1c4e

Please sign in to comment.