Skip to content

Commit

Permalink
smallest eigenvalue mode for hessian regularization + scipyroot
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Dec 22, 2024
1 parent b22ab54 commit bf8d2c9
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 36 deletions.
13 changes: 10 additions & 3 deletions src/torchzero/modules/misc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,20 @@ def __init__(self, value):
def _update(self, state, ascent): return ascent.add_(self.value)

class AddMagnitude(OptimizerModule):
"""Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update."""
def __init__(self, value):
"""Add `value` multiplied by sign of the ascent, i.e. this adds `value` to the magnitude of the update.
Args:
value (_type_): value to add to magnitude.
add_to_zero (bool, optional): if True, adds `value` to 0s. Otherwise, zeros remain zero. Defaults to True.
"""
def __init__(self, value, add_to_zero=True):
super().__init__({})
self.value = value
self.add_to_zero = add_to_zero
@torch.no_grad()
def _update(self, state, ascent):
return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
if self.add_to_zero: return ascent.add_(ascent.clamp_magnitude(min=1).sign_().mul_(self.value))
return ascent.add_(ascent.sign_().mul_(self.value))

class Mul(OptimizerModule):
"""Multiplies the update by `value`."""
Expand Down
25 changes: 16 additions & 9 deletions src/torchzero/modules/second_order/newton.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import typing as T
from typing import Literal
from collections import abc

import torch
Expand Down Expand Up @@ -44,8 +44,15 @@ def _fallback_safe_diag(hessian:torch.Tensor, grad:torch.Tensor, lr = 1e-2):
return grad.mul_(diag * lr), True


LinearSystemSolvers = T.Literal['cholesky', 'lu', 'cholesky_lu', 'lstsq']
FallbackLinearSystemSolvers = T.Literal['lstsq', 'safe_diag', 'gd']
def regularize_hessian_(hessian: torch.Tensor, value: float | Literal['eig']):
"""regularize hessian matrix in-place"""
if value == 'eig':
hessian.add_(torch.eye(hessian.shape[0], device=hessian.device, dtype=hessian.dtype), alpha=torch.linalg.eigvalsh(hessian).min()) # pylint:disable=not-callable
elif value != 0:
hessian.add_(torch.eye(hessian.shape[0], device=hessian.device,dtype=hessian.dtype), alpha = value)

LinearSystemSolvers = Literal['cholesky', 'lu', 'cholesky_lu', 'lstsq']
FallbackLinearSystemSolvers = Literal['lstsq', 'safe_diag', 'gd']

LINEAR_SYSTEM_SOLVERS = {
"cholesky": _cholesky_solve,
Expand All @@ -64,7 +71,8 @@ class ExactNewton(OptimizerModule):
Args:
tikhonov (float, optional):
tikhonov regularization (constant value added to the diagonal of the hessian).
Also known as Levenberg-Marquardt regularization. Defaults to 0.
Also known as Levenberg-Marquardt regularization. Can be set to 'eig', so it will be set
to the smallest eigenvalue of the hessian if that value is negative. Defaults to 0.
solver (Solvers, optional):
solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
fallback (Solvers, optional):
Expand All @@ -90,17 +98,17 @@ class ExactNewton(OptimizerModule):
"""
def __init__(
self,
tikhonov: float = 0.0,
tikhonov: float | Literal['eig'] = 0.0,
solver: LinearSystemSolvers = "cholesky_lu",
fallback: FallbackLinearSystemSolvers = "safe_diag",
validate=False,
tol: float = 1,
gd_lr = 1e-2,
batched_hessian=True,
diag: T.Literal[False] = False,
diag: bool = False,
):
super().__init__({})
self.tikhonov = tikhonov
self.tikhonov: float | Literal['eig'] = tikhonov
self.batched_hessian = batched_hessian

self.solver: abc.Callable = LINEAR_SYSTEM_SOLVERS[solver]
Expand Down Expand Up @@ -131,8 +139,7 @@ def step(self, state):
numel = gvec.numel()

# tikhonov regularization
if self.tikhonov != 0:
hessian += torch.eye(numel, device=hessian.device,dtype=hessian.dtype) * self.tikhonov
regularize_hessian_(hessian, self.tikhonov)

# calculate newton step
if self.diag:
Expand Down
6 changes: 3 additions & 3 deletions src/torchzero/optim/second_order/newton.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import typing as T
from typing import Literal
from collections import abc

import torch
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
self,
params,
lr: float = 1,
tikhonov: float = 0.,
tikhonov: float | Literal['eig'] = 0.0,
solver: LinearSystemSolvers = "cholesky_lu",
fallback: FallbackLinearSystemSolvers = "safe_diag",
max_norm: float | None = None,
Expand All @@ -61,7 +61,7 @@ def __init__(
line_search: LineSearches | None = None,
batched_hessian = True,

diag: T.Literal[False] = False,
diag: bool = False,
):
modules: list[OptimizerModule] = [
_ExactNewton(
Expand Down
35 changes: 24 additions & 11 deletions src/torchzero/optim/wrappers/nevergrad.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import typing
from collections import abc

import nevergrad as ng
import numpy as np
import torch

import nevergrad as ng

from ...core import TensorListOptimizer


def _ensure_float(x):
if isinstance(x, torch.Tensor): return x.detach().cpu().item()
elif isinstance(x, np.ndarray): return x.item()
if isinstance(x, np.ndarray): return x.item()
return float(x)

class NevergradOptimizer(TensorListOptimizer):
Expand Down Expand Up @@ -44,25 +45,37 @@ def __init__(
opt_cls:"type[ng.optimizers.base.Optimizer] | abc.Callable[..., ng.optimizers.base.Optimizer]",
budget: int | None = None,
mutable_sigma = False,
lb: float | None = None,
ub: float | None = None,
use_init = True,
):
super().__init__(params, {})
defaults = dict(lb=lb, ub=ub, use_init=use_init, mutable_sigma=mutable_sigma)
super().__init__(params, defaults)
self.opt_cls = opt_cls
self.opt = None
self.budget = budget
self.mutable_sigma = mutable_sigma
self.use_init = use_init

@torch.no_grad
def step(self, closure): # type:ignore # pylint:disable=signature-differs
params = self.get_params()
if self.opt is None:

if self.use_init:
parametrization = ng.p.Tuple(*(ng.p.Array(init = p.detach().cpu().numpy(), mutable_sigma=self.mutable_sigma) for p in params))
else:
parametrization = ng.p.Tuple(*(ng.p.Array(shape = p.shape, mutable_sigma=self.mutable_sigma) for p in params))

ng_params = []
for group in self.param_groups:
params = group['params']
mutable_sigma = group['mutable_sigma']
use_init = group['use_init']
lb = group['lb']
ub = group['ub']
for p in params:
if p.requires_grad:
if use_init:
ng_params.append(
ng.p.Array(init = p.detach().cpu().numpy(), lower=lb, upper=ub, mutable_sigma=mutable_sigma))
else:
ng_params.append(
ng.p.Array(shape = p.shape, lower=lb, upper=ub, mutable_sigma=mutable_sigma))

parametrization = ng.p.Tuple(*ng_params)
self.opt = self.opt_cls(parametrization, budget=self.budget)

x: ng.p.Tuple = self.opt.ask() # type:ignore
Expand Down
Loading

0 comments on commit bf8d2c9

Please sign in to comment.