diff --git a/test/test_autograd.py b/test/test_autograd.py index 445a35ace44e2..043ac4e501141 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1718,6 +1718,24 @@ def test_retain_grad_inplace(self): a.sum().backward() self.assertEqual(a.grad, torch.tensor([1.])) + # When in-place over view is done, the retains_grad hooks should be + # moved from base's original grad_fn to the copyslices node. + x = torch.tensor([1.], requires_grad=True).clone() + x.retain_grad() + x_view = x[:] + x_view *= 2 + x *= 2 + x.sum().backward() + # The grad is 1, not 4, because we are computing grad wrt the latest + # version of x. + self.assertEqual(a.grad, torch.tensor([1.])) + + # If the base did not originally require grad, there should be no hook + # to move. Make sure this case runs without error. + x = torch.zeros(4) + y = x.view(2, 2) + y.add_(torch.randn(2, 2, requires_grad=True)) + def test_retains_grad_inplace_multiple_outputs(self): class DoubleMul(Function): @staticmethod diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 81c0f19289943..821eea07c4b7d 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -239,6 +239,11 @@ void rebase_history(const Variable& self, Edge gradient_edge) { at::TensorGeometry(self), view_info.view_fn_, std::move(gradient_edge.function)); + if (self.requires_grad()) { + // If self did not previously require grad, there are no hooks to move + torch::autograd::impl::update_tensor_hooks_on_new_gradfn( + view_info.base_, view_info.base_.grad_fn(), copy_slices); + } set_gradient_edge(view_info.base_, {std::move(copy_slices), 0}); self.grad_fn(); // trigger an update to the view's grad_fn return;