From 41bc0e70e9fdafc7133658359d689fe356e20a55 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 4 Dec 2024 17:55:16 +0100 Subject: [PATCH] lr scheduler via cmd arguments --- domainlab/algos/trainers/a_trainer.py | 8 ++++++-- domainlab/arg_parser.py | 7 +++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 2de178320..051cc1e6f 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -5,7 +5,7 @@ import torch from torch import optim -from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim import lr_scheduler from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -14,7 +14,6 @@ def mk_opt(model, aconf): """ create optimizer """ - scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos) if model._decoratee is None: class_opt = getattr(optim, aconf.opt) optimizer = class_opt(model.parameters(), lr=aconf.lr) @@ -29,6 +28,11 @@ def mk_opt(model, aconf): # {'params': model._decoratee.parameters()} # ], lr=aconf.lr) optimizer = optim.Adam(list_par, lr=aconf.lr) + if aconf.lr_scheduler is not None: + class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler) + scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos) + else: + scheduler = None return optimizer, scheduler diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 046810a66..47272d7fb 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -264,6 +264,13 @@ def mk_parser_main(): help="name of pytorch optimizer", ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="CosineAnnealingLR", + help="name of pytorch learning rate scheduler", + ) + parser.add_argument( "--param_idx", type=bool,