From 78a9ad42516fae2680e9ef5bb15fefa5af8d7cd8 Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 3 Dec 2024 13:11:16 +0100 Subject: [PATCH 1/5] lr-scheduler in trainerBasic --- domainlab/algos/trainers/a_trainer.py | 8 ++++++-- domainlab/algos/trainers/train_basic.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index a51c34c14..b8d1de63b 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -5,6 +5,7 @@ import torch from torch import optim +from torch.optim.lr_scheduler import CosineAnnealingLR from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -16,6 +17,7 @@ def mk_opt(model, aconf): if model._decoratee is None: class_opt = getattr(optim, aconf.opt) optimizer = class_opt(model.parameters(), lr=aconf.lr) + scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) else: var1 = model.parameters() var2 = model._decoratee.parameters() @@ -27,7 +29,7 @@ def mk_opt(model, aconf): # {'params': model._decoratee.parameters()} # ], lr=aconf.lr) optimizer = optim.Adam(list_par, lr=aconf.lr) - return optimizer + return optimizer, scheduler class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta): @@ -94,6 +96,8 @@ def __init__(self, successor_node=None, extend=None): self.list_reg_over_task_ratio = None # MIRO self.input_tensor_shape = None + # LR-scheduler + self.lr_scheduler = None @property def model(self): @@ -168,7 +172,7 @@ def reset(self): """ make a new optimizer to clear internal state """ - self.optimizer = mk_opt(self.model, self.aconf) + self.optimizer, self.lr_scheduler = mk_opt(self.model, self.aconf) @abc.abstractmethod def tr_epoch(self, epoch): diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 179848467..10ac3b06f 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -82,6 +82,8 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) loss.backward() self.optimizer.step() + if self.lr_scheduler: + self.lr_scheduler.step() self.after_batch(epoch, ind_batch) self.counter_batch += 1 From 1d2dc78e4e4d646e74d34f0f287222fd26411ff3 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 17:46:34 +0100 Subject: [PATCH 2/5] scheduler always exsit --- domainlab/algos/trainers/a_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index b8d1de63b..2de178320 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -14,10 +14,10 @@ def mk_opt(model, aconf): """ create optimizer """ + scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) if model._decoratee is None: class_opt = getattr(optim, aconf.opt) optimizer = class_opt(model.parameters(), lr=aconf.lr) - scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) else: var1 = model.parameters() var2 = model._decoratee.parameters() From 41bc0e70e9fdafc7133658359d689fe356e20a55 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 17:55:16 +0100 Subject: [PATCH 3/5] lr scheduler via cmd arguments --- domainlab/algos/trainers/a_trainer.py | 8 ++++++-- domainlab/arg_parser.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 2de178320..051cc1e6f 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -5,7 +5,7 @@ import torch from torch import optim -from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim import lr_scheduler from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -14,7 +14,6 @@ def mk_opt(model, aconf): """ create optimizer """ - scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) if model._decoratee is None: class_opt = getattr(optim, aconf.opt) optimizer = class_opt(model.parameters(), lr=aconf.lr) @@ -29,6 +28,11 @@ def mk_opt(model, aconf): # {'params': model._decoratee.parameters()} # ], lr=aconf.lr) optimizer = optim.Adam(list_par, lr=aconf.lr) + if aconf.lr_scheduler is not None: + class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler) + scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos) + else: + scheduler = None return optimizer, scheduler diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 046810a66..47272d7fb 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -264,6 +264,13 @@ def mk_parser_main(): help="name of pytorch optimizer", ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="CosineAnnealingLR", + help="name of pytorch learning rate scheduler", + ) + parser.add_argument( "--param_idx", type=bool, From de004b5c1334b3060013c66991179e7c415a5711 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 19:35:01 +0100 Subject: [PATCH 4/5] unit test --- tests/test_lr_scheduler.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/test_lr_scheduler.py diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py new file mode 100644 index 000000000..7bb7ab92f --- /dev/null +++ b/tests/test_lr_scheduler.py @@ -0,0 +1,14 @@ + +""" +unit and end-end test for lr scheduler +""" +from tests.utils_test import utils_test_algo + + +def test_lr_scheduler(): + """ + train + """ + args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \ + --nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR" + utils_test_algo(args) From 06257726b48178d84188c690b99ab466321c2d4b Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Wed, 4 Dec 2024 19:39:40 +0100 Subject: [PATCH 5/5] Update arg_parser.py --- domainlab/arg_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 47272d7fb..bb7bda2b4 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -267,7 +267,7 @@ def mk_parser_main(): parser.add_argument( "--lr_scheduler", type=str, - default="CosineAnnealingLR", + default=None, help="name of pytorch learning rate scheduler", )