diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 96d54dbc8324..7aa6cdbfa00a 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -641,6 +641,29 @@ jobs: AFTER_SCRIPT: | rm -rf examples/nlp/megatron_llama_distill + L2_Prune_Width_Llama2: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Prune_Width_Llama2') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/language_modeling/megatron_gpt_prune.py \ + trainer.devices=2 \ + trainer.num_nodes=1 \ + trainer.precision=bf16 \ + model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=2 \ + prune.num_calib_size=8 \ + prune.ffn_hidden_size=192 \ + prune.num_attention_heads=2 \ + prune.num_query_groups=2 \ + prune.hidden_size=null \ + export.save_path=examples/nlp/language_modeling/ci_prune_width.nemo + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/ci_prune_width.nemo + # L2: ASR dev run ASR_dev_run_Speech_to_Text: needs: [cicd-test-container-setup] @@ -5350,6 +5373,7 @@ jobs: - L2_Community_LLM_Checkpoints_tests_Llama3 - L2_PTQ_Llama2_Export_Only - L2_Distill_Llama2 + - L2_Prune_Width_Llama2 - L2_Speech_to_Text_AED - L2_Speech_Estimate_Duration_Bins - L2_Speech_Batch_Size_OOMptimizer diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml new file mode 100644 index 000000000000..cb26d5744b5b --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml @@ -0,0 +1,41 @@ +inference: + greedy: false # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: true # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: false # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: false # a flag used to compute logprob of all the input text, a very special case of running inference, default False + batch_size: 64 # batch size for inference + max_context_length: 512 # max length of the context, input sequence will be truncated if it is longer than this + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: false # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + enable_checkpointing: false + +model: + tensor_model_parallel_size: 1 # Pruning currently only supports tensor_model_parallel_size=1 + pipeline_model_parallel_size: 1 + restore_from_path: llama3.1-8b-base.nemo # Nemo file path + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + +prune: + calib_dataset: cnn_dailymail # wikitext, cnn_dailymail, or a local dataset + num_calib_size: 512 # number of samples used for calibration + ffn_hidden_size: 3584 # ffn_hidden_size in the pruned model, ffn_hidden_size // 4 + num_attention_heads: 8 # num_attention_heads in the pruned model, num_attention_heads // 4 + num_query_groups: 4 # num_query_groups in the pruned model, num_query_groups // 2 + hidden_size: 2048 # hidden_size in the pruned model, hidden_size // 2 + +export: + save_path: llama3.1-8b-base-pruned.nemo # Path where the pruned model will be saved diff --git a/examples/nlp/language_modeling/megatron_gpt_prune.py b/examples/nlp/language_modeling/megatron_gpt_prune.py new file mode 100644 index 000000000000..b9bf8edbfb1a --- /dev/null +++ b/examples/nlp/language_modeling/megatron_gpt_prune.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import modelopt.torch.prune as mtp +import torch +import torch.multiprocessing as mp +from datasets import load_dataset +from omegaconf import OmegaConf +from pytorch_lightning.trainer.trainer import Trainer +from tqdm import tqdm + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.core.config import hydra_runner +from nemo.utils.model_utils import load_config + +mp.set_start_method("spawn", force=True) + +""" +Nemo pruning example script. + +Please consult examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml config on available pruning arguments, +models supported as well as how to set up data and inference for calibration (with defaults recommended). + +Example usage: +``` +python examples/nlp/language_modeling/megatron_gpt_prune.py \ + model.restore_from_path=llama3.1-8b-base.nemo \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=8 \ + trainer.num_nodes=1 \ + trainer.precision=bf16 \ + trainer.devices=8 \ + prune.ffn_hidden_size=3584 \ + prune.num_attention_heads=8 \ + prune.num_query_groups=4 \ + prune.hidden_size=2048 \ + export.save_path=llama3.1-8b-base-pruned.nemo +``` +where tensor_model_parallel_size must be 1 because of the current prune API limitation +""" + + +def get_calib_data_iter(data="cnn_dailymail", batch_size=64, calib_size=512, max_sequence_length=512): + if data == "wikitext": + dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") + text_column = "text" + elif data == "cnn_dailymail": + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + text_column = "article" + else: + # Assume a local JSON dataset with a column named "text" + dataset = load_dataset("json", data_files=data, split="train") + text_column = "text" + calib_size = max(min(len(dataset), calib_size), batch_size) + for i in range(calib_size // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] + for j in range(len(batch)): + batch[j] = batch[j][:max_sequence_length] + yield batch + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_prune") +def main(cfg) -> None: + if not torch.cuda.is_available(): + raise EnvironmentError("GPU is required for the pruning.") + + # Overwrite model config with the one from the model checkpoint and apply pruning modifications + model_cfg = load_config(cfg.model.restore_from_path) + model_cfg.update(cfg.model) + model_cfg.name = "modelopt" # Use modelopt transformer spec for pruning + + assert cfg.model.tensor_model_parallel_size == 1, "Pruning currently only supports tensor_model_parallel_size=1" + assert ( + not hasattr(cfg.model, "sequence_parallel") or not cfg.model.sequence_parallel + ), "Pruning currently does not support sequence parallelism" + + trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) + model = MegatronGPTModel.restore_from( + restore_path=cfg.model.restore_from_path, override_config_path=model_cfg, trainer=trainer + ) + + data_iter = get_calib_data_iter( + cfg.prune.calib_dataset, + cfg.inference.batch_size, + cfg.prune.num_calib_size, + cfg.inference.max_context_length, + ) + dataloader = [data for data in data_iter] + + def forward_loop(model): + # NOTE: Alternatively you can also use `model.forward_bwd_step(data_iter, forward_only=True)` + # if your model is setup for training. + model.set_inference_config(OmegaConf.to_container(cfg.inference)) + for i, batch in enumerate(tqdm(dataloader, desc="Calibrating")): + model.predict_step(batch, i) + + model_pruned, _ = mtp.prune( + model, + mode="mcore_gpt_minitron", + constraints={ + "export_config": { + k: cfg.prune.get(k) + for k in ["ffn_hidden_size", "num_attention_heads", "num_query_groups", "hidden_size"] + if cfg.prune.get(k) is not None + }, + }, + dummy_input=None, # Not used + config={"forward_loop": forward_loop}, + ) + + model_pruned.save_to(cfg.export.save_path) + + +if __name__ == '__main__': + main()