diff --git a/ocpmodels/models/equiformer_v2/trainers/forces_trainer.py b/ocpmodels/models/equiformer_v2/trainers/forces_trainer.py index b8bbc3f66..6284a3dac 100755 --- a/ocpmodels/models/equiformer_v2/trainers/forces_trainer.py +++ b/ocpmodels/models/equiformer_v2/trainers/forces_trainer.py @@ -146,10 +146,17 @@ def multiply(obj, num): scheduler_params = self.config["optim"]["scheduler_params"] for k in scheduler_params.keys(): if "epochs" in k: - if isinstance(scheduler_params[k], (int, float, list)): + if isinstance(scheduler_params[k], (int, float)): scheduler_params[k] = int( multiply(scheduler_params[k], n_iter_per_epoch) ) + elif isinstance(scheduler_params[k], list): + scheduler_params[k] = [ + int(x) + for x in multiply( + scheduler_params[k], n_iter_per_epoch + ) + ] self.scheduler = LRScheduler(self.optimizer, self.config["optim"]) self.clip_grad_norm = self.config["optim"].get("clip_grad_norm") diff --git a/ocpmodels/models/equiformer_v2/trainers/lr_scheduler.py b/ocpmodels/models/equiformer_v2/trainers/lr_scheduler.py index 3bf202681..f3926212d 100755 --- a/ocpmodels/models/equiformer_v2/trainers/lr_scheduler.py +++ b/ocpmodels/models/equiformer_v2/trainers/lr_scheduler.py @@ -75,7 +75,7 @@ class MultistepLRLambda: def __init__(self, scheduler_params) -> None: self.warmup_epochs = aii(scheduler_params["warmup_epochs"], int) self.lr_warmup_factor = aii(scheduler_params["warmup_factor"], float) - self.lr_decay_epochs = aii(scheduler_params["decay_epochs"], int) + self.lr_decay_epochs = aii(scheduler_params["decay_epochs"], list) self.lr_gamma = aii(scheduler_params["decay_rate"], float) def __call__(self, current_step: int) -> float: