From 9b1afe0da23ef8a083f965402d9a24c6676c25fc Mon Sep 17 00:00:00 2001 From: Refik Can Malli Date: Tue, 16 Feb 2021 21:34:27 +0100 Subject: [PATCH 1/5] Add online algorithm wrapper --- src/omniglot/main.py | 2 +- src/omniglot/model.py | 30 +++-- src/omniglot/wrapper.py | 188 ++++++++++++++++++++++++++++-- src/warpgrad/warpgrad/__init__.py | 2 +- 4 files changed, 202 insertions(+), 20 deletions(-) diff --git a/src/omniglot/main.py b/src/omniglot/main.py index 928bf99..e57bf8f 100644 --- a/src/omniglot/main.py +++ b/src/omniglot/main.py @@ -74,7 +74,7 @@ help='Turn off batch normalization') parser.add_argument('--meta_model', type=str, default='warp_leap', - help='Meta-learner [warp_leap, leap, reptile,' + help='Meta-learner [warp_leap, warp_online, leap, reptile,' 'maml, fomaml, ft, no]') parser.add_argument('--inner_opt', type=str, default='sgd', help='Optimizer in inner (task) loop: SGD or Adam') diff --git a/src/omniglot/model.py b/src/omniglot/model.py index 4615a7b..deae8cd 100644 --- a/src/omniglot/model.py +++ b/src/omniglot/model.py @@ -3,7 +3,8 @@ """ import torch.nn as nn from wrapper import (WarpGradWrapper, LeapWrapper, MAMLWrapper, NoWrapper, - FtWrapper, FOMAMLWrapper, ReptileWrapper) + FtWrapper, FOMAMLWrapper, ReptileWrapper, + WarpGradOnlineWrapper) NUM_CLASSES = 50 ACT_FUNS = { @@ -43,14 +44,25 @@ def get_model(args, criterion): print(model) if "warp" in args.meta_model.lower(): - return WarpGradWrapper( - model, - args.inner_opt, - args.outer_opt, - args.inner_kwargs, - args.outer_kwargs, - args.meta_kwargs, - criterion) + # this uses online algorithm wrapper + if "online" in args.meta_model.lower(): + return WarpGradOnlineWrapper( + model, + args.inner_opt, + args.outer_opt, + args.inner_kwargs, + args.outer_kwargs, + args.meta_kwargs, + criterion) + else: + return WarpGradWrapper( + model, + args.inner_opt, + args.outer_opt, + args.inner_kwargs, + args.outer_kwargs, + args.meta_kwargs, + criterion) if args.meta_model.lower() == 'leap': return LeapWrapper( diff --git a/src/omniglot/wrapper.py b/src/omniglot/wrapper.py index 6fb6af1..b88ea37 100644 --- a/src/omniglot/wrapper.py +++ b/src/omniglot/wrapper.py @@ -13,10 +13,11 @@ from leap.utils import clone_state_dict from utils import Res, AggRes +from warpgrad import SGD +from warpgrad.utils import step, backward, unfreeze, freeze class BaseWrapper(object): - """Generic training wrapper. Arguments: @@ -123,7 +124,7 @@ def run_batches(self, batches, optimizer, train=False, meta_train=False): if not train: continue - final = (n+1) == N + final = (n + 1) == N loss.backward() if meta_train: @@ -139,8 +140,182 @@ def run_batches(self, batches, optimizer, train=False, meta_train=False): return res -class WarpGradWrapper(BaseWrapper): +class WarpGradOnlineWrapper(BaseWrapper): + """Wrapper around WarpGrad meta-learners using online learning algorithm 1. + + Arguments: + model (nn.Module): classifier. + optimizer_cls: optimizer class. + meta_optimizer_cls: meta optimizer class. + optimizer_kwargs (dict): kwargs to pass to optimizer upon construction. + meta_optimizer_kwargs (dict): kwargs to pass to meta optimizer upon + construction. + meta_kwargs (dict): kwargs to pass to meta-learner upon construction. + criterion (func): loss criterion to use. + """ + + def __init__(self, + model, + optimizer_cls, + meta_optimizer_cls, + optimizer_kwargs, + meta_optimizer_kwargs, + meta_kwargs, + criterion): + + optimizer_parameters = warpgrad.OptimizerParameters( + trainable=meta_kwargs.pop('learn_opt', False), + default_lr=optimizer_kwargs['lr'], + default_momentum=optimizer_kwargs['momentum'] + if 'momentum' in optimizer_kwargs else 0.) + + # For now it is a dummy updater does nothing in backward call. + updater = warpgrad.SimpleUpdater(criterion, **meta_kwargs) + + # we don't need replay buffer for algorithm1 + model = warpgrad.Warp(model=model, + adapt_modules=list(model.adapt_modules()), + warp_modules=list(model.warp_modules()), + updater=updater, + buffer=None, + optimizer_parameters=optimizer_parameters) + + super(WarpGradOnlineWrapper, self).__init__(criterion, + model, + optimizer_cls, + optimizer_kwargs) + self.meta_optimizer_cls = optim.SGD \ + if meta_optimizer_cls.lower() == 'sgd' else optim.Adam + lra = meta_optimizer_kwargs.pop( + 'lr_adapt', meta_optimizer_kwargs['lr']) + lri = meta_optimizer_kwargs.pop( + 'lr_init', meta_optimizer_kwargs['lr']) + lrl = meta_optimizer_kwargs.pop( + 'lr_lr', meta_optimizer_kwargs['lr']) + self.meta_optimizer = self.meta_optimizer_cls( + [{'params': self.model.init_parameters(), 'lr': lri}, + {'params': self.model.warp_parameters(), 'lr': lra}, + {'params': self.model.optimizer_parameters(), 'lr': lrl}], + **meta_optimizer_kwargs) + + # This is the meta loss that we are going to accumulate. + self.meta_loss = 0 + + def _partial_meta_update(self, loss, final): + pass + + def _final_meta_update(self): + + def step_fn(): + self.meta_optimizer.step() + self.meta_optimizer.zero_grad() + + self.model.backward(step_fn, **self.optimizer_kwargs) + + def run_tasks(self, tasks, meta_train): + """Train on a mini-batch tasks and evaluate test performance. + + Arguments: + tasks (list, torch.utils.data.DataLoader): list of task-specific + dataloaders. + meta_train (bool): whether current run in during meta-training. + """ + results = [] + self.meta_loss = 0 + for task in tasks: + task.dataset.train() + trainres = self.run_task(task, train=True, meta_train=meta_train) + task.dataset.eval() + valres = self.run_task(task, train=False, meta_train=False) + results.append((trainres, valres)) + ## + results = AggRes(results) + + # Meta gradient step + if meta_train: + # at the end of collection for K steps N tasks we do the backward + # pass. + backward(self.meta_loss, self.model.meta_parameters( + include_init=False)) + self._final_meta_update() + + return results + + def run_task(self, task, train, meta_train): + """Run model on a given task, first adapting and then evaluating""" + self.model.no_collect() + + optimizer = None + if train: + # TODO: Discuss implementation and correct it. + # This line breakes gradient computation for now + # meta_layers required_grad properties are set to False if + # we call init_adaptation + # self.model.init_adaptation() + self.model.train() + + optimizer = self.optimizer_cls( + self.model.optimizer_parameter_groups(), + **self.optimizer_kwargs) + else: + self.model.eval() + + return self.run_batches( + task, optimizer, train=train, meta_train=meta_train) + + def run_batches(self, batches, optimizer, train=False, meta_train=False): + """Iterate over task-specific batches. + + Arguments: + batches (torch.utils.data.DataLoader): task-specific dataloaders. + optimizer (torch.nn.optim): optimizer instance if training is True. + train (bool): whether to train on task. + meta_train (bool): whether to meta-train on task. + """ + device = next(self.model.parameters()).device + self.model.no_collect() + res = Res() + N = len(batches) + for n, (input, target) in enumerate(batches): + inner_input = input.to(device, non_blocking=True) + inner_target = target.to(device, non_blocking=True) + + # Evaluate model + prediction = self.model(inner_input) + loss = self.criterion(prediction, inner_target) + + res.log(loss=loss.item(), pred=prediction, target=inner_target) + + # TRAINING # + if not train: + continue + + final = (n + 1) == N + loss.backward() + + if meta_train: + opt = SGD(self.model.optimizer_parameter_groups(tensor=True)) + opt.zero_grad() + outer_input, outer_target = next(iter(batches)) + l_outer, (l_inner, a1, a2) = step( + criterion=self.criterion, + x_inner=inner_input, x_outer=outer_input, + y_inner=inner_target, y_outer=outer_target, + model=self.model, + optimizer=opt, scorer=None) + self.meta_loss = self.meta_loss + l_outer + del l_inner, a1, a2 + + optimizer.step() + optimizer.zero_grad() + if final: + break + res.aggregate() + return res + + +class WarpGradWrapper(BaseWrapper): """Wrapper around WarpGrad meta-learners. Arguments: @@ -242,7 +417,6 @@ def run_task(self, task, train, meta_train): class LeapWrapper(BaseWrapper): - """Wrapper around the Leap meta-learner. Arguments: @@ -294,7 +468,6 @@ def run_task(self, task, train, meta_train): class MAMLWrapper(object): - """Wrapper around the MAML meta-learner. Arguments: @@ -358,7 +531,6 @@ def run_meta_batch(self, meta_batch, meta_train): class NoWrapper(BaseWrapper): - """Wrapper for baseline without any meta-learning. Arguments: @@ -367,6 +539,7 @@ class NoWrapper(BaseWrapper): optimizer_kwargs (dict): kwargs to pass to optimizer upon construction. criterion (func): loss criterion to use. """ + def __init__(self, model, optimizer_cls, optimizer_kwargs, criterion): super(NoWrapper, self).__init__(criterion, model, @@ -390,7 +563,6 @@ def _final_meta_update(self): class _FOWrapper(BaseWrapper): - """Base wrapper for First-order MAML and Reptile. Arguments: @@ -476,7 +648,6 @@ def _final_meta_update(self): class ReptileWrapper(_FOWrapper): - """Wrapper for Reptile. Arguments: @@ -515,7 +686,6 @@ def __init__(self, *args, **kwargs): class FtWrapper(BaseWrapper): - """Wrapper for Multi-headed finetuning. This wrapper differs from others in that it blends batches from all tasks diff --git a/src/warpgrad/warpgrad/__init__.py b/src/warpgrad/warpgrad/__init__.py index 1594f57..f73534a 100644 --- a/src/warpgrad/warpgrad/__init__.py +++ b/src/warpgrad/warpgrad/__init__.py @@ -1,3 +1,3 @@ from .warpgrad import Warp, OptimizerParameters, ReplayBuffer -from .updaters import DualUpdater +from .updaters import DualUpdater, SimpleUpdater from .optim import SGD, Adam From a5157c221cc510d85a64d039cc3b43f361c87eee Mon Sep 17 00:00:00 2001 From: Refik Can Malli Date: Tue, 16 Feb 2021 21:34:46 +0100 Subject: [PATCH 2/5] Add simple updater as placeholder --- src/warpgrad/warpgrad/updaters.py | 47 ++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/warpgrad/warpgrad/updaters.py b/src/warpgrad/warpgrad/updaters.py index e99fff5..020ae0e 100644 --- a/src/warpgrad/warpgrad/updaters.py +++ b/src/warpgrad/warpgrad/updaters.py @@ -22,8 +22,46 @@ state_dict_to_par_list) -class DualUpdater: +class SimpleUpdater: + """ + """ + + def __init__(self, criterion, init_objective=0, + epochs=1, bsz=1, norm=True, approx=False): + """Initialize an dummy updater. + + Arguments: + criterion (function): task loss criterion. + init_objective (int): type of objective for initialization + (optional). + epochs (int): number of times to iterate over buffer (default=1). + bsz (int): task parameter batch size between updates (default=1). + norm (bool): use the norm in the Leap objective (d1) + (default=True). + approx (bool): use approximate (Hessian-free) meta-objective. + """ + self.init_objective = init_objective + self.criterion = criterion + self.epochs = epochs + self.approx = approx + self.norm = norm + self.bsz = bsz + def backward(self, model, step_fn, **opt_kwargs): + """It does nothing for now. + + Arguments: + model (Warp): warped model to backprop through. + step_fn (function): step function for the meta gradient. + **opt_kwargs (kwargs): optional arguments to inner optimizer. + """ + + # init_objective = INIT_OBJECTIVES[self.init_objective] + # init_objective(model.named_init_parameters(suffix=None), + # params, self.norm, self.bsz, step_fn) + + +class DualUpdater: """Implements the WarpGrad meta-objective. This updater applies the WarpGrad meta-objective to warp-parameters and @@ -73,10 +111,11 @@ def backward(self, model, step_fn, **opt_kwargs): warp_objective(model, self.criterion, params, optimizer_buffers, data, step_fn, opt_kwargs, self.epochs, self.bsz, self.approx) - init_objective= INIT_OBJECTIVES[self.init_objective] + init_objective = INIT_OBJECTIVES[self.init_objective] init_objective(model.named_init_parameters(suffix=None), params, self.norm, self.bsz, step_fn) + def warp_on_same_loss(model, criterion, trj, brj, tds, step_fn, opt_kwargs, epochs, bsz, approx): """WarpGrad uses same objective in first and second step.""" @@ -124,7 +163,7 @@ def _step(batch): if bsz > 0: for i in range(0, len(datapoints), bsz): - _step(datapoints[i:i+bsz]) + _step(datapoints[i:i + bsz]) else: _step(datapoints) @@ -147,7 +186,7 @@ def simplified_leap(named_init, trj, norm, bsz, step_fn): joblib.delayed(line_seg_len)( trj[t][i], trj[t][i + 1], par_names, norm, device) for t in trj - for i in range(0, len(trj[t])-1) + for i in range(0, len(trj[t]) - 1) ) for i, a in zip(init, zip(*adds)): From 793564b1ecd389cf31e04d24e3569f8c6667a4e9 Mon Sep 17 00:00:00 2001 From: Refik Can Malli Date: Tue, 16 Feb 2021 21:35:04 +0100 Subject: [PATCH 3/5] Add condition for checking buffer --- src/warpgrad/warpgrad/warpgrad.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/warpgrad/warpgrad/warpgrad.py b/src/warpgrad/warpgrad/warpgrad.py index 19d5d97..ab86452 100644 --- a/src/warpgrad/warpgrad/warpgrad.py +++ b/src/warpgrad/warpgrad/warpgrad.py @@ -467,7 +467,8 @@ def init_adaptation(self, reset_adapt_parameters=None): def clear(self): """Clears parameter trajectory buffer.""" - self.buffer.clear() + if self.buffer is not None: + self.buffer.clear() def collect(self): """Switch on task parameter collection in buffer.""" From 0a3e9021511b3bece6854df2a21515d6403c1474 Mon Sep 17 00:00:00 2001 From: Refik Can Malli Date: Wed, 17 Feb 2021 10:50:10 +0100 Subject: [PATCH 4/5] Update updaters.py --- src/warpgrad/warpgrad/updaters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/warpgrad/warpgrad/updaters.py b/src/warpgrad/warpgrad/updaters.py index 020ae0e..94cf599 100644 --- a/src/warpgrad/warpgrad/updaters.py +++ b/src/warpgrad/warpgrad/updaters.py @@ -59,7 +59,7 @@ def backward(self, model, step_fn, **opt_kwargs): # init_objective = INIT_OBJECTIVES[self.init_objective] # init_objective(model.named_init_parameters(suffix=None), # params, self.norm, self.bsz, step_fn) - + pass class DualUpdater: """Implements the WarpGrad meta-objective. From bb23226888f7c4394cc35ffdf690b9983b05e901 Mon Sep 17 00:00:00 2001 From: Refik Can Malli Date: Fri, 5 Mar 2021 20:04:05 +0100 Subject: [PATCH 5/5] add handling initialization state --- src/omniglot/model.py | 2 +- src/omniglot/wrapper.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/omniglot/model.py b/src/omniglot/model.py index deae8cd..54242d2 100644 --- a/src/omniglot/model.py +++ b/src/omniglot/model.py @@ -1,6 +1,7 @@ """Base Omniglot models. Based on original implementation: https://github.com/amzn/metalearn-leap """ +import torch import torch.nn as nn from wrapper import (WarpGradWrapper, LeapWrapper, MAMLWrapper, NoWrapper, FtWrapper, FOMAMLWrapper, ReptileWrapper, @@ -491,7 +492,6 @@ def init_adaptation(self): """Reset stats for new task""" # Reset head if multi-headed, otherwise null-op self.head.reset_parameters() - # Reset BN running stats for m in self.modules(): if hasattr(m, 'reset_running_stats'): diff --git a/src/omniglot/wrapper.py b/src/omniglot/wrapper.py index b88ea37..f5af98c 100644 --- a/src/omniglot/wrapper.py +++ b/src/omniglot/wrapper.py @@ -14,7 +14,7 @@ from utils import Res, AggRes from warpgrad import SGD -from warpgrad.utils import step, backward, unfreeze, freeze +from warpgrad.utils import step, backward, unfreeze, freeze, copy class BaseWrapper(object): @@ -236,23 +236,26 @@ def run_tasks(self, tasks, meta_train): if meta_train: # at the end of collection for K steps N tasks we do the backward # pass. - backward(self.meta_loss, self.model.meta_parameters( - include_init=False)) + meta_parameters = self.model.meta_parameters(include_init=False) + unfreeze(meta_parameters) + backward(self.meta_loss, meta_parameters) self._final_meta_update() + freeze(meta_parameters) return results def run_task(self, task, train, meta_train): """Run model on a given task, first adapting and then evaluating""" self.model.no_collect() - optimizer = None if train: # TODO: Discuss implementation and correct it. # This line breakes gradient computation for now # meta_layers required_grad properties are set to False if # we call init_adaptation - # self.model.init_adaptation() + copy(self.model.adapt_state(), self.model.init_state()) + freeze(self.model.meta_parameters()) + unfreeze(self.model.adapt_parameters()) self.model.train() optimizer = self.optimizer_cls( @@ -295,6 +298,7 @@ def run_batches(self, batches, optimizer, train=False, meta_train=False): loss.backward() if meta_train: + unfreeze(self.model.meta_parameters()) opt = SGD(self.model.optimizer_parameter_groups(tensor=True)) opt.zero_grad() outer_input, outer_target = next(iter(batches)) @@ -305,6 +309,7 @@ def run_batches(self, batches, optimizer, train=False, meta_train=False): model=self.model, optimizer=opt, scorer=None) self.meta_loss = self.meta_loss + l_outer + freeze(self.model.meta_parameters()) del l_inner, a1, a2 optimizer.step()