Skip to content

Commit

Permalink
Add ModelOpt transformer model pruning example for Llama models, defa…
Browse files Browse the repository at this point in the history
…ult to llama3.1-8b-base (#10294)

* Add ModelOpt transformer model pruning example for Llama3 model

Signed-off-by: Shengliang Xu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: shengliangxu <[email protected]>
Signed-off-by: Shengliang Xu <[email protected]>

* examples code is at wrong dir, move them

Signed-off-by: Shengliang Xu <[email protected]>

* changes as suggested in comment

remove some logging and unused config code, update example model to
llama3.1

Signed-off-by: Shengliang Xu <[email protected]>

* Add pruning of hidden_size into example

Signed-off-by: Shengliang Xu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: shengliangxu <[email protected]>
Signed-off-by: Shengliang Xu <[email protected]>

* Update examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml

Signed-off-by: Keval Morabia <[email protected]>

* Add pruning test to cicd-main.yml

Signed-off-by: Keval Morabia <[email protected]>

* Update cicd-main.yml

Signed-off-by: Keval Morabia <[email protected]>

* Update cicd-main.yml

Signed-off-by: Keval Morabia <[email protected]>

* Update cicd-main.yml

Signed-off-by: Keval Morabia <[email protected]>

* Update cicd-main.yml

Signed-off-by: Keval Morabia <[email protected]>

* Update cicd-main.yml

Signed-off-by: Keval Morabia <[email protected]>

---------

Signed-off-by: Shengliang Xu <[email protected]>
Signed-off-by: shengliangxu <[email protected]>
Signed-off-by: Keval Morabia <[email protected]>
Co-authored-by: shengliangxu <[email protected]>
Co-authored-by: Keval Morabia <[email protected]>
  • Loading branch information
3 people authored Oct 8, 2024
1 parent fd5c978 commit efd0252
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 0 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,29 @@ jobs:
AFTER_SCRIPT: |
rm -rf examples/nlp/megatron_llama_distill
L2_Prune_Width_Llama2:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Prune_Width_Llama2') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python examples/nlp/language_modeling/megatron_gpt_prune.py \
trainer.devices=2 \
trainer.num_nodes=1 \
trainer.precision=bf16 \
model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \
model.tensor_model_parallel_size=1 \
model.pipeline_model_parallel_size=2 \
prune.num_calib_size=8 \
prune.ffn_hidden_size=192 \
prune.num_attention_heads=2 \
prune.num_query_groups=2 \
prune.hidden_size=null \
export.save_path=examples/nlp/language_modeling/ci_prune_width.nemo
AFTER_SCRIPT: |
rm -rf examples/nlp/language_modeling/ci_prune_width.nemo
# L2: ASR dev run
ASR_dev_run_Speech_to_Text:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -5350,6 +5373,7 @@ jobs:
- L2_Community_LLM_Checkpoints_tests_Llama3
- L2_PTQ_Llama2_Export_Only
- L2_Distill_Llama2
- L2_Prune_Width_Llama2
- L2_Speech_to_Text_AED
- L2_Speech_Estimate_Duration_Bins
- L2_Speech_Batch_Size_OOMptimizer
Expand Down
41 changes: 41 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
inference:
greedy: false # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: true # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: false # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: false # a flag used to compute logprob of all the input text, a very special case of running inference, default False
batch_size: 64 # batch size for inference
max_context_length: 512 # max length of the context, input sequence will be truncated if it is longer than this

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: false # logger provided by exp_manager
precision: bf16 # 16, 32, or bf16
enable_checkpointing: false

model:
tensor_model_parallel_size: 1 # Pruning currently only supports tensor_model_parallel_size=1
pipeline_model_parallel_size: 1
restore_from_path: llama3.1-8b-base.nemo # Nemo file path

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'

prune:
calib_dataset: cnn_dailymail # wikitext, cnn_dailymail, or a local dataset
num_calib_size: 512 # number of samples used for calibration
ffn_hidden_size: 3584 # ffn_hidden_size in the pruned model, ffn_hidden_size // 4
num_attention_heads: 8 # num_attention_heads in the pruned model, num_attention_heads // 4
num_query_groups: 4 # num_query_groups in the pruned model, num_query_groups // 2
hidden_size: 2048 # hidden_size in the pruned model, hidden_size // 2

export:
save_path: llama3.1-8b-base-pruned.nemo # Path where the pruned model will be saved
127 changes: 127 additions & 0 deletions examples/nlp/language_modeling/megatron_gpt_prune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import modelopt.torch.prune as mtp
import torch
import torch.multiprocessing as mp
from datasets import load_dataset
from omegaconf import OmegaConf
from pytorch_lightning.trainer.trainer import Trainer
from tqdm import tqdm

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.core.config import hydra_runner
from nemo.utils.model_utils import load_config

mp.set_start_method("spawn", force=True)

"""
Nemo pruning example script.
Please consult examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml config on available pruning arguments,
models supported as well as how to set up data and inference for calibration (with defaults recommended).
Example usage:
```
python examples/nlp/language_modeling/megatron_gpt_prune.py \
model.restore_from_path=llama3.1-8b-base.nemo \
model.tensor_model_parallel_size=1 \
model.pipeline_model_parallel_size=8 \
trainer.num_nodes=1 \
trainer.precision=bf16 \
trainer.devices=8 \
prune.ffn_hidden_size=3584 \
prune.num_attention_heads=8 \
prune.num_query_groups=4 \
prune.hidden_size=2048 \
export.save_path=llama3.1-8b-base-pruned.nemo
```
where tensor_model_parallel_size must be 1 because of the current prune API limitation
"""


def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512):
if data == "wikitext":
dataset = load_dataset("wikitext", "wikitext-103-v1", split="train")
text_column = "text"
elif data == "cnn_dailymail":
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
text_column = "article"
else:
# Assume a local JSON dataset with a column named "text"
dataset = load_dataset("json", data_files=data, split="train")
text_column = "text"
calib_size = max(min(len(dataset), calib_size), batch_size)
for i in range(calib_size // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size][text_column]
for j in range(len(batch)):
batch[j] = batch[j][:max_sequence_length]
yield batch


@hydra_runner(config_path="conf", config_name="megatron_gpt_prune")
def main(cfg) -> None:
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for the pruning.")

# Overwrite model config with the one from the model checkpoint and apply pruning modifications
model_cfg = load_config(cfg.model.restore_from_path)
model_cfg.update(cfg.model)
model_cfg.name = "modelopt" # Use modelopt transformer spec for pruning

assert cfg.model.tensor_model_parallel_size == 1, "Pruning currently only supports tensor_model_parallel_size=1"
assert (
not hasattr(cfg.model, "sequence_parallel") or not cfg.model.sequence_parallel
), "Pruning currently does not support sequence parallelism"

trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)
model = MegatronGPTModel.restore_from(
restore_path=cfg.model.restore_from_path, override_config_path=model_cfg, trainer=trainer
)

data_iter = get_calib_data_iter(
cfg.prune.calib_dataset,
cfg.inference.batch_size,
cfg.prune.num_calib_size,
cfg.inference.max_context_length,
)
dataloader = [data for data in data_iter]

def forward_loop(model):
# NOTE: Alternatively you can also use `model.forward_bwd_step(data_iter, forward_only=True)`
# if your model is setup for training.
model.set_inference_config(OmegaConf.to_container(cfg.inference))
for i, batch in enumerate(tqdm(dataloader, desc="Calibrating")):
model.predict_step(batch, i)

model_pruned, _ = mtp.prune(
model,
mode="mcore_gpt_minitron",
constraints={
"export_config": {
k: cfg.prune.get(k)
for k in ["ffn_hidden_size", "num_attention_heads", "num_query_groups", "hidden_size"]
if cfg.prune.get(k) is not None
},
},
dummy_input=None, # Not used
config={"forward_loop": forward_loop},
)

model_pruned.save_to(cfg.export.save_path)


if __name__ == '__main__':
main()

0 comments on commit efd0252

Please sign in to comment.