Skip to content

Commit

Permalink
Revert "Proper view support for jagged layout NestedTensor (pytorch#1…
Browse files Browse the repository at this point in the history
…13279)"

This reverts commit 5855c49.

Reverted pytorch#113279 on behalf of https://github.com/jbschlosser due to Need to fix BC thing ([comment](pytorch#113279 (comment)))
  • Loading branch information
pytorchmergebot committed Mar 21, 2024
1 parent 12e7602 commit 224beec
Show file tree
Hide file tree
Showing 17 changed files with 245 additions and 533 deletions.
23 changes: 0 additions & 23 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,29 +303,6 @@ Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base,
return Tensor();
}

Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx) {
auto values = at::_nested_get_values(mutated_view);
if (inverse_return_mode != InverseReturnMode::NeverView) {
return values;
} else {
return values.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
}
}

Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
auto offsets = at::_nested_get_offsets(base);
auto lengths = at::_nested_get_lengths(base);
auto ragged_idx = at::_nested_get_ragged_idx(base);
auto dummy = at::_nested_get_jagged_dummy(base);
auto nt = at::_nested_view_from_jagged(mutated_view, offsets, dummy, lengths, ragged_idx);

if (inverse_return_mode != InverseReturnMode::NeverView) {
return nt;
} else {
return nt.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
}
}

Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) {
if (inverse_return_mode != InverseReturnMode::NeverView) {
return at::squeeze(mutated_view, dim);
Expand Down
46 changes: 0 additions & 46 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6158,52 +6158,6 @@
CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy
autogen: _nested_view_from_buffer_copy.out

- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a)
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor
variants: function
device_check: NoCheck
tags: view_copy
dispatch:
CompositeExplicitAutogradNonFunctional: _nested_view_from_jagged_copy
autogen: _nested_view_from_jagged_copy.out

- func: _nested_get_values(Tensor(a) self) -> Tensor(a)
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_get_values_copy(Tensor self) -> Tensor
variants: function
device_check: NoCheck
tags: view_copy
dispatch:
CompositeExplicitAutogradNonFunctional: _nested_get_values_copy
autogen: _nested_get_values_copy.out

- func: _nested_get_offsets(Tensor self) -> Tensor
variants: function
device_check: NoCheck
dispatch: {}

# returns undefined Tensor if no lengths present
- func: _nested_get_lengths(Tensor self) -> Tensor
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_get_ragged_idx(Tensor self) -> int
variants: function
device_check: NoCheck
dispatch: {}

- func: _nested_get_jagged_dummy(Tensor any) -> Tensor
category_override: dummy
dispatch: {}

- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
dispatch:
# calls unsqueeze
Expand Down
186 changes: 78 additions & 108 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
StatelessSymbolicContext,
)
from torch.nested._internal.nested_tensor import (
buffer_from_jagged,
jagged_from_list,
jagged_from_tensor_and_lengths,
nested_view_from_values_offsets,
NestedTensor,
ViewBufferFromNested,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand Down Expand Up @@ -1273,20 +1273,19 @@ def _test_autograd(self, backend):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
# TODO: Switch to public API when it exists
nt2, _ = jagged_from_list([a, b, c], nt.offsets())
nt, offsets = jagged_from_list([a, b, c], None)
nt2, _ = jagged_from_list([a, b, c], offsets)

def fn1(nt1, nt2):
return (nt1 + nt2).sin().cos()

compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True)
out = compiled_f(nt, nt2)
out_buffer = out.values()
out_buffer = ViewBufferFromNested.apply(out)
ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c))

out_ref = fn1(nt, nt2)
out_buffer_ref = out_ref.values()
out_buffer_ref = ViewBufferFromNested.apply(out_ref)
ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c))

self.assertTrue(torch.allclose(ga, ga_ref))
Expand Down Expand Up @@ -1326,10 +1325,10 @@ def fn(x, y):
ret = fn_c(nt, y)[0]
ref = fn(nt_copy, y_copy)[0]

self.assertEqual(ret.values(), ref.values())
self.assertEqual(buffer_from_jagged(ret), buffer_from_jagged(ref))

ret.values().sum().backward()
ref.values().sum().backward()
buffer_from_jagged(ret).sum().backward()
buffer_from_jagged(ref).sum().backward()
for ref_v, res_v in zip(values_copy, values):
self.assertEqual(ref_v.grad, res_v.grad)

Expand Down Expand Up @@ -1362,112 +1361,83 @@ def fn(x):
self._check_recompiles(fn, (nt,), (nt3,), True)

