Skip to content

Commit

Permalink
added forward gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Dec 25, 2024
1 parent e8579a8 commit c0b0e19
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 14 deletions.
63 changes: 63 additions & 0 deletions src/torchzero/optim/first_order/forward_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from ..modular import Modular
from ...core import OptimizationState, OptimizerModule
from ...tensorlist import Distributions

class ForwardGradientsModular(Modular):
"""EXPERIMENTAL (WILL TEST TOMORROW).
Evaluates jacobian-vector product with a random vector using forward mode autodiff (torch.func.jvp), which is
the true directional derivative in the direction of that vector.
This requires the closure to be rewritten with functional_call:
.. code:: python
def closure(params):
preds = torch.func.functional_call(model, params, (inputs, ))
loss = loss_fn(preds, targets)
return loss
This is a subclass of Modular (temporarily) so you have to pass modules to it.
For example:
.. code:: python
import torchzero as tz
opt = ForwardGradientsModular(model, tz.m.LR(1e-2))
Args:
params: iterable of parameters to optimize or dicts defining parameter groups.
modules: list of OptimizerModules.
distribution (Distributions, optional): distribution for directional derivative vector. Defaults to 'normal'.
Reference:
Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022).
Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
https://arxiv.org/abs/2202.08587
"""
def __init__(self, model:torch.nn.Module, *modules: OptimizerModule, distribution: Distributions = 'normal'):
if not isinstance(model, torch.nn.Module): raise TypeError("model must be torch.nn.Module")
super().__init__(model, *modules)
self.distribuition: Distributions = distribution

@torch.no_grad
def step(self, closure): # type:ignore # pylint:disable=signature-differs
assert self.model is not None
keys = [k for k, v in self.model.named_parameters() if v.requires_grad]

def list_closure(list_params):
dict_params = {k: p for k, p in zip(keys, list_params)}
return closure(dict_params)

params = self.get_params()
vec = params.sample_like(1, distribution = self.distribuition)

def forward_grad_closure(backward=True):
if backward:
loss, directional_derivative = torch.func.jvp(list_closure, primals = tuple(params), tangents = tuple(vec)) # type:ignore
ascent = vec * directional_derivative
params.set_grad_(ascent)
else: loss = list_closure(params)
return loss

state = OptimizationState(forward_grad_closure, self.model)
return self.chain.step(state)
28 changes: 14 additions & 14 deletions src/torchzero/tensorlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,31 +335,31 @@ def randn_like(self, **kwargs: Unpack[_NewTensorKwargs]): return self.__class__(

def randint_like(self, low: "Scalar | ScalarSequence", high: "Scalar | ScalarSequence", **kwargs: Unpack[_NewTensorKwargs]):
return self.zipmap_args(torch.randint_like, low, high, **kwargs)
def uniform_like(self, low: "Scalar | ScalarSequence" = 0, high: "Scalar | ScalarSequence" = 1, **kwargs: Unpack[_NewTensorKwargs]):
def uniform_like(self, low: "Scalar | ScalarSequence" = 0, high: "Scalar | ScalarSequence" = 1, generator=None, **kwargs: Unpack[_NewTensorKwargs]):
res = self.empty_like(**kwargs)
res.uniform_(low, high)
res.uniform_(low, high, generator=generator)
return res
def sphere_like(self, radius: "Scalar | ScalarSequence", **kwargs: Unpack[_NewTensorKwargs]) -> Self:
r = self.randn_like(**kwargs)
return (r * radius) / r.total_vector_norm() # type:ignore
def bernoulli(self):
return self.__class__(torch.bernoulli(i) for i in self)
def bernoulli_like(self, p: "Scalar | ScalarSequence" = 0.5):
def bernoulli(self, generator = None):
return self.__class__(torch.bernoulli(i, generator=generator) for i in self)
def bernoulli_like(self, p: "Scalar | ScalarSequence" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
"""p is probability of a 1, other values will be 0."""
return self.__class__(torch.bernoulli(i) for i in self.full_like(p))
def rademacher_like(self, p: "Scalar | ScalarSequence" = 0.5):
return self.__class__(torch.bernoulli(i, generator = generator) for i in self.full_like(p, **kwargs))
def rademacher_like(self, p: "Scalar | ScalarSequence" = 0.5, generator = None, **kwargs: Unpack[_NewTensorKwargs]):
"""p is probability of a 1, other values will be -1."""
return self.bernoulli_like(p) * 2 - 1
return self.bernoulli_like(p, generator=generator, **kwargs) * 2 - 1

def sample_like(self, eps: "Scalar | ScalarSequence" = 1, distribution: Distributions = 'normal'):
def sample_like(self, eps: "Scalar | ScalarSequence" = 1, distribution: Distributions = 'normal', generator=None, **kwargs: Unpack[_NewTensorKwargs]):
"""Sample around 0."""
if distribution == 'normal': return self.randn_like() * eps
if distribution == 'normal': return self.randn_like(**kwargs) * eps # TODO: generator
if distribution == 'uniform':
if isinstance(eps, (list,tuple)):
return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps]) # type:ignore
return self.uniform_like(-eps/2, eps/2)
if distribution == 'sphere': return self.sphere_like(eps)
if distribution == 'rademacher': return self.rademacher_like() * eps
return self.uniform_like([-i/2 for i in eps], [i/2 for i in eps], generator=generator, **kwargs) # type:ignore
return self.uniform_like(-eps/2, eps/2, generator=generator, **kwargs)
if distribution == 'sphere': return self.sphere_like(eps, **kwargs)
if distribution == 'rademacher': return self.rademacher_like(generator=generator, **kwargs) * eps
raise ValueError(f'Unknow distribution {distribution}')

def eq(self, other: STOrSTSequence): return self.zipmap(torch.eq, other)
Expand Down

0 comments on commit c0b0e19

Please sign in to comment.