Skip to content

Commit

Permalink
Properly move retains_grad hook on in-place over view for base (pytor…
Browse files Browse the repository at this point in the history
…ch#117552)

Fixes pytorch#117366
Pull Request resolved: pytorch#117552
Approved by: https://github.com/albanD
  • Loading branch information
soulitzer authored and pytorchmergebot committed Jan 25, 2024
1 parent 9c1348f commit 5b819d9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
18 changes: 18 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,24 @@ def test_retain_grad_inplace(self):
a.sum().backward()
self.assertEqual(a.grad, torch.tensor([1.]))

# When in-place over view is done, the retains_grad hooks should be
# moved from base's original grad_fn to the copyslices node.
x = torch.tensor([1.], requires_grad=True).clone()
x.retain_grad()
x_view = x[:]
x_view *= 2
x *= 2
x.sum().backward()
# The grad is 1, not 4, because we are computing grad wrt the latest
# version of x.
self.assertEqual(a.grad, torch.tensor([1.]))

# If the base did not originally require grad, there should be no hook
# to move. Make sure this case runs without error.
x = torch.zeros(4)
y = x.view(2, 2)
y.add_(torch.randn(2, 2, requires_grad=True))

def test_retains_grad_inplace_multiple_outputs(self):
class DoubleMul(Function):
@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
at::TensorGeometry(self),
view_info.view_fn_,
std::move(gradient_edge.function));
if (self.requires_grad()) {
// If self did not previously require grad, there are no hooks to move
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
view_info.base_, view_info.base_.grad_fn(), copy_slices);
}
set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
self.grad_fn(); // trigger an update to the view's grad_fn
return;
Expand Down

0 comments on commit 5b819d9

Please sign in to comment.