Skip to content

Commit

Permalink
Merge pull request #897 from marrlab/lr_scheduler
Browse files Browse the repository at this point in the history
lr-scheduler in trainerBasic
  • Loading branch information
smilesun authored Dec 5, 2024
2 parents a715cfe + 0625772 commit 747476c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
12 changes: 10 additions & 2 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch import optim
from torch.optim import lr_scheduler

from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler

Expand All @@ -27,7 +28,12 @@ def mk_opt(model, aconf):
# {'params': model._decoratee.parameters()}
# ], lr=aconf.lr)
optimizer = optim.Adam(list_par, lr=aconf.lr)
return optimizer
if aconf.lr_scheduler is not None:
class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler)
scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos)
else:
scheduler = None
return optimizer, scheduler


class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -94,6 +100,8 @@ def __init__(self, successor_node=None, extend=None):
self.list_reg_over_task_ratio = None
# MIRO
self.input_tensor_shape = None
# LR-scheduler
self.lr_scheduler = None

@property
def model(self):
Expand Down Expand Up @@ -168,7 +176,7 @@ def reset(self):
"""
make a new optimizer to clear internal state
"""
self.optimizer = mk_opt(self.model, self.aconf)
self.optimizer, self.lr_scheduler = mk_opt(self.model, self.aconf)

@abc.abstractmethod
def tr_epoch(self, epoch):
Expand Down
2 changes: 2 additions & 0 deletions domainlab/algos/trainers/train_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others)
loss.backward()
self.optimizer.step()
if self.lr_scheduler:
self.lr_scheduler.step()
self.after_batch(epoch, ind_batch)
self.counter_batch += 1

Expand Down
7 changes: 7 additions & 0 deletions domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ def mk_parser_main():
help="name of pytorch optimizer",
)

parser.add_argument(
"--lr_scheduler",
type=str,
default=None,
help="name of pytorch learning rate scheduler",
)

parser.add_argument(
"--param_idx",
type=bool,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

"""
unit and end-end test for lr scheduler
"""
from tests.utils_test import utils_test_algo


def test_lr_scheduler():
"""
train
"""
args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \
--nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR"
utils_test_algo(args)

0 comments on commit 747476c

Please sign in to comment.