From fa77c1a98e77502d3715aaf6234d4a9b01f0c164 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 21 May 2024 14:32:52 +0000 Subject: [PATCH] chore(optim): wrap `torch.autograd.grad()` with `torch.enable_grad()` context --- torchopt/optim/func/base.py | 6 ++++-- torchopt/optim/meta/base.py | 22 +++++++++------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7bb27877..fede4c2c 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -87,8 +87,10 @@ def step( if inplace is None: inplace = self.inplace - # Step parameter only - grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) + with torch.enable_grad(): + # Step parameters only + grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) + updates, self.optim_state = self.impl.update( grads, self.optim_state, diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 73ecdde7..1fa0b875 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -72,26 +72,22 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals ): flat_params: TupleOfTensors flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type] + if isinstance(state, UninitializedState): state = self.impl.init(flat_params) - grads = torch.autograd.grad( - loss, - flat_params, - create_graph=True, - allow_unused=True, - ) - updates, new_state = self.impl.update( - grads, - state, - params=flat_params, - inplace=False, - ) - self.state_groups[i] = new_state + + with torch.enable_grad(): + grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True) + + updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False) + flat_new_params = apply_updates(flat_params, updates, inplace=False) new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment] container_treespec, flat_new_params, ) + + self.state_groups[i] = new_state for container, new_param in zip(param_container, new_params): container.update(new_param)