def _get_views(self):
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass
# Dense -> Subclass
for base_is_nt in [False, True]:
# There are three cases to consider here based on the logic in
# meta_utils.py
#
# (1) basic case:
# view is not a leaf and has the same requires grad as its basic case
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
x = x.clone() if base_is_nt else x
self.assertEqual(x.is_leaf, False)
yield x.unsqueeze(-1)

# (2) leaf view case:
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
# base w/ requires_grad True or requires_grad False
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
x, _ = self._get_jagged_tensor(
((2, 3, 4), 3), None, requires_grad=requires_grad_1
)
x = x.clone() if base_is_nt else x
with torch.no_grad():
x_view = x.unsqueeze(-1)
# The issue is this doesn't quite work
x_view.requires_grad_(requires_grad_2)
yield x_view

# (3) obscure case:
# view is not a leaf (implies requires_grad True)
# base w/ requires_grad False)
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
x = x.clone() if base_is_nt else x
# intermediate leaf view
# There are three cases to consider here based on the logic in
# meta_utils.py
#
# (1) basic case:
# view is not a leaf and has the same requires grad as its basic case
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
self.assertEqual(x.is_leaf, False)
yield x.unsqueeze(-1)

# (2) leaf view case:
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
# base w/ requires_grad True or requires_grad False
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
x, _ = self._get_jagged_tensor(
((2, 3, 4), 3), None, requires_grad=requires_grad_1
)
with torch.no_grad():
x_view = x.unsqueeze(-1)
x_view.requires_grad_(True)
x_view_view = x_view.unsqueeze(-1)
yield x_view_view

# Subclass -> Dense
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
yield x.values()

# Dense -> Subclass -> Dense -> Subclass
values = torch.randn(10, 5)
offsets = torch.tensor([0, 3, 6, 10])
offsets2 = offsets.clone().detach()
yield nested_view_from_values_offsets(
nested_view_from_values_offsets(values, offsets).values(), offsets
)

def _input_view_test(self, nt_view):
def fn(x):
return x.sin()
# The issue is this doesn't quite work
x_view.requires_grad_(requires_grad_2)
yield x_view

# (3) obscure case:
# view is not a leaf (implies requires_grad True)
# base w/ requires_grad False)
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
# intermediate leaf view
with torch.no_grad():
x_view = x.unsqueeze(-1)
x_view.requires_grad_(True)
x_view_view = x_view.unsqueeze(-1)
yield x_view_view

out_ref = fn(nt_view)
torch._dynamo.reset()
compile_fn = torch.compile(
fn, fullgraph=True, backend="aot_eager", dynamic=True
)
out = compile_fn(nt_view)
def test_inputs_to_compiled_fn_are_views(self):
for nt_view in self._get_views():

# Check metadata and values are correct
self.assertTrue(out.size() == out_ref.size())
self.assertTrue(out.stride() == out_ref.stride())
if out.is_nested:
self.assertTrue(torch.allclose(out.values(), out_ref.values()))
else:
self.assertTrue(torch.allclose(out, out_ref))
def fn(x):
return x.sin()

# Check that no upper/lower bound guards are incurred
def backend(gm, args):
context = torch._guards.TracingContext.get()
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
out_ref = fn(nt_view)
torch._dynamo.reset()
compile_fn = torch.compile(
fn, fullgraph=True, backend="aot_eager", dynamic=True
)
out = compile_fn(nt_view)

# varies based on the type of view
guard_str = "\n".join(guards)
if isinstance(nt_view._base, NestedTensor):
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
else:
self.assertExpectedInline(guard_str, """""")
return gm
# Check metadata and values are correct
self.assertTrue(out.size() == out_ref.size())
self.assertTrue(out.stride() == out_ref.stride())
self.assertTrue(torch.allclose(out.values(), out_ref.values()))

torch._dynamo.reset()
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
out = compile_fn(nt_view)
# Check that no upper/lower bound guards are incurred
def backend(gm, args):
context = torch._guards.TracingContext.get()
guards = [str(g.expr) for g in context.fake_mode.shape_env.guards]
ranges = [
f"{s}: [{vr.lower}, {vr.upper}]"
for s, vr in context.fake_mode.shape_env.var_to_range.items()
]
self.assertExpectedInline("\n".join(guards), """Eq(s3 - 1, s0)""")
self.assertExpectedInline(
"\n".join(ranges),
"""\
s0: [2, 9223372036854775805]
s2: [2, 9223372036854775806]
s3: [3, 9223372036854775806]
s5: [2, 9223372036854775806]""",
)
return gm

