Skip to content

This is unofficial repository for Towards Efficient and Scalable Sharpness-Aware Minimization.

Notifications You must be signed in to change notification settings

rollovd/LookSAM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 

Repository files navigation

LookSAM Optimizer

Towards Efficient and Scalable Sharpness-Aware Minimization

~ in Pytorch ~

LookSAM is an accelerated SAM algorithm. Instead of computing the inner gradient ascent every step, LookSAM computer it periodically and reuses the direction that promotes to flat regions.

This is unofficial repository for Towards Efficient and Scalable Sharpness-Aware Minimization. Currently it is only proposed an algorithm without layer-wise adaptive rates (but it will be soon...).

In rewritten step method you are able to fed several arguments:

  1. t is a train_index to define index of current batch;
  2. samples are input data;
  3. targets are input ground-truth data;
  4. zero_sam_grad is a boolean value to zero gradients under SAM condition (first step) (see discussion here ;
  5. zero_grad is a boolean value for zero gradient after second step;

Unofficial SAM repo is my inspiration :)

Usage

from looksam import LookSAM


model = YourModel()
criterion = YourCriterion()
base_optimizer = YourBaseOptimizer
loader = YourLoader()

optimizer = LookSAM(
    k=10,
    alpha=0.7,
    model=model,
    base_optimizer=base_optimizer,
    rho=0.1,
    **kwargs
)

...

model.train()

for train_index, (samples, targets) in enumerate(loader):
    ...

    loss = criterion(model(samples), targets)
    loss.backward()
    optimizer.step(
        t=train_index, 
        samples=samples, 
        targets=targets, 
        zero_sam_grad=True, 
        zero_grad=True
    )
    ...

About

This is unofficial repository for Towards Efficient and Scalable Sharpness-Aware Minimization.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages