Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility to track the Hessian throughout training (WIP) #10

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
216 changes: 216 additions & 0 deletions example/tracking_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
Main file to orchestrate model training! This will track the top
eigenvalues during training and log them to track.
"""

import os
import time

import torch
import track

import skeletor
from skeletor.datasets import build_dataset, num_classes
from skeletor.models import build_model
from skeletor.optimizers import build_optimizer
from skeletor.utils import AverageMeter, accuracy, progress_bar

from hessian_eigenthings import HessianTracker

def add_train_args(parser):
# Main arguments go here
parser.add_argument('--arch', default='ResNet18')
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--lr', default=.1, type=float)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--eval_batch_size', default=100, type=int)
parser.add_argument('--epochs', default=200, type=int)
parser.add_argument('--schedule', type=int, nargs='+', default=[150, 190],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1,
help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=.9, type=float,
help='SGD momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float,
help='SGD weight decay')
parser.add_argument('--cuda', action='store_true',
help='if true, use GPU!')
parser.add_argument('--num_eigenthings', default=1, type=int,
help='number of eigenvalues to track')


def adjust_learning_rate(epoch, optimizer, lr, schedule, decay):
if epoch in schedule:
new_lr = lr * decay
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
else:
new_lr = lr
return new_lr


def train(trainloader, model, criterion, optimizer, epoch, cuda=False):
# switch to train mode
model.train()

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()

for batch_idx, (inputs, targets) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)

if cuda:
inputs, targets = inputs.cuda(), targets.cuda(async=True)

# compute output
outputs = model(inputs)
loss = criterion(outputs, targets)

# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))

# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

# plot progress
progress_str = 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\
% (losses.avg, top1.avg, top1.sum, top1.count)
progress_bar(batch_idx, len(trainloader), progress_str)

iteration = epoch * len(trainloader) + batch_idx
track.metric(iteration=iteration, epoch=epoch,
avg_train_loss=losses.avg,
avg_train_acc=top1.avg,
cur_train_loss=loss.item(),
cur_train_acc=prec1.item())
return (losses.avg, top1.avg)


def test(testloader, model, criterion, epoch, cuda=False):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()

# switch to evaluate mode
model.eval()

end = time.time()
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
# measure data loading time
data_time.update(time.time() - end)

if cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs = torch.autograd.Variable(inputs, volatile=True)
targets = torch.autograd.Variable(targets, volatile=True)

# compute output
outputs = model(inputs)
loss = criterion(outputs, targets)

# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

# plot progress
progress_str = 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\
% (losses.avg, top1.avg, top1.sum, top1.count)
progress_bar(batch_idx, len(testloader), progress_str)
track.metric(iteration=0, epoch=epoch,
avg_test_loss=losses.avg,
avg_test_acc=top1.avg)
return (losses.avg, top1.avg)


def do_training(args):
trainloader, testloader = build_dataset(args.dataset,
dataroot=args.dataroot,
batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
num_workers=2)
model = build_model(args.arch, num_classes=num_classes(args.dataset))
if args.cuda:
model = torch.nn.DataParallel(model).cuda()

# Calculate total number of model parameters
num_params = sum(p.numel() for p in model.parameters())
track.metric(iteration=0, num_params=num_params)

optimizer = build_optimizer('SGD', params=model.parameters(), lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)

criterion = torch.nn.CrossEntropyLoss()

# Create the hessian tracker!
hessian_tracker = HessianTracker(model, trainloader, criterion,
num_eigenthings=args.num_eigenthings,
momentum=0.9,
use_gpu=args.cuda)

best_acc = 0.0
for epoch in range(args.epochs):
track.debug("Starting epoch %d" % epoch)
args.lr = adjust_learning_rate(epoch, optimizer, args.lr, args.schedule,
args.gamma)
train_loss, train_acc = train(trainloader, model, criterion,
optimizer, epoch, args.cuda)
test_loss, test_acc = test(testloader, model, criterion, epoch,
args.cuda)
track.debug('Finished epoch %d... | train loss %.3f | train acc %.3f '
'| test loss %.3f | test acc %.3f'
% (epoch, train_loss, train_acc, test_loss, test_acc))

# Save model
model_fname = os.path.join(track.trial_dir(),
"model{}.ckpt".format(epoch))
torch.save(model, model_fname)
if test_acc > best_acc:
best_acc = test_acc
best_fname = os.path.join(track.trial_dir(), "best.ckpt")
track.debug("New best score! Saving model")
torch.save(model, best_fname)

# Compute the hessian spectrum
hessian_tracker.step()
eigenvals, _ = hessian_tracker.get_eigenthings()
track.debug('Top eigenvalue: %.3f' % float(max(eigenvals)))
track.metric(iteration=0, epoch=epoch,
eigenvals=eigenvals)


def postprocess(proj):
df = skeletor.proc.df_from_proj(proj)
if 'avg_test_acc' in df.columns:
best_trial = df.ix[df['avg_test_acc'].idxmax()]
print("Trial with top accuracy:")
print(best_trial)


if __name__ == '__main__':
skeletor.supply_args(add_train_args)
skeletor.supply_postprocess(postprocess, save_proj=True)
skeletor.execute(do_training)
4 changes: 3 additions & 1 deletion hessian_eigenthings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
""" Top-level module for hessian eigenvec computation """
from . import power_iter
from .hvp_operator import HVPOperator, compute_hessian_eigenthings
from .tracking import HessianTracker

__all__ = ['power_iter', 'HVPOperator', 'compute_hessian_eigenthings']
__all__ = ['power_iter', 'HVPOperator', 'compute_hessian_eigenthings',
'HessianTracker']

name = 'hessian_eigenthings'
9 changes: 7 additions & 2 deletions hessian_eigenthings/power_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,20 @@ def _new_op_fn(x, op=current_op, val=eigenval, vec=eigenvec):


def power_iteration(operator, steps=20, error_threshold=1e-4,
momentum=0.0, use_gpu=True):
momentum=0.0, init_vec=None, use_gpu=True):
"""
Compute dominant eigenvalue/eigenvector of a matrix
operator: linear Operator giving us matrix-vector product access
steps: number of update steps to take
returns: (principal eigenvalue, principal eigenvector) pair
"""
vector_size = operator.size # input dimension of operator
vec = torch.rand(vector_size)

if init_vec is None:
vec = torch.rand(vector_size)
else:
vec = init_vec

if use_gpu:
vec = vec.cuda()

Expand Down
95 changes: 95 additions & 0 deletions hessian_eigenthings/tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
This module enables tracking the hessian throughout training by incrementally
updating eigenvalue/eigenvec estimates as the model progresses.
"""
from .hvp_operator import HVPOperator
from .power_iter import LambdaOperator, power_iteration, deflated_power_iteration