def test_inputs_to_compiled_fn_are_views(self):
for nt_view in self._get_views():
self._input_view_test(nt_view)

# NJT1 -> Dense -> NJT2 -> Dense view
# During view replay, the Dense -> NJT2 part will construct an intermediate,
# symbolically-sized NJT that is immediately deconstructed to return the final dense
# view. To construct this intermediate properly, we need the associated nested int
# to be symbolic. This view is expected to fail compilation until symbolic nested ints
# are cached onto fake offsets to solve this problem.
@unittest.expectedFailure
def test_subclass_dense_subclass_dense_view(self):
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
offsets2 = x.offsets().clone().detach()
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
self._input_view_test(nt_view)
torch._dynamo.reset()
compile_fn = torch.compile(
fn, fullgraph=True, backend=backend, dynamic=True
)
out = compile_fn(nt_view)


if __name__ == "__main__":
Expand Down
10 changes: 0 additions & 10 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,6 @@ aten::_nested_from_padded
aten::_nested_from_padded.out
aten::_nested_from_padded_and_nested_example
aten::_nested_from_padded_and_nested_example.out
aten::_nested_get_jagged_dummy
aten::_nested_get_lengths
aten::_nested_get_offsets
aten::_nested_get_ragged_idx
aten::_nested_get_values
aten::_nested_get_values_copy
aten::_nested_get_values_copy.out
aten::_nested_select_backward
aten::_nested_sum_backward
aten::_nested_tensor_from_mask
Expand All @@ -461,9 +454,6 @@ aten::_nested_tensor_strides.out
aten::_nested_view_from_buffer
aten::_nested_view_from_buffer_copy
aten::_nested_view_from_buffer_copy.out
aten::_nested_view_from_jagged
aten::_nested_view_from_jagged_copy
aten::_nested_view_from_jagged_copy.out
aten::_new_zeros_with_same_feature_meta
aten::_new_zeros_with_same_feature_meta.out
aten::_nnpack_spatial_convolution
Expand Down
34 changes: 2 additions & 32 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8822,7 +8822,7 @@ def _assert_match_metadata(a, b):
self.assertEqual(a.device, b.device)
self.assertEqual(a.dtype, b.dtype)

def _test_fn(fn, inp, *args, use_unsafe_view_func=False):
def _test_fn(fn, inp, *args):
outs = fn(inp, *args)
# handle functions that return multiple views (e.g. split)
if isinstance(outs, torch.Tensor):
Expand All @@ -8835,10 +8835,7 @@ def _test_fn(fn, inp, *args, use_unsafe_view_func=False):
# forward view_func
new_inp = inp.clone()
_assert_match_metadata(new_inp, inp)
if use_unsafe_view_func:
new_out = out._view_func_unsafe(new_inp)
else:
new_out = out._view_func(new_inp)
new_out = out._view_func(new_inp)
_assert_match_metadata(new_out, out)
self.assertEqual(new_out, out)

Expand Down Expand Up @@ -8904,33 +8901,6 @@ def chain_with_only_current_view_func(x):

_test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4))

# TODO: Move this somewhere else
# test NT views
from torch.nested._internal.nested_tensor import nested_view_from_values_offsets

values = torch.randn(10, 5)
offsets = torch.tensor([0, 3, 6, 10])
_test_fn(nested_view_from_values_offsets, values, offsets)

nt = nested_view_from_values_offsets(values, offsets).clone().detach()
_test_fn(torch.ops.aten._nested_get_values.default, nt, use_unsafe_view_func=True)

def chain_nt_to_dense_back_and_forth(nt):
# NJT1 -> dense -> NJT2 -> dense
offsets2 = nt.offsets().clone().detach()
return nested_view_from_values_offsets(nt.values(), offsets2).values()

_test_fn(chain_nt_to_dense_back_and_forth, nt, use_unsafe_view_func=True)

def chain_dense_to_nt_back_and_forth(values, offsets):
offsets2 = offsets.clone().detach()
# dense -> NJT1 -> dense -> NJT2
return nested_view_from_values_offsets(
nested_view_from_values_offsets(values, offsets).values(),
offsets2)

_test_fn(chain_dense_to_nt_back_and_forth, values, offsets, use_unsafe_view_func=True)

def test_view_func_replay_with_modified_state(self):
with torch.autograd._force_original_view_tracking(True):
base = torch.randn(3, 4, 5)
Expand Down
Loading

0 comments on commit 224beec

Please sign in to comment.