diff --git a/README.md b/README.md index 7ae78d4..a620bf2 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ This is a work-in-progress general purpose optimization library for pytorch. We Most optimizers are modular, meaning you can chain them like this: ```py -optimizer = torchzero.optim.ModularOptimizer(model.parameters(), [*list of modules*])` +optimizer = torchzero.optim.Modular(model.parameters(), [*list of modules*])` ``` For example you might use `[ClipNorm(4), LR(1e-3), NesterovMomentum(0.9)]` for standard SGD with gradient clipping and nesterov momentum. Move `ClipNorm` to the end to clip the update instead of the gradients. If you don't have access to gradients, add a `RandomizedFDM()` at the beginning to approximate them via randomized finite differences. diff --git a/src/torchzero/modules/adaptive/adaptive.py b/src/torchzero/modules/adaptive/adaptive.py index 9f34a08..4717812 100644 --- a/src/torchzero/modules/adaptive/adaptive.py +++ b/src/torchzero/modules/adaptive/adaptive.py @@ -53,13 +53,7 @@ def _update(self, state, ascent): ascent *= fmask if self.mode == 'grad': - print(f'before {ascent = }') - print(f'{grad = }') - print(f'{mask = }') - print(f'{grad * mask.logical_not() = }') ascent += grad * mask.logical_not_() - print(f'after {ascent = }') - print() return ascent diff --git a/src/torchzero/modules/meta/grafting.py b/src/torchzero/modules/meta/grafting.py index cfbfba3..748e673 100644 --- a/src/torchzero/modules/meta/grafting.py +++ b/src/torchzero/modules/meta/grafting.py @@ -125,6 +125,8 @@ class IntermoduleCautious(OptimizerModule): main module or sequence of modules to chain, which update will be used with a consistency mask applied. compare_module (OptimizerModule | Iterable[OptimizerModule]): module or sequence of modules to chain, which update will be used to compute a consistency mask. + Can also be set to `ascent` to compare to update that is passed `main_module`, or `grad` to compare + to gradients. normalize (bool, optional): renormalize update after masking. only has effect when mode is 'zero'. Defaults to False. @@ -143,18 +145,21 @@ class IntermoduleCautious(OptimizerModule): def __init__( self, main_module: OptimizerModule | Iterable[OptimizerModule], - compare_module: OptimizerModule | Iterable[OptimizerModule], + compare_module: OptimizerModule | Iterable[OptimizerModule] | Literal['ascent', 'grad'], normalize=False, eps=1e-6, - mode: Literal["zero", "grad", "negate", "compare_module"] = "zero", + mode: Literal["zero", "grad", "backtrack", "compare_module"] = "zero", ): super().__init__({}) self._set_child_('main', Chain(main_module, ReturnAscent())) - self._set_child_('compare', Chain(compare_module, ReturnAscent())) + if isinstance(compare_module, str): self.compare_mode = compare_module + else: + self._set_child_('compare', Chain(compare_module, ReturnAscent())) + self.compare_mode = 'module' self.eps = eps self.normalize = normalize - self.mode: Literal["zero", "grad", "negate", "compare_module"] = mode + self.mode: Literal["zero", "grad", "backtrack", "compare_module"] = mode @torch.no_grad def step(self, state): @@ -163,13 +168,18 @@ def step(self, state): ascent: TensorList = self.children['main'].step(state_copy) # type:ignore state.update_attrs_(state_copy) - compare: TensorList = self.children['compare'].step(state) # type:ignore + if self.compare_mode == 'module': compare: TensorList = self.children['compare'].step(state) # type:ignore + else: + params = self.get_params() + if self.compare_mode == 'ascent': compare: TensorList = state.maybe_use_grad_(params) + elif self.compare_mode == 'grad': compare: TensorList = state.maybe_compute_grad_(params) + else: raise ValueError(f'Invalid compare_module: {self.compare_mode}') # mask will be > 0 for parameters where both signs are the same mask = (ascent * compare) > 0 - if self.mode == 'negate': - ascent -= ascent.mul(2).mul_(mask) + if self.mode == 'backtrack': + ascent -= ascent.mul(2).mul_(mask.logical_not_()) else: # normalize if mode is `zero` @@ -191,3 +201,4 @@ def step(self, state): state.ascent = ascent return self._update_params_or_step_with_next(state, params) + diff --git a/src/torchzero/optim/first_order/adamw.py b/src/torchzero/optim/first_order/adamw.py new file mode 100644 index 0000000..e69de29 diff --git a/src/torchzero/optim/first_order/cautious.py b/src/torchzero/optim/first_order/cautious.py index 4a56dcc..c852c30 100644 --- a/src/torchzero/optim/first_order/cautious.py +++ b/src/torchzero/optim/first_order/cautious.py @@ -18,14 +18,13 @@ def __init__( eps: float = 1e-8, amsgrad=False, c_eps = 1e-6, - normalize = True, + normalize = False, mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero' ): modules: list[OptimizerModule] = [ - Adam(lr = 1 if mode == 'grad' else lr, beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad), + Adam(lr = lr, beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad), Cautious(normalize = normalize, eps = c_eps, mode = mode), ] - if mode == 'grad': modules.append(LR(lr)) super().__init__(params, modules) @@ -44,10 +43,9 @@ def __init__( mode: typing.Literal['zero', 'grad', 'backtrack'] = 'zero' ): modules: list[OptimizerModule] = [ - SGD(lr = 1 if mode == 'grad' else lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov), + SGD(lr = lr, momentum = momentum, dampening = dampening, weight_decay = weight_decay, nesterov = nesterov), Cautious(normalize = normalize, eps = c_eps, mode = mode), ] - if mode == 'grad': modules.append(LR(lr)) super().__init__(params, modules) diff --git a/src/torchzero/optim/first_order/rprop.py b/src/torchzero/optim/first_order/rprop.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_against_reference.py b/tests/test_against_reference.py index 091ad56..f649077 100644 --- a/tests/test_against_reference.py +++ b/tests/test_against_reference.py @@ -136,3 +136,17 @@ def test_adagrad(lr, lr_decay, initial_accumulator_value, eps): lambda p: tz.optim.Modular(p, tz.m.Adagrad(lr, lr_decay, initial_accumulator_value, eps)), lambda p: torch.optim.Adagrad(p, lr, lr_decay, initial_accumulator_value = initial_accumulator_value, eps = eps), # type:ignore ) + +@pytest.mark.parametrize('lr', [1e-1]) +@pytest.mark.parametrize('compare', ['ascent', 'grad', tz.m.Mul(1)]) +@pytest.mark.parametrize('normalize', [True, False]) +@pytest.mark.parametrize('mode', ['zero', 'grad', 'backtrack']) +@pytest.mark.parametrize('modular', [True, False]) +def test_cautious_vs_intermodule(lr, compare,normalize, mode,modular): + """tests IntermoduleCautious""" + if modular: opt1 = lambda p: tz.optim.Modular(p, tz.m.Adam(lr), tz.m.Cautious(normalize=normalize, mode=mode)) + else: opt1 = lambda p: tz.optim.CautiousAdam(p, lr, normalize=normalize, mode=mode) + _test_against_reference( + opt1, + lambda p: tz.optim.Modular(p, tz.m.IntermoduleCautious(tz.m.Adam(lr), compare, normalize=normalize, mode=mode)), # type:ignore + ) \ No newline at end of file