Skip to content

Commit

Permalink
- replace all occurrences of torch.cuda.amp.autocast(args...) with …
Browse files Browse the repository at this point in the history
…`torch.autocast("cuda", args...)`

- replace all occurrences of `torch.cuda.amp.GradScaler(args...)` with `torch.GradScaler("cuda", args...)`
  • Loading branch information
IliasChair14 committed Nov 30, 2024
1 parent e11e78e commit f9bf3f2
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/advanced/embedding_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def newforward(self, data):
x_E = self.out_mlp_E(torch.cat(xs_E, dim=-1))
if self.direct_forces:
x_F = self.out_mlp_F(torch.cat(xs_F, dim=-1))
with torch.cuda.amp.autocast(False):
with torch.autocast("cuda", enabled=False):
E_t = self.out_energy(x_E.float())
if self.direct_forces:
F_st = self.out_forces(x_F.float())
Expand Down Expand Up @@ -185,7 +185,7 @@ def embed(self, atoms):
self.trainer.ema.copy_to()

with (
torch.cuda.amp.autocast(enabled=self.trainer.scaler is not None),
torch.autocast("cuda", enabled=self.trainer.scaler is not None),
torch.no_grad(),
):
out = self.trainer.model(batch_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def forward(self, data):
# We can also write this as
# \hat{E_DFT} = E_std * (\hat{E} + E_ref / E_std) + E_mean,
# which is why we save E_ref / E_std as the linear reference.
with torch.cuda.amp.autocast(False):
with torch.autocast("cuda", enabled=False):
energy = energy.to(self.energy_lin_ref.dtype).index_add(
0,
graph.batch_full,
Expand Down
8 changes: 4 additions & 4 deletions src/fairchem/core/models/equiformer_v2/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps})"

@torch.cuda.amp.autocast(enabled=False)
@torch.autocast("cuda", enabled=False)
def forward(self, node_input):
"""
Assume input is of shape [N, sphere_basis, C]
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps}, std_balance_degrees={self.std_balance_degrees})"

@torch.cuda.amp.autocast(enabled=False)
@torch.autocast("cuda", enabled=False)
def forward(self, node_input):
"""
Assume input is of shape [N, sphere_basis, C]
Expand Down Expand Up @@ -260,7 +260,7 @@ def __init__(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps})"

@torch.cuda.amp.autocast(enabled=False)
@torch.autocast("cuda", enabled=False)
def forward(self, node_input):
"""
Assume input is of shape [N, sphere_basis, C]
Expand Down Expand Up @@ -354,7 +354,7 @@ def __init__(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(lmax={self.lmax}, num_channels={self.num_channels}, eps={self.eps}, centering={self.centering}, std_balance_degrees={self.std_balance_degrees})"

@torch.cuda.amp.autocast(enabled=False)
@torch.autocast("cuda", enabled=False)
def forward(self, node_input):
"""
Assume input is of shape [N, sphere_basis, C]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def train(self, disable_eval_tqdm=False):
)

# Forward, loss, backward. #TODO update this with new signatures
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
with torch.autocast("cuda", enabled=self.scaler is not None):
out = self._forward(batch)
loss = self._compute_loss(out, batch)

Expand Down Expand Up @@ -767,7 +767,7 @@ def predict(
desc=f"device {rank}",
disable=disable_tqdm,
):
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
with torch.autocast("cuda", enabled=self.scaler is not None):
out = self._forward(batch)

for key in out:
Expand Down
6 changes: 3 additions & 3 deletions src/fairchem/core/models/gemnet_oc/gemnet_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ def forward(self, data):
x_E = self.out_mlp_E(torch.cat(xs_E, dim=-1))
if self.direct_forces:
x_F = self.out_mlp_F(torch.cat(xs_F, dim=-1))
with torch.cuda.amp.autocast(False):
with torch.autocast("cuda", enabled=False):
E_t = self.out_energy(x_E.float())
if self.direct_forces:
F_st = self.out_forces(x_F.float())
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def forward(
) -> dict[str, torch.Tensor]:
# Global output block for final predictions
x_E = self.out_mlp_E(torch.cat(emb["xs_E"], dim=-1))
with torch.cuda.amp.autocast(False):
with torch.autocast("cuda", enabled=False):
E_t = self.out_energy(x_E.float())

nMolecules = torch.max(data.batch) + 1
Expand Down Expand Up @@ -1530,7 +1530,7 @@ def forward(
self, data: Batch, emb: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
if self.direct_forces:
with torch.cuda.amp.autocast(False):
with torch.autocast("cuda", enabled=False):
x_F = self.out_mlp_F(torch.cat(emb["xs_F"], dim=-1).float())
F_st = self.out_forces(x_F)

Expand Down
2 changes: 1 addition & 1 deletion src/fairchem/core/models/gemnet_oc/layers/force_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ForceScaler:
"""
Scales up the energy and then scales down the forces
to prevent NaNs and infs in calculations using AMP.
Inspired by torch.cuda.amp.GradScaler.
Inspired by torch.GradScaler("cuda", args...).
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion src/fairchem/core/modules/scaling/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _prefilled_input(prompt: str, prefill: str = "") -> str:

def _train_batch(trainer: BaseTrainer, batch) -> None:
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=trainer.scaler is not None):
with torch.autocast("cuda", enabled=trainer.scaler is not None):
out = trainer._forward(batch)
loss = trainer._compute_loss(out, batch)
del out, loss
Expand Down
4 changes: 2 additions & 2 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
"gp_gpus": gp_gpus,
}
# AMP Scaler
self.scaler = torch.cuda.amp.GradScaler() if amp and not self.cpu else None
self.scaler = torch.GradScaler("cuda") if amp and not self.cpu else None

# Fill in SLURM information in config, if applicable
if "SLURM_JOB_ID" in os.environ and "folder" in self.config["slurm"]:
Expand Down Expand Up @@ -883,7 +883,7 @@ def validate(self, split: str = "val", disable_tqdm: bool = False):
disable=disable_tqdm,
):
# Forward.
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
with torch.autocast("cuda", enabled=self.scaler is not None):
batch.to(self.device)
out = self._forward(batch)
loss = self._compute_loss(out, batch)
Expand Down
4 changes: 2 additions & 2 deletions src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
# Get a batch.
batch = next(train_loader_iter)
# Forward, loss, backward.
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
with torch.autocast("cuda", enabled=self.scaler is not None):
out = self._forward(batch)
loss = self._compute_loss(out, batch)

Expand Down Expand Up @@ -468,7 +468,7 @@ def predict(
desc=f"device {rank}",
disable=disable_tqdm,
):
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
with torch.autocast("cuda", enabled=self.scaler is not None):
out = self._forward(batch)

for target_key in self.config["outputs"]:
Expand Down

0 comments on commit f9bf3f2

Please sign in to comment.