Skip to content

Commit

Permalink
additions to intermodule cautious
Browse files Browse the repository at this point in the history
  • Loading branch information
inikishev committed Dec 21, 2024
1 parent daa89d8 commit 77947f7
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ This is a work-in-progress general purpose optimization library for pytorch. We

Most optimizers are modular, meaning you can chain them like this:
```py
optimizer = torchzero.optim.ModularOptimizer(model.parameters(), [*list of modules*])`
optimizer = torchzero.optim.Modular(model.parameters(), [*list of modules*])`
```
For example you might use `[ClipNorm(4), LR(1e-3), NesterovMomentum(0.9)]` for standard SGD with gradient clipping and nesterov momentum. Move `ClipNorm` to the end to clip the update instead of the gradients. If you don't have access to gradients, add a `RandomizedFDM()` at the beginning to approximate them via randomized finite differences.

Expand Down
6 changes: 0 additions & 6 deletions src/torchzero/modules/adaptive/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,7 @@ def _update(self, state, ascent):
ascent *= fmask

if self.mode == 'grad':
print(f'before {ascent = }')
print(f'{grad = }')
print(f'{mask = }')
print(f'{grad * mask.logical_not() = }')
ascent += grad * mask.logical_not_()
print(f'after {ascent = }')
print()

return ascent

Expand Down
25 changes: 18 additions & 7 deletions src/torchzero/modules/meta/grafting.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class IntermoduleCautious(OptimizerModule):
main module or sequence of modules to chain, which update will be used with a consistency mask applied.
compare_module (OptimizerModule | Iterable[OptimizerModule]):
module or sequence of modules to chain, which update will be used to compute a consistency mask.
Can also be set to `ascent` to compare to update that is passed `main_module`, or `grad` to compare
to gradients.
normalize (bool, optional):
renormalize update after masking.
only has effect when mode is 'zero'. Defaults to False.
Expand All @@ -143,18 +145,21 @@ class IntermoduleCautious(OptimizerModule):
def __init__(
self,
main_module: OptimizerModule | Iterable[OptimizerModule],
compare_module: OptimizerModule | Iterable[OptimizerModule],
compare_module: OptimizerModule | Iterable[OptimizerModule] | Literal['ascent', 'grad'],
normalize=False,
eps=1e-6,
mode: Literal["zero", "grad", "negate", "compare_module"] = "zero",
mode: Literal["zero", "grad", "backtrack", "compare_module"] = "zero",
):
super().__init__({})

self._set_child_('main', Chain(main_module, ReturnAscent()))
self._set_child_('compare', Chain(compare_module, ReturnAscent()))
if isinstance(compare_module, str): self.compare_mode = compare_module
else:
self._set_child_('compare', Chain(compare_module, ReturnAscent()))
self.compare_mode = 'module'
self.eps = eps
self.normalize = normalize
self.mode: Literal["zero", "grad", "negate", "compare_module"] = mode
self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode

@torch.no_grad
def step(self, state):
Expand All @@ -163,13 +168,18 @@ def step(self, state):
ascent: TensorList = self.children['main'].step(state_copy) # type:ignore
state.update_attrs_(state_copy)

compare: TensorList = self.children['compare'].step(state) # type:ignore
if self.compare_mode == 'module': compare: TensorList = self.children['compare'].step(state) # type:ignore
else:
params = self.get_params()
if self.compare_mode == 'ascent': compare: TensorList = state.maybe_use_grad_(params)
elif self.compare_mode == 'grad': compare: TensorList = state.maybe_compute_grad_(params)
else: raise ValueError(f'Invalid compare_module: {self.compare_mode}')

# mask will be > 0 for parameters where both signs are the same
mask = (ascent * compare) > 0

if self.mode == 'negate':
ascent -= ascent.mul(2).mul_(mask)
if self.mode == 'backtrack':
ascent -= ascent.mul(2).mul_(mask.logical_not_())

else:
# normalize if mode is `zero`
Expand All @@ -191,3 +201,4 @@ def step(self, state):

state.ascent = ascent
return self._update_params_or_step_with_next(state, params)

Empty file.
8 changes: 3 additions & 5 deletions src/torchzero/optim/first_order/cautious.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def __init__(
eps: float = 1e-8,
amsgrad=False,
c_eps = 1e-6,
normalize = True,
normalize = False,
mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero'
):
modules: list[OptimizerModule] = [
Adam(lr = 1 if mode == 'grad' else lr, beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
Adam(lr = lr, beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
Cautious(normalize = normalize, eps = c_eps, mode = mode),
]
if mode == 'grad': modules.append(LR(lr))

super().__init__(params, modules)

Expand All @@ -44,10 +43,9 @@ def __init__(
mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero'
):
modules: list[OptimizerModule] = [
SGD(lr = 1 if mode == 'grad' else lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov),
SGD(lr = lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov),
Cautious(normalize = normalize, eps = c_eps, mode = mode),
]
if mode == 'grad': modules.append(LR(lr))

super().__init__(params, modules)

Empty file.
14 changes: 14 additions & 0 deletions tests/test_against_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,17 @@ def test_adagrad(lr, lr_decay, initial_accumulator_value, eps):
lambda p: tz.optim.Modular(p, tz.m.Adagrad(lr, lr_decay, initial_accumulator_value, eps)),
lambda p: torch.optim.Adagrad(p, lr, lr_decay, initial_accumulator_value = initial_accumulator_value, eps = eps), # type:ignore
)

@pytest.mark.parametrize('lr', [1e-1])
@pytest.mark.parametrize('compare', ['ascent', 'grad', tz.m.Mul(1)])
@pytest.mark.parametrize('normalize', [True, False])
@pytest.mark.parametrize('mode', ['zero', 'grad', 'backtrack'])
@pytest.mark.parametrize('modular', [True, False])
def test_cautious_vs_intermodule(lr, compare,normalize, mode,modular):
"""tests IntermoduleCautious"""
if modular: opt1 = lambda p: tz.optim.Modular(p, tz.m.Adam(lr), tz.m.Cautious(normalize=normalize, mode=mode))
else: opt1 = lambda p: tz.optim.CautiousAdam(p, lr, normalize=normalize, mode=mode)
_test_against_reference(
opt1,
lambda p: tz.optim.Modular(p, tz.m.IntermoduleCautious(tz.m.Adam(lr), compare, normalize=normalize, mode=mode)), # type:ignore
)

0 comments on commit 77947f7

Please sign in to comment.