From f9bf3f2cbeb9af195bf144861e0b423a03cd5d2c Mon Sep 17 00:00:00 2001 From: iliaschair Date: Sat, 30 Nov 2024 14:51:11 +0100 Subject: [PATCH] - replace all occurrences of `torch.cuda.amp.autocast(args...)` with `torch.autocast("cuda", args...)` - replace all occurrences of `torch.cuda.amp.GradScaler(args...)` with `torch.GradScaler("cuda", args...)` --- docs/tutorials/advanced/embedding_monkeypatch.py | 4 ++-- .../core/models/equiformer_v2/equiformer_v2_deprecated.py | 2 +- src/fairchem/core/models/equiformer_v2/layer_norm.py | 8 ++++---- .../core/models/equiformer_v2/trainers/dens_trainer.py | 4 ++-- src/fairchem/core/models/gemnet_oc/gemnet_oc.py | 6 +++--- src/fairchem/core/models/gemnet_oc/layers/force_scaler.py | 2 +- src/fairchem/core/modules/scaling/fit.py | 2 +- src/fairchem/core/trainers/base_trainer.py | 4 ++-- src/fairchem/core/trainers/ocp_trainer.py | 4 ++-- 9 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/tutorials/advanced/embedding_monkeypatch.py b/docs/tutorials/advanced/embedding_monkeypatch.py index b2983c29ac..e16028abd0 100644 --- a/docs/tutorials/advanced/embedding_monkeypatch.py +++ b/docs/tutorials/advanced/embedding_monkeypatch.py @@ -93,7 +93,7 @@ def newforward(self, data): x_E = self.out_mlp_E(torch.cat(xs_E, dim=-1)) if self.direct_forces: x_F = self.out_mlp_F(torch.cat(xs_F, dim=-1)) - with torch.cuda.amp.autocast(False): + with torch.autocast("cuda", enabled=False): E_t = self.out_energy(x_E.float()) if self.direct_forces: F_st = self.out_forces(x_F.float()) @@ -185,7 +185,7 @@ def embed(self, atoms): self.trainer.ema.copy_to() with ( - torch.cuda.amp.autocast(enabled=self.trainer.scaler is not None), + torch.autocast("cuda", enabled=self.trainer.scaler is not None), torch.no_grad(), ): out = self.trainer.model(batch_list) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py index 1da2ed3adb..5af270045e 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py @@ -591,7 +591,7 @@ def forward(self, data): # We can also write this as # \hat{E_DFT} = E_std * (\hat{E} + E_ref / E_std) + E_mean, # which is why we save E_ref / E_std as the linear reference. - with torch.cuda.amp.autocast(False): + with torch.autocast("cuda", enabled=False): energy = energy.to(self.energy_lin_ref.dtype).index_add( 0, graph.batch_full, diff --git a/src/fairchem/core/models/equiformer_v2/layer_norm.py b/src/fairchem/core/models/equiformer_v2/layer_norm.py index 8edcfd62fa..e23f573c37 100755 --- a/src/fairchem/core/models/equiformer_v2/layer_norm.py +++ b/src/fairchem/core/models/equiformer_v2/layer_norm.py @@ -72,7 +72,7 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps})" - @torch.cuda.amp.autocast(enabled=False) + @torch.autocast("cuda", enabled=False) def forward(self, node_input): """ Assume input is of shape [N, sphere_basis, C] @@ -172,7 +172,7 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps}, std_balance_degrees={self.std_balance_degrees})" - @torch.cuda.amp.autocast(enabled=False) + @torch.autocast("cuda", enabled=False) def forward(self, node_input): """ Assume input is of shape [N, sphere_basis, C] @@ -260,7 +260,7 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps})" - @torch.cuda.amp.autocast(enabled=False) + @torch.autocast("cuda", enabled=False) def forward(self, node_input): """ Assume input is of shape [N, sphere_basis, C] @@ -354,7 +354,7 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps}, centering={self.centering}, std_balance_degrees={self.std_balance_degrees})" - @torch.cuda.amp.autocast(enabled=False) + @torch.autocast("cuda", enabled=False) def forward(self, node_input): """ Assume input is of shape [N, sphere_basis, C] diff --git a/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py b/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py index 11735d7bb9..8bacc49133 100644 --- a/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py +++ b/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py @@ -392,7 +392,7 @@ def train(self, disable_eval_tqdm=False): ) # Forward, loss, backward. #TODO update this with new signatures - with torch.cuda.amp.autocast(enabled=self.scaler is not None): + with torch.autocast("cuda", enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) @@ -767,7 +767,7 @@ def predict( desc=f"device {rank}", disable=disable_tqdm, ): - with torch.cuda.amp.autocast(enabled=self.scaler is not None): + with torch.autocast("cuda", enabled=self.scaler is not None): out = self._forward(batch) for key in out: diff --git a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py index c982b7d43a..d29a7ed614 100644 --- a/src/fairchem/core/models/gemnet_oc/gemnet_oc.py +++ b/src/fairchem/core/models/gemnet_oc/gemnet_oc.py @@ -1267,7 +1267,7 @@ def forward(self, data): x_E = self.out_mlp_E(torch.cat(xs_E, dim=-1)) if self.direct_forces: x_F = self.out_mlp_F(torch.cat(xs_F, dim=-1)) - with torch.cuda.amp.autocast(False): + with torch.autocast("cuda", enabled=False): E_t = self.out_energy(x_E.float()) if self.direct_forces: F_st = self.out_forces(x_F.float()) @@ -1465,7 +1465,7 @@ def forward( ) -> dict[str, torch.Tensor]: # Global output block for final predictions x_E = self.out_mlp_E(torch.cat(emb["xs_E"], dim=-1)) - with torch.cuda.amp.autocast(False): + with torch.autocast("cuda", enabled=False): E_t = self.out_energy(x_E.float()) nMolecules = torch.max(data.batch) + 1 @@ -1530,7 +1530,7 @@ def forward( self, data: Batch, emb: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: if self.direct_forces: - with torch.cuda.amp.autocast(False): + with torch.autocast("cuda", enabled=False): x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1).float()) F_st = self.out_forces(x_F) diff --git a/src/fairchem/core/models/gemnet_oc/layers/force_scaler.py b/src/fairchem/core/models/gemnet_oc/layers/force_scaler.py index fe5ae1810a..d41e144657 100644 --- a/src/fairchem/core/models/gemnet_oc/layers/force_scaler.py +++ b/src/fairchem/core/models/gemnet_oc/layers/force_scaler.py @@ -15,7 +15,7 @@ class ForceScaler: """ Scales up the energy and then scales down the forces to prevent NaNs and infs in calculations using AMP. - Inspired by torch.cuda.amp.GradScaler. + Inspired by torch.GradScaler("cuda", args...). """ def __init__( diff --git a/src/fairchem/core/modules/scaling/fit.py b/src/fairchem/core/modules/scaling/fit.py index 462088318e..e3eea3d018 100644 --- a/src/fairchem/core/modules/scaling/fit.py +++ b/src/fairchem/core/modules/scaling/fit.py @@ -32,7 +32,7 @@ def _prefilled_input(prompt: str, prefill: str = "") -> str: def _train_batch(trainer: BaseTrainer, batch) -> None: with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=trainer.scaler is not None): + with torch.autocast("cuda", enabled=trainer.scaler is not None): out = trainer._forward(batch) loss = trainer._compute_loss(out, batch) del out, loss diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 90cdce0e58..b51e9507ab 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -153,7 +153,7 @@ def __init__( "gp_gpus": gp_gpus, } # AMP Scaler - self.scaler = torch.cuda.amp.GradScaler() if amp and not self.cpu else None + self.scaler = torch.GradScaler("cuda") if amp and not self.cpu else None # Fill in SLURM information in config, if applicable if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]: @@ -883,7 +883,7 @@ def validate(self, split: str = "val", disable_tqdm: bool = False): disable=disable_tqdm, ): # Forward. - with torch.cuda.amp.autocast(enabled=self.scaler is not None): + with torch.autocast("cuda", enabled=self.scaler is not None): batch.to(self.device) out = self._forward(batch) loss = self._compute_loss(out, batch) diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index a8976773c6..e9f2b01d50 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -161,7 +161,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # Get a batch. batch = next(train_loader_iter) # Forward, loss, backward. - with torch.cuda.amp.autocast(enabled=self.scaler is not None): + with torch.autocast("cuda", enabled=self.scaler is not None): out = self._forward(batch) loss = self._compute_loss(out, batch) @@ -468,7 +468,7 @@ def predict( desc=f"device {rank}", disable=disable_tqdm, ): - with torch.cuda.amp.autocast(enabled=self.scaler is not None): + with torch.autocast("cuda", enabled=self.scaler is not None): out = self._forward(batch) for target_key in self.config["outputs"]: