Skip to content

Commit

Permalink
Use the distributed_type directly
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Oct 8, 2024
1 parent c8c5f87 commit 21cab59
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 17 deletions.
8 changes: 2 additions & 6 deletions benchmarks/fp8/ms_amp/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
Expand All @@ -36,10 +35,7 @@

def train_baseline(opt_level="O2"):
set_seed(42)
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()
scaler = get_grad_scaler()
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(MODEL_NAME)
accelerator = Accelerator()
device = accelerator.device
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/fp8/ms_amp/non_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
import msamp
import torch
from fp8_utils import evaluate_model, get_training_utilities
from packaging import version

from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import FP8RecipeKwargs, set_seed
from accelerate.utils import FP8RecipeKwargs, get_grad_scaler, set_seed


MODEL_NAME = "bert-base-cased"
Expand All @@ -42,10 +41,7 @@ def train_baseline(opt_level="O2"):

base_model_results = evaluate_model(model, eval_dataloader, METRIC)
model.train()
if version.parse(torch.__version__) > version.parse("2.3"):
scaler = torch.amp.GradScaler("cuda")
else:
scaler = torch.cuda.amp.GradScaler()
scaler = get_grad_scaler()

for batch in train_dataloader:
batch = batch.to("cuda")
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def __init__(
):
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
self.scaler = get_grad_scaler(self.distributed_type == DistributedType.FSDP, **kwargs)
self.scaler = get_grad_scaler(self.distributed_type, **kwargs)

elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
DistributedType.DEEPSPEED,
Expand Down
8 changes: 4 additions & 4 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,18 +1873,18 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
return contextlib.nullcontext()


def get_grad_scaler(use_fsdp: bool = False, **kwargs):
def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
"""
A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return
it.
Args:
use_fsdp (`bool`, *optional*, defaults to False):
Whether FSDP is enabled.
distributed_type (`DistributedType`, *optional*, defaults to None):
The type of distributed environment.
kwargs:
Additional arguments for the utilized `GradScaler` constructor.
"""
if use_fsdp:
if distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

return ShardedGradScaler(**kwargs)
Expand Down

0 comments on commit 21cab59

Please sign in to comment.