class HessianTracker:
"""
This class incrementally tracks the top `num_eigenthings` eigenval/vec
pairs for the hessian of the loss of the given model. It uses
accelerated stochastic power iteration with deflation to do this.

model: PyTorch model for which we want to track the hessian
dataloader: PyTorch dataloader that lets us compute the gradient
loss: objective function that we take the hessian of
num_eigenthings: number of eigenval/vec pairs to track
power_iter_steps: default number of power iteration steps for each deflation step
power_iter_err_threshold: error tolerance for early stopping in power iteration
momentum: acceleration term for accelerated stochastic power iteration
max_samples: max number of samples we can compute the grad of at once
use_gpu: use cuda or not
"""
def __init__(self, model, dataloader, loss,
num_eigenthings=10,
power_iter_steps=20,
power_iter_err_threshold=1e-4,
momentum=0.0,
max_samples=512,
use_gpu=True):
self.num_eigenthings = num_eigenthings
self.hvp_operator = HVPOperator(model, dataloader, loss,
use_gpu=use_gpu, max_samples=max_samples)

# This function computes the initial eigenthing estimates
def _deflated_power_fn(op):
return deflated_power_iteration(op,
num_eigenthings,
power_iter_steps,
power_iter_err_threshold,
momentum,
use_gpu)
self.deflated_power_fn = _deflated_power_fn

# This function will update a single vector using the deflated op
def _power_iter_fn(op, prev, steps):
return power_iteration(op,
steps,
power_iter_err_threshold,
momentum,
prev,
use_gpu)
self.power_iter_fn = _power_iter_fn
self.power_iter_steps = power_iter_steps

# Set initial eigenvalue estimates
self.eigenvecs = None
self.eigenvals = None

def step(self, power_iter_steps=None):
"""
Perform power iteration, starting from the initial eigen estimates
we accrued from the previous steps.
"""
# Take the first estimate if we need to.
if self.eigenvals is None:
self.eigenvals, self.eigenvecs = self.deflated_power_fn(self.hvp_operator)
return

# Allow a variable number of update steps during training.
if power_iter_steps is None:
power_iter_steps = self.power_iter_steps

# Update existing estimates, one at a time.
def _deflate(x, val, vec):
return val * vec.dot(x) * vec

current_op = self.hvp_operator
for i in range(self.num_eigenthings):
prev_eigenvec = self.eigenvecs[i]
# Use the previous eigenvec estimate as the starting point.
new_eigenval, new_eigenvec = self.power_iter_fn(current_op, prev_eigenvec,
power_iter_steps)

# Deflate the HVP operator using this new estimate.
def _new_op_fn(x, op=current_op, val=new_eigenval, vec=new_eigenvec):
return op.apply(x) - _deflate(x, val, vec)
current_op = LambdaOperator(_new_op_fn, self.hvp_operator.size)
self.eigenvals[i] = new_eigenval
self.eigenvecs[i] = new_eigenvec

def get_eigenthings(self):
""" Get current estimate of the eigenthings """
return self.eigenvals, self.eigenvecs