Skip to content

Commit

Permalink
[inductor] Do variance calculation in opmath type (pytorch#115181)
Browse files Browse the repository at this point in the history
Fixes pytorch#114903

Previously large split variance reductions stored the intermediates as float16
precision, which may lead to overflow as the intermediate result is
unnormalized.

In pytorch#114903 we see two different `num_split` decisions made based on the
hardware capabilities, one of which has large enough intermediates to cause
overflows.

Pull Request resolved: pytorch#115181
Approved by: https://github.com/shunting314
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Dec 13, 2023
1 parent 95de4f5 commit 42390a0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
13 changes: 13 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,19 @@ def fn(a):
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),)))
self.common(fn, ((torch.rand((14923), dtype=torch.float32),)))

def test_multilayer_var_lowp(self):
if self.device == "cpu" and IS_MACOS and not IS_X86:
atol, rtol = 1e-5, 5e-3
else:
atol, rtol = None, None

def fn(a):
return torch.var(a)

run_test = functools.partial(self.common, atol=atol, rtol=rtol)
run_test(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),))
run_test(fn, (torch.rand((14923), dtype=torch.float16),))

def test_embedding_bag_byte_unpack(self):
if self.device != "cpu":
raise unittest.SkipTest("No CUDA implementation (it returns empty)")
Expand Down
39 changes: 25 additions & 14 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
dtype_to_type,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
get_computation_dtype,
is_boolean_dtype,
is_float_dtype,
is_integer_dtype,
Expand Down Expand Up @@ -4572,7 +4573,7 @@ def var_mean_sum_(x, axis, correction, keepdim, return_mean):
denom = ExpandView.create(denom, list(sum_result.get_size()))
x_var = div(sum_result, denom)
if not return_mean:
return x_var
return (x_var,)

x_mean = x_mean if keepdim else squeeze(x_mean, axis)
return x_var, x_mean
Expand Down Expand Up @@ -4634,29 +4635,39 @@ def scale_fn(data):
if return_mean:
mean.realize()
return var, mean
return var
return (var,)


def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
out_dtype = x.get_dtype()
compute_dtype = get_computation_dtype(out_dtype)
x = to_dtype(x, compute_dtype, copy=False)
kwargs = dict(
x=x,
axis=axis,
correction=correction,
keepdim=keepdim,
return_mean=return_mean,
)
output = (
var_mean_sum_(**kwargs)
if use_two_step_variance(x, axis=axis, keepdim=keepdim)
else var_mean_welford_(**kwargs)
)
output = tuple(to_dtype(x, out_dtype, copy=False) for x in output)
return output[0] if not return_mean else output


@register_lowering([aten.var, prims.var])
def var_(x, axis=None, *, correction=None, keepdim=False):
if use_two_step_variance(x, axis=axis, keepdim=keepdim):
return var_mean_sum_(
x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
)

return var_mean_welford_(
return var_mean_helper_(
x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
)


@register_lowering(aten.var_mean)
def var_mean(x, axis=None, *, correction=None, keepdim=False):
if use_two_step_variance(x, axis=axis, keepdim=keepdim):
return var_mean_sum_(
x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
)

return var_mean_welford_(
return var_mean_helper_(
x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
)

Expand Down

0 comments on commit 42390a0

Please sign in to comment.