diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index a51c34c14..051cc1e6f 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -5,6 +5,7 @@ import torch from torch import optim +from torch.optim import lr_scheduler from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -27,7 +28,12 @@ def mk_opt(model, aconf): # {'params': model._decoratee.parameters()} # ], lr=aconf.lr) optimizer = optim.Adam(list_par, lr=aconf.lr) - return optimizer + if aconf.lr_scheduler is not None: + class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler) + scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos) + else: + scheduler = None + return optimizer, scheduler class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta): @@ -94,6 +100,8 @@ def __init__(self, successor_node=None, extend=None): self.list_reg_over_task_ratio = None # MIRO self.input_tensor_shape = None + # LR-scheduler + self.lr_scheduler = None @property def model(self): @@ -168,7 +176,7 @@ def reset(self): """ make a new optimizer to clear internal state """ - self.optimizer = mk_opt(self.model, self.aconf) + self.optimizer, self.lr_scheduler = mk_opt(self.model, self.aconf) @abc.abstractmethod def tr_epoch(self, epoch): diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 179848467..10ac3b06f 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -82,6 +82,8 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss.backward() self.optimizer.step() + if self.lr_scheduler: + self.lr_scheduler.step() self.after_batch(epoch, ind_batch) self.counter_batch += 1 diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 046810a66..bb7bda2b4 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -264,6 +264,13 @@ def mk_parser_main(): help="name of pytorch optimizer", ) + parser.add_argument( + "--lr_scheduler", + type=str, + default=None, + help="name of pytorch learning rate scheduler", + ) + parser.add_argument( "--param_idx", type=bool, diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py new file mode 100644 index 000000000..7bb7ab92f --- /dev/null +++ b/tests/test_lr_scheduler.py @@ -0,0 +1,14 @@ + +""" +unit and end-end test for lr scheduler +""" +from tests.utils_test import utils_test_algo + + +def test_lr_scheduler(): + """ + train + """ + args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \ + --nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR" + utils_test_algo(args)