~ in Pytorch ~
LookSAM is an accelerated SAM algorithm. Instead of computing the inner gradient ascent every step, LookSAM computer it periodically and reuses the direction that promotes to flat regions.
This is unofficial repository for Towards Efficient and Scalable Sharpness-Aware Minimization. Currently it is only proposed an algorithm without layer-wise adaptive rates (but it will be soon...).
In rewritten step
method you are able to fed several arguments:
t
is a train_index to define index of current batch;samples
are input data;targets
are input ground-truth data;zero_sam_grad
is a boolean value to zero gradients under SAM condition (first step) (see discussion here ;zero_grad
is a boolean value for zero gradient after second step;
Unofficial SAM repo is my inspiration :)
from looksam import LookSAM
model = YourModel()
criterion = YourCriterion()
base_optimizer = YourBaseOptimizer
loader = YourLoader()
optimizer = LookSAM(
k=10,
alpha=0.7,
model=model,
base_optimizer=base_optimizer,
rho=0.1,
**kwargs
)
...
model.train()
for train_index, (samples, targets) in enumerate(loader):
...
loss = criterion(model(samples), targets)
loss.backward()
optimizer.step(
t=train_index,
samples=samples,
targets=targets,
zero_sam_grad=True,
zero_grad=True
)
...