Skip to content

Commit

Permalink
Type fixes for using the MultistepLR scheduler (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhshkdz authored Sep 14, 2023
1 parent 07b69c2 commit 0505017
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion ocpmodels/models/equiformer_v2/trainers/forces_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,17 @@ def multiply(obj, num):
scheduler_params = self.config["optim"]["scheduler_params"]
for k in scheduler_params.keys():
if "epochs" in k:
if isinstance(scheduler_params[k], (int, float, list)):
if isinstance(scheduler_params[k], (int, float)):
scheduler_params[k] = int(
multiply(scheduler_params[k], n_iter_per_epoch)
)
elif isinstance(scheduler_params[k], list):
scheduler_params[k] = [
int(x)
for x in multiply(
scheduler_params[k], n_iter_per_epoch
)
]
self.scheduler = LRScheduler(self.optimizer, self.config["optim"])

self.clip_grad_norm = self.config["optim"].get("clip_grad_norm")
Expand Down
2 changes: 1 addition & 1 deletion ocpmodels/models/equiformer_v2/trainers/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MultistepLRLambda:
def __init__(self, scheduler_params) -> None:
self.warmup_epochs = aii(scheduler_params["warmup_epochs"], int)
self.lr_warmup_factor = aii(scheduler_params["warmup_factor"], float)
self.lr_decay_epochs = aii(scheduler_params["decay_epochs"], int)
self.lr_decay_epochs = aii(scheduler_params["decay_epochs"], list)
self.lr_gamma = aii(scheduler_params["decay_rate"], float)

def __call__(self, current_step: int) -> float:
Expand Down

0 comments on commit 0505017

Please sign in to comment.