From b0528392c83c1d7afb823d87abb94a924f7f69b1 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 15 Dec 2023 13:07:55 -0500 Subject: [PATCH] Integrate MS-AMP Support for FP8 as a seperate backend (#2232) * Redo with new version * Store * Working version * Seperate for now * Min diff * check if available * Better docstring * Check for multiple models and optimizers * Check for TE and MSAMP args seperately * String clarity * Better docstring and types * Quality * Simplify a bunch for fp8 * Convert literals to type alias * Better err * Docs * toc typo * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Maria Khalusova * Address doc nits --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Maria Khalusova --- docs/source/_toctree.yml | 4 + .../concept_guides/low_precision_training.md | 74 +++++++++++++++ docs/source/package_reference/utilities.md | 2 + .../usage_guides/low_precision_training.md | 92 +++++++++++++++++++ src/accelerate/accelerator.py | 59 +++++++++--- src/accelerate/state.py | 19 +++- src/accelerate/utils/__init__.py | 3 + src/accelerate/utils/dataclasses.py | 81 +++++++++++++--- src/accelerate/utils/environment.py | 12 +++ src/accelerate/utils/imports.py | 18 +++- 10 files changed, 337 insertions(+), 27 deletions(-) create mode 100644 docs/source/concept_guides/low_precision_training.md create mode 100644 docs/source/usage_guides/low_precision_training.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7e0391019d1..23d07dcd843 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -41,6 +41,8 @@ title: Using experiment trackers - local: usage_guides/mps title: How to use Apple Silicon M1 GPUs + - local: usage_guides/low_precision_training + title: How to train in low precision (FP8) - local: usage_guides/deepspeed title: How to use DeepSpeed - local: usage_guides/fsdp @@ -63,6 +65,8 @@ title: Executing and deferring jobs - local: concept_guides/gradient_synchronization title: Gradient synchronization + - local: concept_guides/low_precision_training + title: How training in low-precision environments is possible (FP8) - local: concept_guides/training_tpu title: TPU best practices title: Concepts and fundamentals diff --git a/docs/source/concept_guides/low_precision_training.md b/docs/source/concept_guides/low_precision_training.md new file mode 100644 index 00000000000..39476090540 --- /dev/null +++ b/docs/source/concept_guides/low_precision_training.md @@ -0,0 +1,74 @@ + + +# Low Precision Training Methods + +The release of new kinds of hardware led to the emergence of new training paradigms that better utilize them. Currently, this is in the form of training +in 8-bit precision using packages such as [TranformersEngine](https://github.com/NVIDIA/TransformerEngine) (TE) or [MS-AMP](https://github.com/Azure/MS-AMP/tree/main). + +For an introduction to the topics discussed today, we recommend reviewing the [low-precision usage guide](../usage_guides/low_precision_training.md) as this documentation will reference it regularly. + +## A Quick Chart + +Below is a quick chart from the MS-AMP documentation showing the different bit-precisions for each solution during training: + +Optimization Level | Computation(GEMM) | Comm | Weight | Master Weight | Weight Gradient | Optimizer States +-- | -- | -- | -- | -- | -- | -- +FP16 AMP | FP16 | FP32 | FP32 | N/A | FP32 | FP32+FP32 +Nvidia TE | FP8 | FP32 | FP32 | N/A | FP32 | FP32+FP32 +MS-AMP O1 | FP8 | FP8 | FP16 | N/A | FP8 | FP32+FP32 +MS-AMP O2 | FP8 | FP8 | FP16 | N/A | FP8 | FP8+FP16 +MS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 + +## `TransformersEngine` + +`TranformersEngine` is the first solution to trying to train in 8-bit floating point. It works by using drop-in replacement layers for certain ones in a model that utilize their FP8-engine to reduce the number of bits (such as 32 to 8) without degrading the final accuracy of the model. + +Specifically, 🤗 Accelerate will find and replace the following layers with `TranformersEngine` versions: + +* `nn.LayerNorm` for `te.LayerNorm` +* `nn.Linear` for `te.Linear` + +As a result we wind up with a model that has most of its layers in BF16, while some layers are in FP8 reducing some of the memory. + +Anecdotally, we have noticed that performance gains don't really start showing when using `TransformerEngine` until a large majority of the layers +in the model are made up of those two layers to replace. As a result, only larger models have shown performance improvements when the number of parameters is around and upwards of a few billion. + +The `TransformerEngine` can receive many different arguments that customize how it performs FP8 calculations and what they do. A full list of the arguments is available below: + +* `margin`: The margin to use for the gradient scaling. +* `interval`: The interval to use for how often the scaling factor is recomputed. +* `fp8_format``: The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`. +* `amax_history_len`: The length of the history to use for the scaling factor computation +* `amax_compute_algo`: The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. +* `override_linear_precision`: Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. + +You can customize each of these as part of [`utils.FP8RecipeKwargs`] to help optimize performance of your models. + +If we notice in the chart mentioned earlier, TE simply casts the computation layers into FP8, while everything else is in FP32. As a result this winds up utilizing the most memory but does so with the benefit of guaranteeing the least amount of loss in end accuracy during training. + +## `MS-AMP` + +MS-AMP takes a different approach to `TransformersEngine` by providing three different optimization levels to convert more operations in FP8 or FP16. + +* The base optimization level (`O1`), passes communications of the weights (such as in DDP) in FP8, stores the weights of the model in FP16, and leaves the optimizer states in FP32. The main benefit of this optimization level is that we can reduce the communication bandwidth by essentially half. Additionally, more GPU memory is saved due to 1/2 of everything being cast in FP8, and the weights being cast to FP16. Notably, both the optimizer states remain in FP32. + +* The second optimization level (`O2`) improves upon this by also reducing the precision of the optimizer states. One is in FP8 while the other is in FP16. Generally it's been shown that this will only provide a net-gain of no degredated end accuracy, increased training speed, and reduced memory as now every state is either in FP16 or FP8. + +* Finally, MS-AMP has a third optimization level (`O3`) which helps during DDP scenarios such as DeepSpeed. The weights of the model in memory are fully cast to FP8, and the master weights are now stored in FP16. This fully reduces memory by the highest factor as now not only is almost everything in FP8, only two states are left in FP16. Currently, only DeepSpeed versions up through 0.9.2 are supported, so this capability is not included in the 🤗 Accelerate integration + +## Combining the two + +More experiments need to be performed but it's been noted that combining both MS-AMP and TransformersEngine can lead to the highest throughput by relying on NVIDIA's optimized FP8 operators and utilizing how MS-AMP reduces the memory overhead. \ No newline at end of file diff --git a/docs/source/package_reference/utilities.md b/docs/source/package_reference/utilities.md index f4ea1cbcebb..7483267472f 100644 --- a/docs/source/package_reference/utilities.md +++ b/docs/source/package_reference/utilities.md @@ -48,6 +48,8 @@ These are basic dataclasses used throughout 🤗 Accelerate and they can be pass [[autodoc]] utils.PrecisionType +[[autodoc]] utils.FP8RecipeKwargs + [[autodoc]] utils.ProjectConfiguration ## Environmental Variables diff --git a/docs/source/usage_guides/low_precision_training.md b/docs/source/usage_guides/low_precision_training.md new file mode 100644 index 00000000000..a1899b7ccb1 --- /dev/null +++ b/docs/source/usage_guides/low_precision_training.md @@ -0,0 +1,92 @@ + + +# Low Precision Training Methods + +🤗 Accelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine` and `MS-AMP` packages. This documentation will help guide you through what hardware is supported, how to configure your [`Accelerator`] to leverage the low precision methods, and what you can expect when training. + +## What training on FP8 means + +To explore more of the nitty-gritty in traninig in FP8 with PyTorch and 🤗 Accelerate, check out the [concept_guide](../concept_guides/low_precision_training.md) on why this can be difficult. But essentially rather than training in BF16, some (or all) aspects of training a model can be performed using 8 bits instead of 16. The challenge is doing so without degrading final performance. + +This is only enabled on specific NVIDIA hardware, namely: + +* Anything after the 3000 series consumer graphics cards (such as the 4090) +* Hopper-based GPU architectures (such as the `H100` and `H200`) + +What this will result in is some gain in the memory used (as we've cut the needed memory in half for some parts of training) and an increase in throughput *should* be seen as well for larger models that can replace certain layers with FP8-enabled ones. + +## Configuring the Accelerator + +Currently two different backends for FP8 are supported (`TransformersEngine` and `MS-AMP`), each with different capabilities and configurations. + +To use either, the same core API is used. Just pass `mixed_precision="fp8"` to either the [`Accelerator`], during `accelerate config` when prompted about mixed precision, or as part of your `config.yaml` file in the `mixed_precision` key: + +```{python} +from accelerate import Accelerator +accelerator = Accelerator(mixed_precision="fp8") +``` + +By default, if `MS-AMP` is available in your environment, 🤗 Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`]: + +```{python} +from accelerate import Accelerator +from accelerate.utils import FP8RecipeKwargs +kwargs = [FP8RecipeKwargs(backend="msamp")] +# Or to specify the backend as `TransformersEngine` even if MS-AMP is installed +# kwargs = [FP8RecipeKwargs(backend="te")] +accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs) +``` + +## Configuring MS-AMP + +Of the two, `MS-AMP` is traditionally the easier one to configure as there is only a single argument: the optimization level. + +Currently two levels of optimization are supported in the 🤗 Accelerate integration, `"O1"` and `"O2"` (using the letter 'o', not zero). + +* `"O1"` will cast the weight gradients and `all_reduce` communications to happen in 8-bit, while the rest are done in 16 bit. This reduces the general GPU memory usage and speeds up communication bandwidths. +* `"O2"` will also cast first-order optimizer states into 8 bit, while the second order states are in FP16. (Currently just the `Adam` optimizer is supported). This tries it's best to minimize final accuracy degredation and will save the highest potential memory. + +To specify an optimization level, pass it to the `FP8KwargsHandler` by setting the `optimization_level` argument: + +```{python} +from accelerate import Accelerator +from accelerate.utils import FP8RecipeKwargs +kwargs = [FP8RecipeKwargs(backend="msamp", optimization_level="O2")] +accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs) +``` + +## Configuring TransformersEngine + +TransformersEngine has much more available for customizing how and what FP8 calculations are performed. A full list of supported arguments and what they mean are available in [NVIDIA's documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html), however they are restated as part of [`FP8KwargsHandler`]'s docstring for your convience. + +🤗 Accelerate tries to set sensible defaults, but exploring and tweaking the various parameters yourself can lead to better performance potentially. + +To use it, specify `backend="te"` and modify any of the arguments you want as part of your kwarg handler: + +```{python} +from accelerate import Accelerator +from accelerate.utils import FP8RecipeKwargs +kwargs = [FP8RecipeKwargs(backend="te", ...)] +accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs) +``` + +## Futher Reading + +To learn more about training in FP8 please check out the following resources: + +* [Our concept guide](../concept_guides/low_precision_training.md) detailing into more about both TransformersEngine and MS-AMP +* [The `transformers-engine` documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html) +* [The `MS-AMP` documentation](https://azure.github.io/MS-AMP/docs/) \ No newline at end of file diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index bc4e351152f..e2544da791f 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -79,6 +79,7 @@ is_fp8_available, is_ipex_available, is_megatron_lm_available, + is_msamp_available, is_npu_available, is_torch_version, is_tpu_available, @@ -366,6 +367,8 @@ def __init__( raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.") else: self.autocast_handler = handler + if self.fp8_recipe_handler is None and mixed_precision == "fp8": + self.fp8_recipe_handler = FP8RecipeKwargs() kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {} self.state = AcceleratorState( @@ -1196,7 +1199,7 @@ def prepare(self, *args, device_placement=None): # If we're dealing with device placement, this deals with that by... tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.TPU - if tpu_should_fix_optimizer or self.mixed_precision == "fp8": + if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"): # 1. grabbing old model parameters old_named_params = self._get_named_parameters(*args) @@ -1210,12 +1213,16 @@ def prepare(self, *args, device_placement=None): elif self.distributed_type == DistributedType.MEGATRON_LM: result = self._prepare_megatron_lm(*args) else: + if self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP": + args = self._prepare_msamp(*args) + # MS-AMP will handle the device placement + device_placement = [False for _ in args] result = tuple( self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) ) result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) - if tpu_should_fix_optimizer or self.mixed_precision == "fp8": + if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"): # 2. grabbing new model parameters new_named_params = self._get_named_parameters(*result) # 3. building a map from the first to the second @@ -1284,7 +1291,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) else: model.forward = convert_outputs_to_fp32(new_forward) - elif self.mixed_precision == "fp8": + elif self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE": if not has_transformer_engine_layers(model): with torch.no_grad(): convert_model(model) @@ -1295,15 +1302,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e if "fp8_format" in kwargs: kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"]) fp8_recipe = te_recipe.DelayedScaling(**kwargs) - cuda_device_capacity = torch.cuda.get_device_capability() - fp8_enabled = cuda_device_capacity >= (8, 9) - if not fp8_enabled: - logger.warn( - f"The current device has compute capability of {cuda_device_capacity} which is " - "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace " - "or higher, compute capability of 8.9 or higher). Will use FP16 instead." - ) - model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward) + model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward) if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( model, "hf_device_map", False @@ -1749,6 +1748,42 @@ def _prepare_ipex(self, *args): result[i] = optimizer return tuple(result) + def _prepare_msamp(self, *args): + if not is_msamp_available(): + raise ImportError( + "MS-AMP was not found on your system. Please ensure that MS-AMP is available " + " or choose `'te'` as the backend for FP8 mixed precision training." + ) + else: + import msamp + + model, optimizer = None, None + num_models, num_optimizers = 0, 0 + result = [obj for obj in args] + for obj in result: + if isinstance(obj, torch.nn.Module): + model = obj + num_models += 1 + elif isinstance(obj, (torch.optim.Optimizer)): + optimizer = obj + num_optimizers += 1 + if optimizer is None or model is None: + raise ValueError( + "You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP." + ) + elif num_models > 1 or num_optimizers > 1: + raise ValueError( + f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP." + ) + else: + model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level) + for i in range(len(result)): + if isinstance(result[i], torch.nn.Module): + result[i] = model + elif isinstance(result[i], (torch.optim.Optimizer)): + result[i] = optimizer + return tuple(result) + def prepare_data_loader( self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None ): diff --git a/src/accelerate/state.py b/src/accelerate/state.py index b4f95b03fbd..17b43708b15 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging import math import os import threading @@ -29,6 +30,7 @@ DynamoBackend, GradientAccumulationPlugin, check_cuda_p2p_ib_support, + check_fp8_capability, get_ccl_version, get_int_from_env, is_ccl_available, @@ -52,6 +54,8 @@ if is_npu_available(check_device=False): import torch_npu # noqa: F401 +logger = logging.getLogger(__name__) + def is_initialized() -> bool: """ @@ -765,8 +769,19 @@ def __init__( if mixed_precision is None else mixed_precision.lower() ) - if mixed_precision == "fp8" and not is_fp8_available(): - raise ValueError("Using `fp8` precision requires `transformer_engine` to be installed.") + if mixed_precision == "fp8": + if not is_fp8_available(): + raise ValueError( + "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed." + ) + elif not check_fp8_capability(): + logger.warning( + f"The current device has compute capability of {torch.cuda.get_device_capability()} which is " + "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace " + "or higher, compute capability of 8.9 or higher). Will use FP16 instead." + ) + mixed_precision = "fp16" + self.dynamo_plugin = dynamo_plugin if not _from_accelerator: raise ValueError( diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 83bb19502e5..4c9cd006547 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -40,6 +40,7 @@ from .environment import ( are_libraries_initialized, check_cuda_p2p_ib_support, + check_fp8_capability, get_int_from_env, parse_choice_from_env, parse_flag_from_env, @@ -65,6 +66,7 @@ is_megatron_lm_available, is_mlflow_available, is_mps_available, + is_msamp_available, is_npu_available, is_pandas_available, is_rich_available, @@ -72,6 +74,7 @@ is_tensorboard_available, is_timm_available, is_tpu_available, + is_transformer_engine_available, is_transformers_available, is_wandb_available, is_xpu_available, diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 6bc51c399e3..917861eed64 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -26,7 +26,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from datetime import timedelta -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, get_args import torch @@ -169,36 +169,93 @@ class InitProcessGroupKwargs(KwargsHandler): timeout: timedelta = timedelta(seconds=1800) +# Literals +Backend = Literal["msamp", "te"] +OptLevel = Literal["O1", "O2"] +FP8Format = Literal["E4M3", "HYBRID"] +AmaxComputeAlgorithm = Literal["max", "most_recent"] + + @dataclass class FP8RecipeKwargs(KwargsHandler): """ Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision - training. Please refer to the documentation of this - [class](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transformer_engine.common.recipe.DelayedScaling) - for more information on each argument. + training with `transformer-engine` or `ms-amp`. + + + + For more information on `transformer-engine` args, please refer to the API + [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). + + For more information on the `ms-amp` args, please refer to the Optimization Level + [documentation](https://azure.github.io/MS-AMP/docs/user-tutorial/optimization-level). + + ```python from accelerate import Accelerator from accelerate.utils import FP8RecipeKwargs - kwargs = FP8RecipeKwargs(fp8_format="HYBRID") + kwargs = FP8RecipeKwargs(backend="te", fp8_format="HYBRID") accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) ``` + + To use MS-AMP as an engine, pass `backend="msamp"` and the `optimization_level`: + + ```python + kwargs = FP8RecipeKwargs(backend="msamp", optimization_level="02") + ``` + + Args: + backend (`str`, *optional*, defaults to "msamp"): + Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine). + margin (`int`, *optional*, default to 0): + The margin to use for the gradient scaling. + interval (`int`, *optional*, default to 1): + The interval to use for how often the scaling factor is recomputed. + fp8_format (`str`, *optional*, default to "E4M3"): + The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`. + amax_history_len (`int`, *optional*, default to 1024): + The length of the history to use for the scaling factor computation + amax_compute_algo (`str`, *optional*, default to "most_recent"): + The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. + override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): + Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. + optimization_level (`str`), one of `O1`, `O2`. (default is `O2`): + What level of 8-bit collective communication should be used with MS-AMP. In general: + * O1: Weight gradients and `all_reduce` communications are done in fp8, reducing GPU + memory usage and communication bandwidth + * O2: First-order optimizer states are in 8-bit, and second order states are in FP16. + Only available when using Adam or AdamW. This maintains accuracy and can potentially save the + highest memory. + * 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models + are stored in FP8. If `fp8` is selected and deepspeed is enabled, will be used by default. (Not + available currently). """ + backend: Backend = "msamp" + opt_level: OptLevel = "O2" margin: int = 0 interval: int = 1 - fp8_format: str = "E4M3" + fp8_format: FP8Format = "E4M3" amax_history_len: int = 1 - amax_compute_algo: str = "most_recent" + amax_compute_algo: AmaxComputeAlgorithm = "most_recent" override_linear_precision: Tuple[bool, bool, bool] = (False, False, False) def __post_init__(self): - self.fp8_format = self.fp8_format.upper() - if self.fp8_format not in ["E4M3", "HYBRID"]: - raise ValueError("`fp8_format` must be 'E4M3' or 'HYBRID'.") - if self.amax_compute_algo not in ["max", "most_recent"]: - raise ValueError("`amax_compute_algo` must be 'max' or 'most_recent'") + self.backend = self.backend.upper() + if self.backend not in get_args(Backend): + raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).") + # Check TE args + if self.backend == "TE": + self.fp8_format = self.fp8_format.upper() + if self.fp8_format not in get_args(FP8Format): + raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") + if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): + raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") + elif self.backend == "MSAMP": + if self.opt_level not in get_args(OptLevel): + raise ValueError(f"`optimization_level` must be one of {' or '.join(get_args(OptLevel))}") class EnumWithContains(enum.EnumMeta): diff --git a/src/accelerate/utils/environment.py b/src/accelerate/utils/environment.py index 12169b8d852..ac0e6a8d9f4 100644 --- a/src/accelerate/utils/environment.py +++ b/src/accelerate/utils/environment.py @@ -19,6 +19,8 @@ from distutils import spawn from typing import Dict +import torch + def str_to_bool(value) -> int: """ @@ -108,3 +110,13 @@ def check_cuda_p2p_ib_support(): except Exception: pass return True + + +def check_fp8_capability(): + """ + Checks if all the current GPUs available support FP8. + + Notably must initialize `torch.cuda` to check. + """ + cuda_device_capacity = torch.cuda.get_device_capability() + return cuda_device_capacity >= (8, 9) diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 27389eab107..14ed9f7328c 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -72,10 +72,26 @@ def get_ccl_version(): return importlib.metadata.version("oneccl_bind_pt") -def is_fp8_available(): +def is_msamp_available(): + package_exists = importlib.util.find_spec("msamp") is not None + if package_exists: + try: + # MS-AMP has a different metadata name + _ = importlib.metadata.metadata("ms-amp") + return True + except importlib.metadata.PackageNotFoundError: + return False + return False + + +def is_transformer_engine_available(): return _is_package_available("transformer_engine") +def is_fp8_available(): + return is_msamp_available() or is_transformer_engine_available() + + def is_cuda_available(): """ Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda