Skip to content

Commit

Permalink
TE acceleration using callbacks (#11261)
Browse files Browse the repository at this point in the history
* TE acceleration using callbacks

Signed-off-by: Onur Yilmaz <[email protected]>

* TE accelerator example added

Signed-off-by: Onur Yilmaz <[email protected]>

* fp8 autocast added

Signed-off-by: Onur Yilmaz <[email protected]>

* single GPU support with TE

Signed-off-by: Onur Yilmaz <[email protected]>

* generalized model callback added

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* remove te_transform

Signed-off-by: Onur Yilmaz <[email protected]>

* DDP with TE is working

Signed-off-by: Onur Yilmaz <[email protected]>

* address feedback

Signed-off-by: Onur Yilmaz <[email protected]>

* updated function name and added export guard

Signed-off-by: Onur Yilmaz <[email protected]>

* move the torch.no_grad

Signed-off-by: Onur Yilmaz <[email protected]>

* gemma hf example added

Signed-off-by: Onur Yilmaz <[email protected]>

* gemma hf example added

Signed-off-by: Onur Yilmaz <[email protected]>

* add docstrings

Signed-off-by: Onur Yilmaz <[email protected]>

* Fix the hf te test param

Signed-off-by: Onur Yilmaz <[email protected]>

* fix minor issue

Signed-off-by: Onur Yilmaz <[email protected]>

* fix the recent callback related issue

Signed-off-by: Onur Yilmaz <[email protected]>

---------

Signed-off-by: Onur Yilmaz <[email protected]>
Signed-off-by: oyilmaz-nvidia <[email protected]>
Co-authored-by: oyilmaz-nvidia <[email protected]>
  • Loading branch information
oyilmaz-nvidia and oyilmaz-nvidia authored Nov 20, 2024
1 parent cae18e6 commit fcd128e
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 18 deletions.
44 changes: 29 additions & 15 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3457,7 +3457,20 @@ jobs:
inference.repetition_penalty=1.0 \
inference.outfile_path=/tmp/nlp_mcore_t5_lora_tuning_tp2/out.jsonl
# L2: Megatron Mock Data Generation
L2_HF_Transformer_SFT_TE_Acceleration:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_TE_Acceleration') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python examples/llm/sft/hf.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --model-accelerator te
AFTER_SCRIPT: |
rm -rf nemo_experiments
# L2: Megatron Mock Data Generation
L2_Megatron_Mock_Data_Generation_MockGPTDataset:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -3578,12 +3591,12 @@ jobs:
# timeout-minutes: 10
# container:
# image: nemoci.azurecr.io/nemo_container:${{ github.run_id }}
# options:
# options:
# # --user 0:128
# --device=/dev/nvidia0
# --gpus all
# --shm-size=8g
# --env TRANSFORMERS_OFFLINE=0
# --env TRANSFORMERS_OFFLINE=0
# --env HYDRA_FULL_ERROR=1
# --volume /mnt/datadrive/TestData:/home/TestData
# steps:
Expand Down Expand Up @@ -3643,12 +3656,12 @@ jobs:
# runs-on: self-hosted-azure
# container:
# image: nemoci.azurecr.io/nemo_container:${{ github.run_id }}
# options:
# options:
# # --user 0:128
# --device=/dev/nvidia0
# --gpus all
# --shm-size=8g
# --env TRANSFORMERS_OFFLINE=0
# --shm-size=8g
# --env TRANSFORMERS_OFFLINE=0
# --env HYDRA_FULL_ERROR=1
# --volume /mnt/datadrive/TestData:/home/TestData
# steps:
Expand Down Expand Up @@ -4333,7 +4346,7 @@ jobs:
rm -rf /tmp/nemo2_ptq_engine
Nemo_CICD_Test:
needs:
needs:
- pre-flight
- cicd-test-container-setup

Expand All @@ -4348,7 +4361,7 @@ jobs:
- L0_Unit_Tests_GPU_Hydra
- L0_Unit_Tests_GPU_Lightning
- L0_Unit_Tests_GPU_Others

- L0_Unit_Tests_CPU_ASR
- L0_Unit_Tests_CPU_Audio
- L0_Unit_Tests_CPU_Common
Expand Down Expand Up @@ -4446,7 +4459,8 @@ jobs:
- L2_NeMo_2_GPT_Pretraining_no_transformer_engine
- L2_NeMo_2_GPT_DDP_Param_Parity_check
- L2_NeMo_2_HF_MODEL_IMPORT
- L2_NeMo_2_llama3_pretraining_recipe
- L2_NeMo_2_llama3_pretraining_recipe
- L2_HF_Transformer_SFT_TE_Acceleration
- L2_NeMo_2_SSM_Pretraining
- L2_NeMo_2_SSM_Finetuning
- L2_NeMo_2_T5_Pretraining
Expand Down Expand Up @@ -4485,7 +4499,7 @@ jobs:
- L2_NeMo_2_PTQ_Llama2_FP8
if: always()
runs-on: ubuntu-latest
steps:
steps:
- name: Evaluate conclusion
if: ${{ always() }}
id: pipeline-conclusion
Expand All @@ -4499,14 +4513,14 @@ jobs:
echo "SUCCESS=$SUCCESS" >> $GITHUB_OUTPUT
# This should depend on all the tests so we block/unblock based on all tests passing
- name: Pipeline successful, set exit code to 0
- name: Pipeline successful, set exit code to 0
if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'true' }}
run: exit 0

