Skip to content

Commit

Permalink
Integrate MS-AMP Support for FP8 as a seperate backend (#2232)
Browse files Browse the repository at this point in the history
* Redo with new version

* Store

* Working version

* Seperate for now

* Min diff

* check if available

* Better docstring

* Check for multiple models and optimizers

* Check for TE and MSAMP args seperately

* String clarity

* Better docstring and types

* Quality

* Simplify a bunch for fp8

* Convert literals to type alias

* Better err

* Docs

* toc typo

* Apply suggestions from code review

Co-authored-by: Marc Sun <[email protected]>

* Apply suggestions from code review

Co-authored-by: Maria Khalusova <[email protected]>

* Address doc nits

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Maria Khalusova <[email protected]>
  • Loading branch information
3 people authored Dec 15, 2023
1 parent 0606784 commit b052839
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 27 deletions.
4 changes: 4 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
title: Using experiment trackers
- local: usage_guides/mps
title: How to use Apple Silicon M1 GPUs
- local: usage_guides/low_precision_training
title: How to train in low precision (FP8)
- local: usage_guides/deepspeed
title: How to use DeepSpeed
- local: usage_guides/fsdp
Expand All @@ -63,6 +65,8 @@
title: Executing and deferring jobs
- local: concept_guides/gradient_synchronization
title: Gradient synchronization
- local: concept_guides/low_precision_training
title: How training in low-precision environments is possible (FP8)
- local: concept_guides/training_tpu
title: TPU best practices
title: Concepts and fundamentals
Expand Down
74 changes: 74 additions & 0 deletions docs/source/concept_guides/low_precision_training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Low Precision Training Methods

The release of new kinds of hardware led to the emergence of new training paradigms that better utilize them. Currently, this is in the form of training
in 8-bit precision using packages such as [TranformersEngine](https://github.com/NVIDIA/TransformerEngine) (TE) or [MS-AMP](https://github.com/Azure/MS-AMP/tree/main).

For an introduction to the topics discussed today, we recommend reviewing the [low-precision usage guide](../usage_guides/low_precision_training.md) as this documentation will reference it regularly.

## A Quick Chart

Below is a quick chart from the MS-AMP documentation showing the different bit-precisions for each solution during training:

Optimization Level | Computation(GEMM) | Comm | Weight | Master Weight | Weight Gradient | Optimizer States
-- | -- | -- | -- | -- | -- | --
FP16 AMP | FP16 | FP32 | FP32 | N/A | FP32 | FP32+FP32
Nvidia TE | FP8 | FP32 | FP32 | N/A | FP32 | FP32+FP32
MS-AMP O1 | FP8 | FP8 | FP16 | N/A | FP8 | FP32+FP32
MS-AMP O2 | FP8 | FP8 | FP16 | N/A | FP8 | FP8+FP16
MS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16

## `TransformersEngine`

`TranformersEngine` is the first solution to trying to train in 8-bit floating point. It works by using drop-in replacement layers for certain ones in a model that utilize their FP8-engine to reduce the number of bits (such as 32 to 8) without degrading the final accuracy of the model.

Specifically, 🤗 Accelerate will find and replace the following layers with `TranformersEngine` versions:

* `nn.LayerNorm` for `te.LayerNorm`
* `nn.Linear` for `te.Linear`

As a result we wind up with a model that has most of its layers in BF16, while some layers are in FP8 reducing some of the memory.

Anecdotally, we have noticed that performance gains don't really start showing when using `TransformerEngine` until a large majority of the layers
in the model are made up of those two layers to replace. As a result, only larger models have shown performance improvements when the number of parameters is around and upwards of a few billion.

The `TransformerEngine` can receive many different arguments that customize how it performs FP8 calculations and what they do. A full list of the arguments is available below:

* `margin`: The margin to use for the gradient scaling.
* `interval`: The interval to use for how often the scaling factor is recomputed.
* `fp8_format``: The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`.
* `amax_history_len`: The length of the history to use for the scaling factor computation
* `amax_compute_algo`: The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
* `override_linear_precision`: Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.

You can customize each of these as part of [`utils.FP8RecipeKwargs`] to help optimize performance of your models.

If we notice in the chart mentioned earlier, TE simply casts the computation layers into FP8, while everything else is in FP32. As a result this winds up utilizing the most memory but does so with the benefit of guaranteeing the least amount of loss in end accuracy during training.

## `MS-AMP`

MS-AMP takes a different approach to `TransformersEngine` by providing three different optimization levels to convert more operations in FP8 or FP16.

* The base optimization level (`O1`), passes communications of the weights (such as in DDP) in FP8, stores the weights of the model in FP16, and leaves the optimizer states in FP32. The main benefit of this optimization level is that we can reduce the communication bandwidth by essentially half. Additionally, more GPU memory is saved due to 1/2 of everything being cast in FP8, and the weights being cast to FP16. Notably, both the optimizer states remain in FP32.

* The second optimization level (`O2`) improves upon this by also reducing the precision of the optimizer states. One is in FP8 while the other is in FP16. Generally it's been shown that this will only provide a net-gain of no degredated end accuracy, increased training speed, and reduced memory as now every state is either in FP16 or FP8.

* Finally, MS-AMP has a third optimization level (`O3`) which helps during DDP scenarios such as DeepSpeed. The weights of the model in memory are fully cast to FP8, and the master weights are now stored in FP16. This fully reduces memory by the highest factor as now not only is almost everything in FP8, only two states are left in FP16. Currently, only DeepSpeed versions up through 0.9.2 are supported, so this capability is not included in the 🤗 Accelerate integration

## Combining the two

More experiments need to be performed but it's been noted that combining both MS-AMP and TransformersEngine can lead to the highest throughput by relying on NVIDIA's optimized FP8 operators and utilizing how MS-AMP reduces the memory overhead.
2 changes: 2 additions & 0 deletions docs/source/package_reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ These are basic dataclasses used throughout 🤗 Accelerate and they can be pass

[[autodoc]] utils.PrecisionType

[[autodoc]] utils.FP8RecipeKwargs

[[autodoc]] utils.ProjectConfiguration

## Environmental Variables
Expand Down
92 changes: 92 additions & 0 deletions docs/source/usage_guides/low_precision_training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Low Precision Training Methods

🤗 Accelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine` and `MS-AMP` packages. This documentation will help guide you through what hardware is supported, how to configure your [`Accelerator`] to leverage the low precision methods, and what you can expect when training.

## What training on FP8 means

To explore more of the nitty-gritty in traninig in FP8 with PyTorch and 🤗 Accelerate, check out the [concept_guide](../concept_guides/low_precision_training.md) on why this can be difficult. But essentially rather than training in BF16, some (or all) aspects of training a model can be performed using 8 bits instead of 16. The challenge is doing so without degrading final performance.

This is only enabled on specific NVIDIA hardware, namely:

* Anything after the 3000 series consumer graphics cards (such as the 4090)
* Hopper-based GPU architectures (such as the `H100` and `H200`)

What this will result in is some gain in the memory used (as we've cut the needed memory in half for some parts of training) and an increase in throughput *should* be seen as well for larger models that can replace certain layers with FP8-enabled ones.

## Configuring the Accelerator

Currently two different backends for FP8 are supported (`TransformersEngine` and `MS-AMP`), each with different capabilities and configurations.

To use either, the same core API is used. Just pass `mixed_precision="fp8"` to either the [`Accelerator`], during `accelerate config` when prompted about mixed precision, or as part of your `config.yaml` file in the `mixed_precision` key:

```{python}
from accelerate import Accelerator
accelerator = Accelerator(mixed_precision="fp8")
```

By default, if `MS-AMP` is available in your environment, 🤗 Accelerate will automatically utilize it as a backend. To specify it yourself (and customize other parts of the FP8 mixed precision setup), you can utilize the [`utils.FP8RecipeKwargs`]:

```{python}
from accelerate import Accelerator
from accelerate.utils import FP8RecipeKwargs
kwargs = [FP8RecipeKwargs(backend="msamp")]
# Or to specify the backend as `TransformersEngine` even if MS-AMP is installed
# kwargs = [FP8RecipeKwargs(backend="te")]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```

## Configuring MS-AMP

Of the two, `MS-AMP` is traditionally the easier one to configure as there is only a single argument: the optimization level.

Currently two levels of optimization are supported in the 🤗 Accelerate integration, `"O1"` and `"O2"` (using the letter 'o', not zero).

* `"O1"` will cast the weight gradients and `all_reduce` communications to happen in 8-bit, while the rest are done in 16 bit. This reduces the general GPU memory usage and speeds up communication bandwidths.
* `"O2"` will also cast first-order optimizer states into 8 bit, while the second order states are in FP16. (Currently just the `Adam` optimizer is supported). This tries it's best to minimize final accuracy degredation and will save the highest potential memory.

To specify an optimization level, pass it to the `FP8KwargsHandler` by setting the `optimization_level` argument:

```{python}
from accelerate import Accelerator
from accelerate.utils import FP8RecipeKwargs
kwargs = [FP8RecipeKwargs(backend="msamp", optimization_level="O2")]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```

## Configuring TransformersEngine

TransformersEngine has much more available for customizing how and what FP8 calculations are performed. A full list of supported arguments and what they mean are available in [NVIDIA's documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html), however they are restated as part of [`FP8KwargsHandler`]'s docstring for your convience.

🤗 Accelerate tries to set sensible defaults, but exploring and tweaking the various parameters yourself can lead to better performance potentially.

To use it, specify `backend="te"` and modify any of the arguments you want as part of your kwarg handler:

```{python}
from accelerate import Accelerator
from accelerate.utils import FP8RecipeKwargs
kwargs = [FP8RecipeKwargs(backend="te", ...)]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```

## Futher Reading

To learn more about training in FP8 please check out the following resources:

* [Our concept guide](../concept_guides/low_precision_training.md) detailing into more about both TransformersEngine and MS-AMP
* [The `transformers-engine` documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html)
* [The `MS-AMP` documentation](https://azure.github.io/MS-AMP/docs/)
59 changes: 47 additions & 12 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
is_fp8_available,
is_ipex_available,
is_megatron_lm_available,
is_msamp_available,
is_npu_available,
is_torch_version,
is_tpu_available,
Expand Down Expand Up @@ -366,6 +367,8 @@ def __init__(
raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.")
else:
self.autocast_handler = handler
if self.fp8_recipe_handler is None and mixed_precision == "fp8":
self.fp8_recipe_handler = FP8RecipeKwargs()

kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
self.state = AcceleratorState(
Expand Down Expand Up @@ -1196,7 +1199,7 @@ def prepare(self, *args, device_placement=None):

# If we're dealing with device placement, this deals with that by...
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.TPU
if tpu_should_fix_optimizer or self.mixed_precision == "fp8":
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
# 1. grabbing old model parameters
old_named_params = self._get_named_parameters(*args)

Expand All @@ -1210,12 +1213,16 @@ def prepare(self, *args, device_placement=None):
elif self.distributed_type == DistributedType.MEGATRON_LM:
result = self._prepare_megatron_lm(*args)
else:
if self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP":
args = self._prepare_msamp(*args)
# MS-AMP will handle the device placement
device_placement = [False for _ in args]
result = tuple(
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
)
result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))

if tpu_should_fix_optimizer or self.mixed_precision == "fp8":
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
# 2. grabbing new model parameters
new_named_params = self._get_named_parameters(*result)
# 3. building a map from the first to the second
Expand Down Expand Up @@ -1284,7 +1291,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
else:
model.forward = convert_outputs_to_fp32(new_forward)
elif self.mixed_precision == "fp8":
elif self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE":
if not has_transformer_engine_layers(model):
with torch.no_grad():
convert_model(model)
Expand All @@ -1295,15 +1302,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
cuda_device_capacity = torch.cuda.get_device_capability()
fp8_enabled = cuda_device_capacity >= (8, 9)
if not fp8_enabled:
logger.warn(
f"The current device has compute capability of {cuda_device_capacity} which is "
"insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
)
model.forward = fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe)(model.forward)
model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward)

if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
model, "hf_device_map", False
Expand Down Expand Up @@ -1749,6 +1748,42 @@ def _prepare_ipex(self, *args):
result[i] = optimizer
return tuple(result)

def _prepare_msamp(self, *args):
if not is_msamp_available():
raise ImportError(
"MS-AMP was not found on your system. Please ensure that MS-AMP is available "
" or choose `'te'` as the backend for FP8 mixed precision training."
)
else:
import msamp

model, optimizer = None, None
num_models, num_optimizers = 0, 0
result = [obj for obj in args]
for obj in result:
if isinstance(obj, torch.nn.Module):
model = obj
num_models += 1
elif isinstance(obj, (torch.optim.Optimizer)):
optimizer = obj
num_optimizers += 1
if optimizer is None or model is None:
raise ValueError(
"You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP."
)
elif num_models > 1 or num_optimizers > 1:
raise ValueError(
f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP."
)
else:
model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level)
for i in range(len(result)):
if isinstance(result[i], torch.nn.Module):
result[i] = model
elif isinstance(result[i], (torch.optim.Optimizer)):
result[i] = optimizer
return tuple(result)

def prepare_data_loader(
self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
):
Expand Down
19 changes: 17 additions & 2 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import logging
import math
import os
import threading
Expand All @@ -29,6 +30,7 @@
DynamoBackend,
GradientAccumulationPlugin,
check_cuda_p2p_ib_support,
check_fp8_capability,
get_ccl_version,
get_int_from_env,
is_ccl_available,
Expand All @@ -52,6 +54,8 @@
if is_npu_available(check_device=False):
import torch_npu # noqa: F401

logger = logging.getLogger(__name__)


def is_initialized() -> bool:
"""
Expand Down Expand Up @@ -765,8 +769,19 @@ def __init__(
if mixed_precision is None
else mixed_precision.lower()
)
if mixed_precision == "fp8" and not is_fp8_available():
raise ValueError("Using `fp8` precision requires `transformer_engine` to be installed.")
if mixed_precision == "fp8":
if not is_fp8_available():
raise ValueError(
"Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed."
)
elif not check_fp8_capability():
logger.warning(
f"The current device has compute capability of {torch.cuda.get_device_capability()} which is "
"insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace "
"or higher, compute capability of 8.9 or higher). Will use FP16 instead."
)
mixed_precision = "fp16"

self.dynamo_plugin = dynamo_plugin
if not _from_accelerator:
raise ValueError(
Expand Down
Loading

0 comments on commit b052839

Please sign in to comment.