Skip to content

Commit

Permalink
add option for decoupled weight decay
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 15, 2024
1 parent 50edc8a commit 05258c9
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 8 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,10 @@ opt = Lion(
year = {2019}
}
```

```bibtex
@misc{Schaipp2024,
author = {Fabian Schaipp},
url = {https://fabian-sp.github.io/posts/2024/02/decoupling/}
}
```
15 changes: 13 additions & 2 deletions lion_pytorch/foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ def __init__(
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0
weight_decay: float = 0.0,
decoupled_weight_decay: bool = False
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
assert all([hasattr(torch, attr) for attr in ('_foreach_mul_', '_foreach_add_', '_foreach_sign_', '_foreach_lerp_')]), 'this version of torch does not have the prerequisite foreach functions'

self._init_lr = lr
self.decoupled_wd = decoupled_weight_decay

defaults = dict(
lr = lr,
betas = betas,
Expand All @@ -44,7 +48,14 @@ def step(

for group in self.param_groups:

lr, wd, beta1, beta2 = group['lr'], group['weight_decay'], *group['betas']
lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr

# maybe decoupled weight decay

if decoupled_wd:
wd /= init_lr

# accumulate List[Tensor] for foreach inplace updates

params = []
grads = []
Expand Down
19 changes: 14 additions & 5 deletions lion_pytorch/lion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def exists(val):
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay

p.data.mul_(1 - lr * wd)
p.data.mul_(1. - lr * wd)

# weight update

update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_()
update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1. - beta1).sign_()
p.add_(update, alpha = -lr)

# decay the momentum running average coefficient

exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2)
exp_avg.mul_(beta2).add_(grad, alpha = 1. - beta2)

# class

Expand All @@ -34,11 +34,15 @@ def __init__(
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
use_triton: bool = False
use_triton: bool = False,
decoupled_weight_decay: bool = False,
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])

self._init_lr = lr
self.decoupled_wd = decoupled_weight_decay

defaults = dict(
lr = lr,
betas = betas,
Expand Down Expand Up @@ -67,7 +71,12 @@ def step(
for group in self.param_groups:
for p in filter(lambda p: exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]
grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr

# maybe decoupled weight decay

if decoupled_wd:
wd /= init_lr

# init state - exponential moving average of gradient values

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lion-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.2.1',
license='MIT',
description = 'Lion Optimizer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 05258c9

Please sign in to comment.