diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index a484812d4..a51c34c14 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -14,7 +14,8 @@ def mk_opt(model, aconf): create optimizer """ if model._decoratee is None: - optimizer = optim.Adam(model.parameters(), lr=aconf.lr) + class_opt = getattr(optim, aconf.opt) + optimizer = class_opt(model.parameters(), lr=aconf.lr) else: var1 = model.parameters() var2 = model._decoratee.parameters() diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index d456e3045..046810a66 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -257,6 +257,13 @@ def mk_parser_main(): "Default is zoutput/benchmarks/shell_benchmark", ) + parser.add_argument( + "--opt", + type=str, + default="Adam", + help="name of pytorch optimizer", + ) + parser.add_argument( "--param_idx", type=bool,