- name: Pipeline successful, add PR comment
- name: Pipeline successful, add PR comment
if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'true' && github.event_name == 'pull_request' && env.SLACK_WEBHOOK != '' }}
uses: peter-evans/create-or-update-comment@v4
env:
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPOSITORY: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
Expand All @@ -4525,7 +4539,7 @@ jobs:
- name: "Pipeline not successful and not cancelled: Send Slack alert & create step summary"
if: ${{ always() && steps.pipeline-conclusion.outputs.FAILED == 'true' && env.SLACK_WEBHOOK != '' }}
env:
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
SLACK_WEBHOOK_ADMIN: <!subteam^${{ secrets.SLACK_WEBHOOK_ADMIN }}>
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down Expand Up @@ -4618,4 +4632,4 @@ jobs:
- name: "Pipeline not successful, set exit code to 1"
if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'false' }}
run: exit 1
run: exit 1
19 changes: 19 additions & 0 deletions examples/llm/sft/hf.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated, te_accelerate
from nemo.lightning.pytorch.callbacks import ModelCallback


class SquadDataModuleWithPthDataloader(llm.SquadDataModule):
Expand Down Expand Up @@ -53,7 +55,9 @@ def squad(tokenizer) -> pl.LightningDataModule:
parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp'])
parser.add_argument('--devices', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--model-accelerator', default=None, choices=['te'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument("--fp8-autocast", default=False, action='store_true')
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--model-save-path', type=str, default=None)
args = parser.parse_args()
Expand All @@ -74,6 +78,14 @@ def squad(tokenizer) -> pl.LightningDataModule:
model = llm.HfAutoModelForCausalLM(args.model)
tokenizer = model.tokenizer

callbacks = []
if args.model_accelerator:
if args.model_accelerator == "te":
model_transform = ModelCallback(
on_train_start=lambda model: te_accelerate(model, fp8_autocast=args.fp8_autocast)
)
callbacks.append(model_transform)

llm.api.finetune(
model=model,
data=squad(tokenizer),
Expand All @@ -88,11 +100,18 @@ def squad(tokenizer) -> pl.LightningDataModule:
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
callbacks=callbacks,
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
)

if args.model_accelerator:
if args.model_accelerator == "te":
te_acc = is_te_accelerated(model.model)
assert te_acc, "Transformer Engine acceleration was unsuccessful"
print("TE Accelerated: ", te_acc)

if args.model_save_path is not None:
model.save_pretrained(args.model_save_path)
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def __init__(
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(
self.model_name, trust_remote_code=self.trust_remote_code
)
self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(self.model_name)
return self._tokenizer

@tokenizer.setter
Expand Down
13 changes: 13 additions & 0 deletions nemo/lightning/pytorch/accelerate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
121 changes: 121 additions & 0 deletions nemo/lightning/pytorch/accelerate/transformer_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from types import MethodType

import torch
from nemo.utils import logging
from nemo.utils.import_utils import safe_import_from

te, HAVE_TE = safe_import_from("transformer_engine", "pytorch")


def te_accelerate(model, fp8_autocast=False):
"""
Replaces original model layers with TE's accelerated layers
Args:
model: HF model
fp8_autocast (bool): apply autocast or not
"""

if not HAVE_TE:
logging.warning("Transformer Engine is not available and the module replacements " "will not be applied.")
else:
_apply_basic_module_replacement(model)
if fp8_autocast:
apply_fp8_autocast(model)


@torch.no_grad
def _apply_basic_module_replacement(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
has_bias = module.bias is not None
if any(p % 16 != 0 for p in module.weight.shape):
continue
te_module = te.Linear(
module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
)
te_module.weight.copy_(module.weight)
if has_bias:
te_module.bias.copy_(module.bias)

setattr(module, name.split(".")[-1], te_module)
elif isinstance(module, torch.nn.LayerNorm):
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
te_module.weight.copy_(module.weight)
te_module.bias.copy_(module.bias)
setattr(module, name.split(".")[-1], te_module)
elif isinstance(module, torch.nn.RMSNorm):
te_module = te.RMSNorm(module.normalized_shape[0], eps=module.eps, dtype=module.weight.dtype)
te_module.weight.copy_(module.weight)
te_module.bias.copy_(module.bias)
setattr(module, name.split(".")[-1], te_module)


def is_te_accelerated(model):
"""
Checks whether model has TE layers or not
Args:
model: HF model
"""

if not HAVE_TE:
logging.warning("Transformer Engine is not available.")
return False
else:
for name, module in model.named_modules():
if isinstance(module, (te.LayerNorm, te.Linear, te.TransformerLayer)):
return True

return False


def apply_fp8_autocast(model, fp8_recipe_handler=None):
"""
Applies TE's autocast
Args:
model: HF model
fp8_recipe_handler: fpt handler
"""

if not HAVE_TE:
logging.warning("Transformer Engine is not available and the FP8 autocast " "will not be applied.")
else:
import transformer_engine.common.recipe as te_recipe

kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
new_forward = _contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)

if hasattr(model.forward, "__func__"):
model.forward = MethodType(new_forward, model)
else:
model.forward = new_forward


def _contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
from transformer_engine.pytorch import fp8_autocast

def forward(self, *args, **kwargs):
enabled = use_during_eval or self.training
with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe):
return model_forward(*args, **kwargs)

forward.__wrapped__ = model_forward

return forward
2 changes: 2 additions & 0 deletions nemo/lightning/pytorch/callbacks/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nemo.lightning.pytorch.callbacks.debugging import ParameterDebugger
from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback
from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback
from nemo.lightning.pytorch.callbacks.model_callback import ModelCallback
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform
from nemo.lightning.pytorch.callbacks.nsys import NsysCallback
Expand All @@ -36,4 +37,5 @@
"DdpParityChecker",
"GarbageCollectionCallback",
"ParameterDebugger",
"ModelCallback",
]
Loading

0 comments on commit fcd128e

Please sign in to comment.