Skip to content

Commit

Permalink
Merge pull request #896 from marrlab/opt_str_cmd
Browse files Browse the repository at this point in the history
optimizer name string as command line input
  • Loading branch information
smilesun authored Nov 29, 2024
2 parents 6da7252 + 495c314 commit 69e8b8c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def mk_opt(model, aconf):
create optimizer
"""
if model._decoratee is None:
optimizer = optim.Adam(model.parameters(), lr=aconf.lr)
class_opt = getattr(optim, aconf.opt)
optimizer = class_opt(model.parameters(), lr=aconf.lr)
else:
var1 = model.parameters()
var2 = model._decoratee.parameters()
Expand Down
7 changes: 7 additions & 0 deletions domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ def mk_parser_main():
"Default is zoutput/benchmarks/shell_benchmark",
)

parser.add_argument(
"--opt",
type=str,
default="Adam",
help="name of pytorch optimizer",
)

parser.add_argument(
"--param_idx",
type=bool,
Expand Down

0 comments on commit 69e8b8c

Please sign in to comment.