diff --git a/src/fairchem/core/models/equiformer_v2/trainers/forces_trainer.py b/src/fairchem/core/models/equiformer_v2/trainers/forces_trainer.py index 96364ae901..49cb6b4978 100755 --- a/src/fairchem/core/models/equiformer_v2/trainers/forces_trainer.py +++ b/src/fairchem/core/models/equiformer_v2/trainers/forces_trainer.py @@ -59,7 +59,8 @@ def multiply(obj, num): ] self.scheduler = LRScheduler(self.optimizer, self.config["optim"]) - self.clip_grad_norm = self.config["optim"].get("clip_grad_norm") + self.clip_grad_norm = self.config["optim"].get("clip_grad_norm",None) + self.clip_grad_value = self.config["optim"].get("clip_grad_value",None) self.ema_decay = self.config["optim"].get("ema_decay") if self.ema_decay: self.ema = ExponentialMovingAverage( diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 94becb924c..16e60519eb 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -702,6 +702,9 @@ def load_extras(self) -> None: self.clip_grad_norm = aii( self.config["optim"].get("clip_grad_norm", None), (int, float) ) + self.clip_grad_value = aii( + self.config["optim"].get("clip_grad_value", None), (int, float) + ) self.ema_decay = aii(self.config["optim"].get("ema_decay"), float) if self.ema_decay: self.ema = ExponentialMovingAverage( @@ -889,15 +892,18 @@ def _backward(self, loss) -> None: "Please check if all shared parameters are used " "and point to PyTorch parameters." ) - if self.clip_grad_norm: + if self.clip_grad_norm or self.clip_grad_value: if self.scaler: self.scaler.unscale_(self.optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - max_norm=self.clip_grad_norm, - ) - if self.logger is not None: - self.logger.log({"grad_norm": grad_norm}, step=self.step, split="train") + if self.clip_grad_norm: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + max_norm=self.clip_grad_norm, + ) + if self.logger is not None: + self.logger.log({"grad_norm": grad_norm}, step=self.step, split="train") + if self.clip_grad_value: + torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=self.clip_grad_value) if self.scaler: self.scaler.step(self.optimizer) self.scaler.update()