Skip to content

Commit

Permalink
Convert SymInts to SymFloats with SymPy (pytorch#113683)
Browse files Browse the repository at this point in the history
  • Loading branch information
isuruf authored and pytorchmergebot committed Nov 20, 2023
1 parent 4182092 commit e4a88d9
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 10 deletions.
8 changes: 8 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,14 @@ def test_partials_lambda(x):
triple = functools.partial(multiply, y=3)
return triple(x)

def test_pow_int(self):
def fn(a, b):
return torch.pow(a, b)

x = torch.ones(2, 2)
opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
self.assertEqual(opt_fn(x, 2), fn(x, 2))

def test_tensor_size_indexed_by_symint(self):
def fn(x, y):
index = x.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def get_opoverloadpacket_from_dispatch(kernel):
def test_numpy_ref(self, device, dtype, op):
if (
TEST_WITH_TORCHINDUCTOR and
op.formatted_name == 'signal_windows_exponential' and
op.formatted_name in ('signal_windows_exponential', 'signal_windows_bartlett') and
dtype == torch.float64 and 'cuda' in device
): # noqa: E121
raise unittest.SkipTest("XXX: raises tensor-likes are not close.")
Expand Down
14 changes: 12 additions & 2 deletions torch/fx/experimental/sym_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def guard_int(self, file, line):
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
r = self.shape_env.evaluate_expr(
self.expr, self.hint, fx_node=self.fx_node, expect_rational=False
)
try:
return float(r)
except Exception:
Expand Down Expand Up @@ -628,6 +630,14 @@ def _sympy_abs(a):
return sympy.Abs(a)


def _sympy_sym_float(a):
# Cannot use sympy.Float(a) here, coz it expects python literals
# Multiply by 1.0 to cast to float. This is needed when the input
# is a SymInt which has the assumption that it is integer and
# SymPy will otherwise assume that return value cannot be a float.
return a * 1.0


magic_methods = {
**reflectable_magic_methods,
"sym_not": lambda a: ~a,
Expand All @@ -638,7 +648,7 @@ def _sympy_abs(a):
"le": _sympy_le,
"ge": _sympy_ge,
"floor": _sympy_floor,
"sym_float": lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals
"sym_float": _sympy_sym_float,
"ceil": _sympy_ceil,
"neg": lambda a: -a,
"sym_min": _sympy_min,
Expand Down
17 changes: 10 additions & 7 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3029,7 +3029,8 @@ def get_shape_groups(self):

@_lru_cache
def _maybe_evaluate_static(
self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False
self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False,
expect_rational=True,
) -> "Optional[sympy.Expr]":
"""
Tries to evaluate expr without introducing guards
Expand Down Expand Up @@ -3121,10 +3122,10 @@ def replace(expr, repl):

# Check if the range can solve it statically
out = bound_sympy(new_expr, new_range_env)
_assert_bound_is_rational(new_expr, out)

if out.is_singleton():
return out.lower
if expect_rational:
_assert_bound_is_rational(new_expr, out)
if out.is_singleton():
return out.lower

return new_expr if unbacked_only else None

Expand Down Expand Up @@ -3450,7 +3451,8 @@ def _log_guard(self, prefix: str, g):

@lru_cache(256)
@record_shapeenv_event(save_tracked_fakes=True)
def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):
def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None,
expect_rational=True):
"""
Given an expression, evaluates it, adding guards if necessary
"""
Expand Down Expand Up @@ -3510,7 +3512,8 @@ def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None):

expr = orig_expr

static_expr = self._maybe_evaluate_static(expr)
static_expr = self._maybe_evaluate_static(expr,
expect_rational=expect_rational)
if static_expr is not None:
self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
# NB: don't test float as there may be precision issues
Expand Down

0 comments on commit e4a88d9

Please sign in to comment.