Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement APOLLO optimizer #312

Merged
merged 11 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **84 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -192,6 +192,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |

## Supported LR Scheduler

Expand Down
3 changes: 3 additions & 0 deletions docs/changelogs/v3.3.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

* Support `Cautious` variant to `AdaShift` optimizer. (#310)
* Save the state of the `Lookahead` optimizer too. (#310)
* Implement `APOLLO` optimizer. (#311, #312)
* [SGD-like Memory, AdamW-level Performance](https://arxiv.org/abs/2412.05270)
* Rename the `Apollo` (`An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization`) optimizer name to `ApolloDQN` not to overlap with the new optimizer name `APOLLO`. (#312)

### Bug

Expand Down
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## The reasons why you use `pytorch-optimizer`.

* Wide range of supported optimizers. Currently, **83 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Wide range of supported optimizers. Currently, **84 optimizers (+ `bitsandbytes`, `qgalore`, `torchao`)**, **16 lr schedulers**, and **13 loss functions** are supported!
* Including many variants such as `Cautious`, `AdamD`, `Gradient Centrailiaztion`
* Easy to use, clean, and tested codes
* Active maintenance
Expand Down Expand Up @@ -192,6 +192,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| MicroAdam | *Accurate Adaptive Optimization with Low Space Overhead and Provable Convergence* | [github](https://github.com/IST-DASLab/MicroAdam) | <https://arxiv.org/abs/2405.15593> | [cite](https://github.com/IST-DASLab/MicroAdam?tab=readme-ov-file#citing) |
| Muon | *MomentUm Orthogonalized by Newton-schulz* | [github](https://github.com/KellerJordan/Muon) | <https://x.com/kellerjordan0/status/1842300916864844014> | [cite](https://github.com/KellerJordan/Muon) |
| LaProp | *Separating Momentum and Adaptivity in Adam* | [github](https://github.com/Z-T-WANG/LaProp-Optimizer) | <https://arxiv.org/abs/2002.04839> | [cite](https://github.com/Z-T-WANG/LaProp-Optimizer?tab=readme-ov-file#citation) |
| APOLLO | *SGD-like Memory, AdamW-level Performance* | [github](https://github.com/zhuhanqing/APOLLO) | <https://arxiv.org/abs/2412.05270> | [cite](https://github.com/zhuhanqing/APOLLO?tab=readme-ov-file#-citation) |

## Supported LR Scheduler

Expand Down
6 changes: 5 additions & 1 deletion docs/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@
:docstring:
:members:

::: pytorch_optimizer.Apollo
::: pytorch_optimizer.APOLLO
:docstring:
:members:

::: pytorch_optimizer.ApolloDQN
:docstring:
:members:

Expand Down
16 changes: 8 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ keywords = [
"pytorch", "deep-learning", "optimizer", "lr scheduler", "A2Grad", "ASGD", "AccSGD", "AdaBelief", "AdaBound",
"AdaDelta", "AdaFactor", "AdaMax", "AdamG", "AdaMod", "AdaNorm", "AdaPNM", "AdaSmooth", "AdEMAMix", "ADOPT",
"AdaHessian", "Adai", "Adalite", "AdaLomo", "AdamMini", "AdamP", "AdamS", "Adan", "AggMo", "Aida", "AliG", "Amos",
"Apollo", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD", "DAdaptLion",
"DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate", "Lamb", "LaProp",
"LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam", "PCGrad", "PID",
"PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM", "ScheduleFreeSGD",
"ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW", "SignSGD", "SM3", "SOAP",
"SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE", "BCEFocal", "Focal",
"FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky", "LovaszHinge",
"bitsandbytes", "WSD", "QGaLore",
"Apollo", "APOLLO", "AvaGrad", "bSAM", "CAME", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptSGD",
"DAdaptLion", "DeMo", "DiffGrad", "FAdam", "Fromage", "FTRL", "GaLore", "Gravity", "GrokFast", "GSAM", "Kate",
"Lamb", "LaProp", "LARS", "Lion", "LOMO", "Lookahead", "MADGRAD", "MSVAG", "Muno", "Nero", "NovoGrad", "PAdam",
"PCGrad", "PID", "PNM", "Prodigy", "QHAdam", "QHM", "RAdam", "Ranger", "Ranger21", "RotoGrad", "SAM",
"ScheduleFreeSGD", "ScheduleFreeAdamW", "ScheduleFreeRAdam", "SGDP", "Shampoo", "ScalableShampoo", "SGDW",
"SignSGD", "SM3", "SOAP", "SopihaH", "SRMM", "StableAdamW", "SWATS", "Tiger", "TRAC", "WSAM", "Yogi", "BCE",
"BCEFocal", "Focal", "FocalCosine", "SoftF1", "Dice", "LDAM", "Jaccard", "Bi-Tempered", "Tversky", "FocalTversky",
"LovaszHinge", "bitsandbytes", "WSD", "QGaLore",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from pytorch_optimizer.optimizer import (
ADOPT,
APOLLO,
ASGD,
BSAM,
CAME,
Expand Down Expand Up @@ -90,7 +91,7 @@
Aida,
AliG,
Amos,
Apollo,
ApolloDQN,
AvaGrad,
DAdaptAdaGrad,
DAdaptAdam,
Expand Down
5 changes: 3 additions & 2 deletions pytorch_optimizer/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytorch_optimizer.optimizer.aida import Aida
from pytorch_optimizer.optimizer.alig import AliG
from pytorch_optimizer.optimizer.amos import Amos
from pytorch_optimizer.optimizer.apollo import Apollo
from pytorch_optimizer.optimizer.apollo import APOLLO, ApolloDQN
from pytorch_optimizer.optimizer.avagrad import AvaGrad
from pytorch_optimizer.optimizer.came import CAME
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptLion, DAdaptSGD
Expand Down Expand Up @@ -228,7 +228,8 @@ def load_optimizer(optimizer: str) -> OPTIMIZER:
DAdaptAdan,
AdamS,
AdaFactor,
Apollo,
ApolloDQN,
APOLLO,
SWATS,
NovoGrad,
Lion,
Expand Down
168 changes: 162 additions & 6 deletions pytorch_optimizer/optimizer/apollo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Optional
import math
from typing import Literal, Optional

import numpy as np
import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector

SCALE_TYPE = Literal['channel', 'tensor']

class Apollo(BaseOptimizer):

class ApolloDQN(BaseOptimizer):
r"""An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand All @@ -25,8 +29,8 @@ class Apollo(BaseOptimizer):
def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
init_lr: Optional[float] = None,
lr: float = 1e-2,
init_lr: Optional[float] = 1e-5,
beta: float = 0.9,
rebound: str = 'constant',
weight_decay: float = 0.0,
Expand Down Expand Up @@ -58,7 +62,7 @@ def __init__(
super().__init__(params, defaults)

def __str__(self) -> str:
return 'Apollo'
return 'ApolloDQN'

@torch.no_grad()
def reset(self):
Expand Down Expand Up @@ -146,3 +150,155 @@ def step(self, closure: CLOSURE = None) -> LOSS:
p.add_(d_p, alpha=-current_lr)

return loss


class APOLLO(BaseOptimizer):
r"""SGD-like Memory, AdamW-level Performance.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
:param weight_decay: float. weight decay (L2 penalty).
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
:param fixed_decay: bool. fix weight decay.
:param correct_bias: bool. Whether to correct bias in Adam.
:param eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-2,
betas: BETAS = (0.9, 0.999),
scale_type: SCALE_TYPE = 'tensor',
weight_decay: float = 0.0,
weight_decouple: bool = True,
fixed_decay: bool = False,
correct_bias: bool = True,
eps: float = 1e-6,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
'scale_type': scale_type,
'weight_decay': weight_decay,
'weight_decouple': weight_decouple,
'fixed_decay': fixed_decay,
'correct_bias': correct_bias,
'eps': eps,
**kwargs,
}
super().__init__(params, defaults)

def __str__(self) -> str:
return 'APOLLO'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
group['step'] = 0
for p in group['params']:
state = self.state[p]

state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1

beta1, beta2 = group['betas']

step_size: float = group['lr']
if group['correct_bias']:
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))
step_size *= bias_correction2_sq / bias_correction1

for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(str(self))

state = self.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p)

if 'rank' in group and p.dim() > 1:
if 'projector' not in state:
state['projector'] = GaLoreProjector(
rank=group['rank'],
update_proj_gap=group['update_proj_gap'],
scale=group['scale'],
projection_type=group['projection_type'],
)

grad = state['projector'].project(grad, group['step'], from_random_matrix=True)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

de_nom = exp_avg_sq.sqrt().add_(group['eps'])

norm_grad = exp_avg / de_nom
if 'rank' in group and p.dim() > 1:
if group['scale_type'] == 'channel':
norm_dim: int = 0 if norm_grad.shape[0] < norm_grad.shape[1] else 1
scaling_factor = torch.norm(norm_grad, dim=norm_dim) / (torch.norm(grad, dim=norm_dim) + 1e-8)
if norm_dim == 1:
scaling_factor = scaling_factor.unsqueeze(1)
else:
scaling_factor = torch.norm(norm_grad) / (torch.norm(grad) + 1e-8)

scaling_grad = grad * scaling_factor

scaling_grad_norm = torch.norm(scaling_grad)
if 'scaling_grad' in state:
limiter = (
max(
scaling_grad_norm / (state['scaling_grad'] + 1e-8),
1.01,
)
/ 1.01
)

scaling_grad.div_(limiter)
scaling_grad_norm.div_(limiter)

state['scaling_grad'] = scaling_grad_norm

norm_grad = scaling_grad * np.sqrt(group['scale'])
norm_grad = state['projector'].project_back(norm_grad)

p.add_(norm_grad, alpha=-step_size)

self.apply_weight_decay(
p,
grad,
lr=step_size,
weight_decay=group['weight_decay'],
weight_decouple=group['weight_decouple'],
fixed_decay=group['fixed_decay'],
)

return loss
Loading
Loading