Skip to content

Commit

Permalink
lr scheduler via cmd arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Dec 4, 2024
1 parent 1d2dc78 commit 41bc0e7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
8 changes: 6 additions & 2 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from torch import optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import lr_scheduler

from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler

Expand All @@ -14,7 +14,6 @@ def mk_opt(model, aconf):
"""
create optimizer
"""
scheduler = CosineAnnealingLR(optimizer, T_max=aconf.epos)
if model._decoratee is None:
class_opt = getattr(optim, aconf.opt)
optimizer = class_opt(model.parameters(), lr=aconf.lr)
Expand All @@ -29,6 +28,11 @@ def mk_opt(model, aconf):
# {'params': model._decoratee.parameters()}
# ], lr=aconf.lr)
optimizer = optim.Adam(list_par, lr=aconf.lr)
if aconf.lr_scheduler is not None:
class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler)
scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos)
else:
scheduler = None
return optimizer, scheduler


Expand Down
7 changes: 7 additions & 0 deletions domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ def mk_parser_main():
help="name of pytorch optimizer",
)

parser.add_argument(
"--lr_scheduler",
type=str,
default="CosineAnnealingLR",
help="name of pytorch learning rate scheduler",
)

parser.add_argument(
"--param_idx",
type=bool,
Expand Down

0 comments on commit 41bc0e7

Please sign in to comment.