diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 4d788fb866d8..a4b2baa59550 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -2109,6 +2109,121 @@ jobs: # } # } + L2_Megatron_LM_To_NeMo_Conversion: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Megatron_LM_To_NeMo_Conversion') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 Megatron-LM/pretrain_gpt.py \ + --mock-data \ + --distributed-timeout-minutes 60 \ + --use-mcore-models \ + --no-mmap-bin-files \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --train-samples 80 \ + --init-method-std 0.014 \ + --position-embedding-type rope \ + --rotary-base 1000000 \ + --rotary-percent 1.0 \ + --squared-relu \ + --num-layers 4 \ + --hidden-size 384 \ + --num-attention-heads 8 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 1536 \ + --kv-channels 128 \ + --normalization RMSNorm \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --exit-duration-in-mins 5750 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --micro-batch-size 1 \ + --global-batch-size 8 \ + --lr 6e-4 \ + --min-lr 6e-6 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 1 \ + --eval-interval 10 \ + --tokenizer-type GPT2BPETokenizer \ + --tokenizer-model /home/TestData/nlp/gpt2_tokenizer \ + --vocab-file /home/TestData/nlp/gpt2_tokenizer/vocab.json \ + --merge-file /home/TestData/nlp/gpt2_tokenizer/merges.txt \ + --save /tmp/mlm_conversion_ckpt \ + --save-interval 10 \ + --ckpt-format torch_dist \ + --ckpt-fully-parallel-save \ + --ckpt-fully-parallel-load \ + --async-save \ + --ckpt-assume-constant-structure \ + --timing-log-option minmax \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --log-throughput \ + --bf16 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --use-distributed-optimizer \ + --overlap-grad-reduce \ + --overlap-param-gather \ + --manual-gc \ + --num-workers 2 + + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + model.data.data_impl=mock \ + model.data.data_prefix=[] \ + model.skip_train=True \ + model.transformer_engine=True \ + model.use_flash_attention=False \ + model.normalization=rmsnorm \ + model.num_layers=4 \ + model.hidden_size=384 \ + model.ffn_hidden_size=1536 \ + model.num_attention_heads=8 \ + model.num_query_groups=8 \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=True \ + model.masked_softmax_fusion=True \ + model.encoder_seq_length=8192 \ + model.max_position_embeddings=8192 \ + model.data.seq_length=8192 \ + model.activation=squared-relu \ + model.transformer_block_type=True \ + model.micro_batch_size=1 \ + model.global_batch_size=8 \ + ++model.rotary_base=1000000 \ + model.rotary_percentage=1.0 \ + model.apply_query_key_layer_scaling=False \ + ++model.group_query_attention=True \ + model.apply_rope_fusion=True \ + model.kv_channels=128 \ + ++model.bert_binary_head=True \ + ++model.position_embedding_type=rope \ + ++model.add_position_embedding=True \ + trainer.limit_val_batches=1 \ + exp_manager.exp_dir=/tmp/nemo_conversion_ckpt + + python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/language_modeling/megatron_ckpt_to_nemo.py \ + --checkpoint_folder /tmp/mlm_conversion_ckpt \ + --checkpoint_name iter_0000010 \ + --nemo_file_path /tmp/mlm_to_nemo_test.nemo \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --gpus_per_node 1 \ + --model_type gpt \ + --hparams_file /tmp/nemo_conversion_ckpt/megatron_gpt/version_0/hparams.yaml \ + --convert_mlm + L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4319,7 +4434,18 @@ jobs: --mbs 1 \ --model mistral \ --dist-opt + + L2_NEMO_2_LoRA_MERGE: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NEMO_2_LoRA_MERGE') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python tests/collections/llm/peft/lora_merge.py \ + --lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint/ \ + --output_path=/tmp/nemo2_lora_merge/${{ github.run_id }} L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: needs: [cicd-test-container-setup] @@ -4421,6 +4547,7 @@ jobs: - L2_RAG_Pipeline_Generating - L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_Skip_Train + - L2_Megatron_LM_To_NeMo_Conversion - L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_Drop_Optimizer_States_TP2 @@ -4482,6 +4609,7 @@ jobs: - L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1 + - L2_NEMO_2_LoRA_MERGE - L2_NeMo_2_Mixtral_Pretraining - L2_PTQ_Llama2_FP8 - L2_Community_LLM_Checkpoints_tests_Llama3 diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index f1d61edc990e..5c4a636e8b1c 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -276,6 +276,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # we will adjust this flag if the model does not support it compute_langs = cfg.compute_langs + if cfg.timestamps: + cfg.return_hypotheses = True + # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): if cfg.decoder_type and cfg.decoder_type != 'ctc': diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index bdf54ea67f7d..d60099acd379 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -163,6 +163,14 @@ def main(cfg: ParallelTranscriptionConfig): cfg.predict_ds.return_sample_id = True cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) + if cfg.predict_ds.use_lhotse: + OmegaConf.set_struct(cfg.predict_ds, False) + cfg.trainer.use_distributed_sampler = False + cfg.predict_ds.force_finite = True + cfg.predict_ds.force_map_dataset = True + cfg.predict_ds.do_transcribe = True + OmegaConf.set_struct(cfg.predict_ds, True) + if isinstance(model, EncDecMultiTaskModel): cfg.trainer.use_distributed_sampler = False OmegaConf.set_struct(cfg.predict_ds, False) @@ -172,7 +180,7 @@ def main(cfg: ParallelTranscriptionConfig): trainer = ptl.Trainer(**cfg.trainer) - if isinstance(model, EncDecMultiTaskModel): + if cfg.predict_ds.use_lhotse: OmegaConf.set_struct(cfg.predict_ds, False) cfg.predict_ds.global_rank = trainer.global_rank cfg.predict_ds.world_size = trainer.world_size diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 59b8b4ad3491..1d282312b130 100755 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -19,7 +19,7 @@ from nemo import lightning as nl from nemo.collections import llm -from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated, te_accelerate +from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated from nemo.lightning.pytorch.callbacks import ModelCallback @@ -75,16 +75,17 @@ def squad(tokenizer) -> pl.LightningDataModule: grad_clip = None use_dist_samp = False - model = llm.HfAutoModelForCausalLM(args.model) - tokenizer = model.tokenizer + model_accelerator = None + if args.model_accelerator == "te": + from functools import partial + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate - callbacks = [] - if args.model_accelerator: - if args.model_accelerator == "te": - model_transform = ModelCallback( - on_train_start=lambda model: te_accelerate(model, fp8_autocast=args.fp8_autocast) - ) - callbacks.append(model_transform) + model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast) + + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) + tokenizer = model.tokenizer llm.api.finetune( model=model, @@ -100,7 +101,7 @@ def squad(tokenizer) -> pl.LightningDataModule: accumulate_grad_batches=10, gradient_clip_val=grad_clip, use_distributed_sampler=use_dist_samp, - callbacks=callbacks, + callbacks=[], logger=wandb, ), optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), diff --git a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py index b46f8f459ff0..4b9fab987dc7 100644 --- a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py @@ -112,6 +112,11 @@ def get_args(): choices=['32-true', '16-mixed', 'bf16-mixed'], help="Precision value for the trainer that matches with precision of the ckpt", ) + parser.add_argument( + "--convert_mlm", + action="store_true", + help="Use this flag to convert megatron-lm checkpoints.", + ) args = parser.parse_args() return args @@ -195,7 +200,9 @@ def convert(local_rank, rank, world_size, args): ) if args.model_type == 'gpt': - model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) + model = MegatronGPTModel.load_from_checkpoint( + checkpoint_path, hparams_file=args.hparams_file, trainer=trainer, load_mlm=args.convert_mlm + ) elif args.model_type == 'sft': model = MegatronGPTSFTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 76537a8b2b78..f91710de3cb3 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -867,10 +867,16 @@ def write_on_batch_end( sample = sample_id if isinstance(sample, lhotse.cut.MixedCut): sample = sample.first_non_padding_cut + if sample.recording.sources[0].source != '': + item["audio_filepath"] = sample.recording.sources[0].source + else: + item["audio_filepath"] = sample.id item["audio_filepath"] = sample.recording.sources[0].source item["offset"] = sample.start item["duration"] = sample.duration - item["text"] = sample.supervisions[0].text + item["text"] = sample.supervisions[0].text or '' + if hasattr(sample, 'shard_id'): + item["shard_id"] = sample.shard_id item["pred_text"] = transcribed_text self.outf.write(json.dumps(item) + "\n") self.samples_num += 1 diff --git a/nemo/collections/asr/data/audio_to_text_lhotse.py b/nemo/collections/asr/data/audio_to_text_lhotse.py index f916ae1de56b..0ae3059a9296 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse.py @@ -43,17 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } - def __init__(self, tokenizer): + def __init__(self, tokenizer, return_cuts=False): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) + self.return_cuts = return_cuts def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) tokens = [ torch.cat( [ - torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language)) + torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text or "", s.language)) for s in c.supervisions ], dim=0, @@ -62,6 +63,8 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: ] token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) tokens = collate_vectors(tokens, padding_value=0) + if self.return_cuts: + return audio, audio_lens, tokens, token_lens, cuts.drop_in_memory_data() return audio, audio_lens, tokens, token_lens diff --git a/nemo/collections/asr/models/configs/asr_models_config.py b/nemo/collections/asr/models/configs/asr_models_config.py index 29dbbe06d1f8..081233da5d32 100644 --- a/nemo/collections/asr/models/configs/asr_models_config.py +++ b/nemo/collections/asr/models/configs/asr_models_config.py @@ -41,6 +41,17 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig): shard_manifests: bool = False shuffle_n: int = 0 + # lhotse support + use_lhotse: bool = False + tarred_random_access: bool = False + use_bucketing: bool = False + batch_duration: Optional[int] = None + quadratic_duration: Optional[int] = None + bucket_batch_size: Optional[int] = None + bucket_duration_bins: Optional[list] = None + num_buckets: Optional[int] = 0 + pin_memory: bool = False + # Optional int_values: Optional[int] = None augmentor: Optional[Dict[str, Any]] = None diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 79c22794de01..1f84989c8ebe 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -97,9 +97,15 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, - dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer), + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), + dataset=LhotseSpeechToTextBpeDataset( + tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), + ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 3df6a7352c4d..ae8c35220931 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -160,6 +160,7 @@ def transcribe( A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files """ + timestamps = timestamps or (override_config.timestamps if override_config is not None else None) if timestamps is not None: # else retain the decoder state (users can set it using change_decoding_strategy) if timestamps or (override_config is not None and override_config.timestamps): @@ -308,8 +309,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=make_parser( labels=config.get('labels', None), @@ -318,6 +322,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): blank_id=config.get('blank_index', -1), do_normalize=config.get('normalize_transcripts', False), ), + return_cuts=config.get("do_transcribe", False), ), ) @@ -613,7 +618,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return_hypotheses=False, ) - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, transcribed_texts)) def validation_pass(self, batch, batch_idx, dataloader_idx=0): diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 7e8720ee3ad8..cd04a5ad2462 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -140,10 +140,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 34dd9aae5711..1f63c617cea2 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -519,8 +519,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) - - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) def validation_pass(self, batch, batch_idx, dataloader_idx): diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index c92bcfaaef7a..cd8667f2f0fe 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -509,10 +509,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index a6408b5e935e..78038d404107 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -285,7 +285,7 @@ def transcribe( * A list of greedy transcript texts / Hypothesis * An optional list of beam search transcript texts / Hypothesis / NBestHypothesis. """ - + timestamps = timestamps or (override_config.timestamps if override_config is not None else None) if timestamps is not None: if timestamps or (override_config is not None and override_config.timestamps): logging.info( @@ -469,8 +469,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=make_parser( labels=config.get('labels', None), @@ -479,6 +482,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): blank_id=config.get('blank_index', -1), do_normalize=config.get('normalize_transcripts', False), ), + return_cuts=config.get("do_transcribe", False), ), ) @@ -814,7 +818,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) def validation_pass(self, batch, batch_idx, dataloader_idx=0): diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 8d0f2b2223a3..4692cb662b4b 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -225,10 +225,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config = self._update_default_values(config) return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 98b63a07fa9d..bf6b77ad907e 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -147,6 +147,28 @@ class LhotseDataLoadingConfig: # In most cases (such as regular multi-GPU training) it will result in a deadlock due to # a different number of steps on different DDP ranks. force_finite: bool = False + # The following two options may be used to override auto-detection of appropriate PyTorch dataset flavor + # for your data types. PyTorch DataLoader uses two objects to yield data: dataset and sampler. + # *Map-dataset flavor.* There is one sampler per GPU that lives in the training loop process; + # it selects the examples to be prepared by map-dataset class. Each batch selection determined by the sampler + # is then passed by the dataloader to one of its worker processes to be processed by the dataset class. + # *Iterable-dataset flavor.* Each dataloading worker has its own sampler replica instead; + # the sampler must have the logic for either data deduplication or unique order shuffling to avoid + # duplicated data across workers and GPUs. Lhotse relies on unique order shuffling. + # The default settings are: + # * use iterable dataset for tarred audio data. + # * use iterable dataset for any text data. + # * use map dataset for non-tarred audio data (we might change this in the future) + force_map_dataset: bool = False + force_iterable_dataset: bool = False + + +def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: + assert not ( + config.force_map_dataset and config.force_iterable_dataset + ), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True" + use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset + return use_iterable_dataset def get_lhotse_dataloader_from_config( @@ -176,7 +198,6 @@ def get_lhotse_dataloader_from_config( Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work). """ logging.info("We will be using a Lhotse DataLoader.") - config = make_structured_with_schema_warnings(config) maybe_set_cuda_expandable_segments(enabled=config.cuda_expandable_segments) @@ -186,8 +207,8 @@ def get_lhotse_dataloader_from_config( fix_random_seed(seed) # 1. Load a manifest as a Lhotse CutSet. - cuts, is_tarred = read_cutset_from_config(config) - + cuts, use_iterable_dataset = read_cutset_from_config(config) + use_iterable_dataset = determine_use_iterable_dataset(use_iterable_dataset, config) # Apply channel selector if config.channel_selector is not None: logging.info('Using channel selector %s.', config.channel_selector) @@ -202,7 +223,7 @@ def get_lhotse_dataloader_from_config( if tokenizer is not None and config.pretokenize: from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper - if not is_tarred: + if not use_iterable_dataset: logging.warning( "You are using a non-tarred dataset and requested tokenization during data sampling (pretokenize=True). " "This will cause the tokenization to happen in the main (GPU) process, possibly impacting the training speed " @@ -317,8 +338,8 @@ def get_lhotse_dataloader_from_config( duration_bins=determine_bucket_duration_bins(config), num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate, buffer_size=config.bucket_buffer_size, - rank=0 if is_tarred else global_rank, - world_size=1 if is_tarred else world_size, + rank=0 if use_iterable_dataset else global_rank, + world_size=1 if use_iterable_dataset else world_size, ) else: # Non-bucketing sampler, similar to original NeMo dataloading without bucketing, @@ -335,8 +356,8 @@ def get_lhotse_dataloader_from_config( drop_last=config.drop_last, shuffle_buffer_size=config.shuffle_buffer_size, seed=config.shard_seed, - rank=0 if is_tarred else global_rank, - world_size=1 if is_tarred else world_size, + rank=0 if use_iterable_dataset else global_rank, + world_size=1 if use_iterable_dataset else world_size, ) if config.concatenate_samples: @@ -368,7 +389,7 @@ def get_lhotse_dataloader_from_config( ) # 4. Creating dataloader. - if is_tarred and not config.tarred_random_access: + if use_iterable_dataset and not config.tarred_random_access: # Wrapper here is necessary when using NeMo tarred data or Lhotse Shar data, # because then I/O happens upon sampler iteration. Normally, the sampler resides # in the training loop process, but when we use iterable dataset, we can move it to @@ -601,8 +622,8 @@ class DurationFilter: """Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise.""" def __init__(self, d_min: float, d_max: float) -> None: - self.d_min = d_min - self.d_max = d_max + self.d_min = d_min if d_min is not None else -1.0 + self.d_max = d_max if d_max is not None else float("inf") def __call__(self, example) -> bool: if isinstance(example, Cut): diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index c5c2e007bc1e..c36da39b43c7 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -73,6 +73,8 @@ Llama31Config8B, Llama31Config70B, Llama31Config405B, + Llama32Config1B, + Llama32Config3B, LlamaConfig, LlamaModel, MaskedTokenLossReduction, @@ -171,6 +173,8 @@ "Llama31Config8B", "Llama31Config70B", "Llama31Config405B", + "Llama32Config1B", + "Llama32Config3B", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index aaef714ef738..4bafdd97ba21 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os from copy import deepcopy from pathlib import Path diff --git a/nemo/collections/llm/fn/activation.py b/nemo/collections/llm/fn/activation.py index 5970846d32b2..db82f95b4bcc 100644 --- a/nemo/collections/llm/fn/activation.py +++ b/nemo/collections/llm/fn/activation.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +from megatron.core.jit import jit_fuser @torch.jit.script @@ -25,6 +26,11 @@ def openai_gelu(x): return gelu_impl(x) +@jit_fuser +def quick_gelu(x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + # @torch.jit.script # remove until we have serialization def squared_relu(x): """Squared ReLU activation function.""" diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 8fcef72f3bd9..0d866bb600fe 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -117,17 +117,28 @@ def prepare_data(self) -> None: """ Prepare packed sequence data """ - if self.packed_sequence_size > 0 and not self.train_path_packed.is_file(): + if self.packed_sequence_size > 0: from nemo.collections.llm.gpt.data.packed_sequence import prepare_packed_sequence_data - prepare_packed_sequence_data( - input_path=self.train_path, - output_path=self.train_path_packed, - packed_sequence_size=self.packed_sequence_size, - tokenizer=self.tokenizer, - max_seq_length=self.seq_length, - seed=self.seed, - ) + if not self.train_path_packed.is_file(): + prepare_packed_sequence_data( + input_path=self.train_path, + output_path=self.train_path_packed, + packed_sequence_size=self.packed_sequence_size, + tokenizer=self.tokenizer, + max_seq_length=self.seq_length, + seed=self.seed, + ) + + if not self.validation_path_packed.is_file(): + prepare_packed_sequence_data( + input_path=self.validation_path, + output_path=self.validation_path_packed, + packed_sequence_size=self.packed_sequence_size, + tokenizer=self.tokenizer, + max_seq_length=self.seq_length, + seed=self.seed, + ) def setup(self, stage: str): """Called by pytorch lightning in datamodule setup""" @@ -195,7 +206,7 @@ def val_dataloader(self) -> DataLoader: # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( - self.validation_path, + self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed, is_test=True, **self.dataset_kwargs, ), @@ -249,8 +260,8 @@ def train_path_packed(self) -> Path: """Path to training dataset file for packed sequence. The file path contains a reference to the tokenizer/model name since packed sequence dataset consists of tokenized indices.""" if self.packed_sequence_size > 0: - if self.packed_sequence_specs.packed_data_path is not None: - return self.packed_sequence_specs.packed_data_path + if self.packed_sequence_specs.packed_train_data_path is not None: + return self.packed_sequence_specs.packed_train_data_path tokenizer_model_name = self._extract_tokenizer_model_name() folder_name = self.dataset_root / "packed" / tokenizer_model_name folder_name.mkdir(parents=True, exist_ok=True) @@ -258,6 +269,20 @@ def train_path_packed(self) -> Path: else: raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.") + @property + def validation_path_packed(self) -> Path: + """Path to validation dataset file for packed sequence. The file path contains a reference to the + tokenizer/model name since packed sequence dataset consists of tokenized indices.""" + if self.packed_sequence_size > 0: + if self.packed_sequence_specs.packed_val_data_path is not None: + return self.packed_sequence_specs.packed_val_data_path + tokenizer_model_name = self._extract_tokenizer_model_name() + folder_name = self.dataset_root / "packed" / tokenizer_model_name + folder_name.mkdir(parents=True, exist_ok=True) + return folder_name / f"validation_{self.packed_sequence_size}.npy" + else: + raise ValueError("`validation_path_packed` invalid since packed sequence size is not specified.") + @property def validation_path(self) -> Path: """Path to validation dataset file""" diff --git a/nemo/collections/llm/gpt/data/packed_sequence.py b/nemo/collections/llm/gpt/data/packed_sequence.py index 153e79f94391..345489ea0b63 100644 --- a/nemo/collections/llm/gpt/data/packed_sequence.py +++ b/nemo/collections/llm/gpt/data/packed_sequence.py @@ -101,15 +101,31 @@ class PackedSequenceSpecs: This field is set by llm.finetune api. """ - packed_data_path: str = None + packed_train_data_path: str = None """ - If specified, use the packed dataset from this file instead of the default path. + If specified, use this file for the packed training dataset instead of the default path. + """ + + packed_val_data_path: str = None + """ + If specified, use this file for the packed validation dataset instead of the default path. """ def __post_init__(self): - if self.packed_data_path is not None: - self.packed_data_path = Path(self.packed_data_path) + if self.packed_train_data_path is not None: + self.packed_train_data_path = Path(self.packed_train_data_path) + assert ( + self.packed_train_data_path.suffix == ".npy" + ), f"packed training data file must be a .npy file: {self.packed_train_data_path}" + assert ( + self.packed_train_data_path.exists() + ), f"packed training data file does not exist: {self.packed_train_data_path}" + + if self.packed_val_data_path is not None: + self.packed_val_data_path = Path(self.packed_val_data_path) + assert ( + self.packed_val_data_path.suffix == ".npy" + ), f"packed validation data file must be a .npy file: {self.packed_val_data_path}" assert ( - self.packed_data_path.suffix == ".npy" - ), f"packed data file must be a .npy file: {self.packed_data_path}" - assert self.packed_data_path.exists(), f"packed data file does not exist: {self.packed_data_path}" + self.packed_val_data_path.exists() + ), f"packed validation data file does not exist: {self.packed_val_data_path}" diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 152309536f5b..9f186ebba90f 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -59,6 +59,8 @@ Llama31Config8B, Llama31Config70B, Llama31Config405B, + Llama32Config1B, + Llama32Config3B, LlamaConfig, LlamaModel, ) @@ -134,6 +136,8 @@ "Llama31Config8B", "Llama31Config70B", "Llama31Config405B", + "Llama32Config1B", + "Llama32Config3B", "NemotronConfig", "Nemotron3Config4B", "Nemotron3Config8B", diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 8c3b47835ab1..e411077aca31 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -179,7 +179,7 @@ class GPTConfig(TransformerConfig, io.IOMixin): forward_step_fn: Callable = gpt_forward_step data_step_fn: Callable = gpt_data_step - def configure_model(self, tokenizer) -> "MCoreGPTModel": + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreGPTModel": vp_size = self.virtual_pipeline_model_parallel_size if vp_size: p_size = self.pipeline_model_parallel_size @@ -214,8 +214,8 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel": rotary_percent=self.rotary_percent, rotary_base=self.rotary_base, seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), ) # If using full TE layer, need to set TP, CP group since the module call diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index be6b5e708604..26e4604adc43 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -39,6 +39,7 @@ def __init__( tokenizer=None, loss_fn=masked_cross_entropy, model_transform=None, + model_accelerator=None, trust_remote_code=False, ): super().__init__() @@ -50,12 +51,13 @@ def __init__( self.load_pretrained_weights = load_pretrained_weights self.is_hf_model = True self.model_transform = model_transform + self.model_accelerator = model_accelerator self.trust_remote_code = trust_remote_code @property def tokenizer(self): if self._tokenizer is None: - self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(self.model_name) + self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(self.model_name, self.trust_remote_code) return self._tokenizer @tokenizer.setter @@ -64,8 +66,8 @@ def tokenizer(self, value): self._tokenizer = value @staticmethod - def configure_tokenizer(model_name): - return AutoTokenizer(model_name) + def configure_tokenizer(model_name, trust_remote_code=False): + return AutoTokenizer(model_name, trust_remote_code=trust_remote_code) def configure_model(self): # create all your layers here @@ -78,6 +80,10 @@ def configure_model(self): config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) self.model = AutoModelForCausalLM.from_config(config, trust_remote_code=self.trust_remote_code) + + if self.model_accelerator is not None: + self.model_accelerator(self.model) + self.model.train() def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None): diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index a9d18220bcaf..a7e995addb83 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -14,6 +14,7 @@ import math from dataclasses import dataclass +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Annotated, Callable, Optional @@ -86,7 +87,7 @@ class Llama2Config70B(LlamaConfig): @dataclass -class Llama3Config(GPTConfig): +class Llama3Config(LlamaConfig): num_query_groups: int = 8 hidden_dropout: float = 0.0 attention_dropout: float = 0.0 @@ -115,8 +116,8 @@ class Llama31Config(Llama3Config): old_context_len: int = 8192 init_method_std: float = 0.02 - def configure_model(self, tokenizer) -> "MCoreGPTModel": - model = super().configure_model(tokenizer) + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreGPTModel": + model = super().configure_model(tokenizer, pre_process, post_process) # Apply rope scaling for Llama3.1 model model.rotary_pos_emb.inv_freq = apply_rope_scaling( model.rotary_pos_emb.inv_freq, @@ -182,6 +183,32 @@ class Llama31Config405B(Llama31Config): make_vocab_size_divisible_by: int = 128 +@dataclass +class Llama32Config1B(Llama31Config): + scale_factor: int = 32 + share_embeddings_and_output_weights: bool = True + rotary_base: int = 500_000 + num_layers: int = 16 + hidden_size: int = 2048 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 32 + num_query_groups: int = 8 + make_vocab_size_divisible_by: int = 128 + + +@dataclass +class Llama32Config3B(Llama31Config): + scale_factor: int = 32 + share_embeddings_and_output_weights: bool = True + rotary_base: int = 500_000 + num_layers: int = 28 + hidden_size: int = 3072 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 24 + num_query_groups: int = 8 + make_vocab_size_divisible_by: int = 128 + + @dataclass class CodeLlamaConfig7B(Llama2Config7B): rotary_base: int = 1_000_000 @@ -252,6 +279,9 @@ def convert_state(self, source, target): "model.norm.weight": "decoder.final_layernorm.weight", "lm_head.weight": "output_layer.weight", } + if getattr(source.config, "tie_word_embeddings", False): + # llama 3.2 1B and 3B models have no shared input output embeddings + del mapping["lm_head.weight"] return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) @@ -275,7 +305,7 @@ def make_vocab_size_divisible_by(vocab_size): if getattr(source, 'rope_scaling', None) is not None and source.rope_scaling.get('rope_type') == 'llama3': # Apply Llama3.1 customize rope scaling - cls = Llama31Config + cls = partial(Llama31Config, scale_factor=source.rope_scaling.get("factor", 8.0)) else: cls = LlamaConfig output = cls( @@ -289,7 +319,7 @@ def make_vocab_size_divisible_by(vocab_size): rotary_base=source.rope_theta, gated_linear_unit=True, make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), - share_embeddings_and_output_weights=False, + share_embeddings_and_output_weights=getattr(source, "tie_word_embeddings", False), fp16=(dtype_from_hf(source) == torch.float16), bf16=(dtype_from_hf(source) == torch.bfloat16), params_dtype=dtype_from_hf(source), @@ -355,6 +385,7 @@ def config(self) -> "HFLlamaConfig": num_key_value_heads=source.num_query_groups, rope_theta=source.rotary_base, vocab_size=self.tokenizer.vocab_size, + tie_word_embeddings=source.share_embeddings_and_output_weights, ) @@ -509,6 +540,8 @@ def apply_rope_scaling( "Llama31Config8B", "Llama31Config70B", "Llama31Config405B", + "Llama32Config1B", + "Llama32Config3B", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", diff --git a/nemo/collections/llm/gpt/model/ssm.py b/nemo/collections/llm/gpt/model/ssm.py index d38a690cb4ad..f4190114042e 100644 --- a/nemo/collections/llm/gpt/model/ssm.py +++ b/nemo/collections/llm/gpt/model/ssm.py @@ -86,7 +86,7 @@ class SSMConfig(TransformerConfig, io.IOMixin): data_step_fn: Callable = gpt_data_step tokenizer_model_path: str = None - def configure_model(self, tokenizer) -> "MCoreMambaModel": + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreMambaModel": return MCoreMambaModel( self, @@ -101,8 +101,8 @@ def configure_model(self, tokenizer) -> "MCoreMambaModel": rotary_percent=self.rotary_percent, rotary_base=self.rotary_base, seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), ) @@ -290,6 +290,7 @@ class BaseMambaConfig2_7B(SSMConfig): @dataclass class NVIDIAMambaConfig8B(SSMConfig): hybrid_override_pattern: str = "M" * 56 + num_attention_heads: int = 32 num_layers: int = 56 seq_length: int = 4096 hidden_size: int = 4096 diff --git a/nemo/collections/llm/peft/__init__.py b/nemo/collections/llm/peft/__init__.py index 11511ffe72b7..1dcc070a5a97 100644 --- a/nemo/collections/llm/peft/__init__.py +++ b/nemo/collections/llm/peft/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.peft.api import gpt_lora, merge_lora from nemo.collections.llm.peft.dora import DoRA from nemo.collections.llm.peft.lora import LoRA @@ -23,4 +23,4 @@ "dora": DoRA, } -__all__ = ["LoRA", "DoRA", "gpt_lora", "PEFT_STR2CLS"] +__all__ = ["LoRA", "DoRA", "gpt_lora", "PEFT_STR2CLS", "merge_lora"] diff --git a/nemo/collections/llm/peft/api.py b/nemo/collections/llm/peft/api.py index 85c0ae6cae41..a089a6d17515 100644 --- a/nemo/collections/llm/peft/api.py +++ b/nemo/collections/llm/peft/api.py @@ -12,9 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.llm.peft.lora import LoRA +import json +from pathlib import Path +from typing import Tuple, Union + +import pytorch_lightning as pl +from megatron.core import dist_checkpointing +from pytorch_lightning.trainer.states import TrainerFn + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.llm.peft.lora import LoRA, LoRAMerge from nemo.collections.llm.utils import factory +from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib, io +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir +from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir +from nemo.lightning.pytorch.callbacks import PEFT from nemo.lightning.pytorch.callbacks.peft import PEFT +from nemo.lightning.pytorch.strategies.utils import RestoreConfig +from nemo.utils import logging @factory @@ -22,4 +37,108 @@ def gpt_lora() -> PEFT: return LoRA() -__all__ = ["gpt_lora"] +def merge_lora( + lora_checkpoint_path: str, + output_path: str, +) -> None: + """ + Merges the LoRA adapter weights into the base model's weights. + + Python Usage: + ```python + if __name__ == '__main__': + llm.peft.merge_lora( + lora_checkpoint_path=your_lora_checkpoint_path, + output_path=your_output_path, + ) + ``` + + Args: + lora_checkpoint_path: The path to the LoRA checkpoint. + output_path: The path to save the merged checkpoint. + + """ + from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed + + trainer = Trainer( + devices=1, + accelerator="cpu", + strategy=MegatronStrategy(ddp="pytorch", setup_optimizers=False, plugins=bf16_mixed()), + ) + + model, lora = _load_base_model_and_lora(lora_checkpoint_path) + _setup_trainer_and_restore_model_and_adapter(Path(lora_checkpoint_path), trainer, model, lora) + + lora_merge = LoRAMerge() + merged_model = lora_merge(trainer.strategy.megatron_parallel) + merged_weights = {k: v for k, v in merged_model.sharded_state_dict().items() if ".adapter." not in k} + _save_merged_weight(output_path, merged_weights, model, trainer) + + +def _load_base_model_and_lora(lora_checkpoint_path: Path) -> Tuple[pl.LightningModule, LoRA]: + model = io.load_context(ckpt_to_context_subdir(lora_checkpoint_path), "model") + model.model_transform, model.__io__.model_transform = None, None + model.config.bf16 = False + lora: Union[io.TrainerContext, LoRA] = io.load_context( + ckpt_to_context_subdir(lora_checkpoint_path), "model.model_transform" + ) + assert isinstance(lora, LoRA), "LoRA config not found in checkpoint" + return model, lora + + +def _setup_trainer_and_restore_model_and_adapter( + lora_checkpoint_path: Path, trainer: Trainer, model: pl.LightningModule, lora: LoRA +) -> None: + if ( + adapter_meta_path := ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False) / ADAPTER_META_FILENAME + ).exists(): + with open(adapter_meta_path, "r") as f: + metadata = json.load(f) + restore_config = RestoreConfig( + path=metadata["model_ckpt_path"], + load_model_state=True, + load_optim_state=False, + ) + else: + raise ValueError(f"Cannot find adapter meta file in {lora_checkpoint_path}") + + trainer.strategy.restore_config = restore_config + trainer.strategy._setup_optimizers = False + trainer.ckpt_path = None + trainer.strategy.connect(model) + trainer.strategy.setup_environment() + + if not model.state_dict(): + with _strategy_lib.megatron_cpu_init_context(model.config): + model.configure_model() + + trainer.strategy.setup(trainer) # load base model ckpt + trainer.state.fn = TrainerFn.TESTING + trainer.strategy.setup_megatron_parallel(trainer=trainer) + trainer.strategy.trainer = trainer + model.trainer = trainer + + lora(model) + adapter_sharded_state_dict = { + k: v for k, v in trainer.strategy.megatron_parallel.sharded_state_dict().items() if ".adapter." in k + } + adapter_state = trainer.strategy.checkpoint_io.load_checkpoint( + ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict + ) + trainer.strategy.load_model_state_dict(adapter_state, strict=False) + + +def _save_merged_weight(output_path: str, merged_weights: dict, model: pl.LightningModule, trainer: Trainer): + weight_path = ckpt_to_weights_subdir(output_path, is_saving=True) + Path(weight_path).mkdir(parents=True, exist_ok=True) + dist_checkpointing.save(merged_weights, str(ckpt_to_weights_subdir(output_path, is_saving=True))) + if hasattr(model.tokenizer, "save_pretrained"): + model.tokenizer.save_pretrained("/tmp/nemo_tokenizer") + model.tokenizer = AutoTokenizer("/tmp/nemo_tokenizer") + if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'): + trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__ + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(output_path), yaml_attrs=["model"]) + logging.info(f"Merged checkpoint saved to {output_path}") + + +__all__ = ["gpt_lora", "merge_lora"] diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py index 55f807bf1f9c..205cde071fa7 100644 --- a/nemo/collections/llm/peft/lora.py +++ b/nemo/collections/llm/peft/lora.py @@ -124,6 +124,7 @@ class LoRA(PEFT): dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'post'. + a2a_experimental (bool): Enables the experimental All-to-All (A2A) communication strategy. Defaults to False. Example: -------- @@ -151,6 +152,7 @@ class LoRA(PEFT): dropout_position: Literal['pre', 'post'] = 'post' lora_A_init_method: str = "xavier" lora_B_init_method: str = "zero" + a2a_experimental: bool = False def transform(self, m: nn.Module, name=None, prefix=None): """ @@ -224,6 +226,47 @@ def wildcard_match(pattern, key): model_parallel_config=getattr(m, "config", None), alpha=self.alpha, is_expert=is_expert_linear(full_name), + a2a_experimental=self.a2a_experimental, ) return AdapterParallelAdd(m, adapter) return m + + +class LoRAMerge(PEFT): + """ + Implements the LoRA weight merge for parameter-efficient fine-tuning. + + Example: + -------- + >>> from nemo.collections.llm.peft.lora import LoRAMerge + >>> lora_merge = LoRAMerge() + >>> merged_model = lora_merge(trainer.strategy.megatron_parallel) + """ + + @torch.no_grad() + def transform(self, m: nn.Module, name=None, prefix=None): + """ + Merges the LoRA adapter with the base model weights. + + Args: + m (nn.Module): The module to apply LoRA merge to. + name (str, optional): Name of the module to merge. Defaults to None. + prefix (str, optional): Prefix for the module name. Defaults to None. + + Returns: + nn.Module: The modified module with the LoRA adapter merged into the base model weights. + """ + + if not isinstance(m, AdapterParallelAdd): + return m + logging.info(f'merging {(prefix if prefix else "") + "." + (name if name else "")}') + base_weight = m.to_wrap.weight + lora_weight = ( + m.adapter.alpha + / m.adapter.dim + * m.adapter.linear_out.weight.to(base_weight.device) + @ m.adapter.linear_in.weight.to(base_weight.device) + ) + merged_weight = base_weight + lora_weight + m.to_wrap.weight.data = merged_weight + return m diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 45f72f06741e..d41ba39f39ea 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -24,10 +24,12 @@ from tqdm import tqdm from nemo.collections import llm -from nemo.lightning.ckpt_utils import CONTEXT_PATH +from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.megatron_parallel import MegatronParallel from nemo.utils import logging -from .utils import get_unwrapped_mcore_model +from .utils import get_modelopt_decoder_type, get_unwrapped_mcore_model try: import modelopt.torch.quantization as mtq @@ -83,35 +85,12 @@ class ExportConfig: decoder_type: Optional[str] = None inference_tensor_parallel: int = 1 inference_pipeline_parallel: int = 1 + generate_sample: bool = False def __post_init__(self): self.path = Path(self.path) -def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: - """Infers the modelopt decoder type from GPTConfig class.""" - mapping = [ - (llm.Baichuan2Config, "baichuan"), - (llm.ChatGLMConfig, "chatglm"), - (llm.GemmaConfig, "gemma"), - (llm.LlamaConfig, "llama"), - (llm.MistralConfig7B, "llama"), - (llm.MixtralConfig, "llama"), - (llm.NemotronConfig, "gptnext"), - (llm.Qwen2Config, "qwen"), - # TODO: (llm.StarcoderConfig, ""), - (llm.Starcoder2Config, "gptnext"), - ] - - for config_class, decoder_type in mapping: - if isinstance(config, config_class): - return decoder_type - - logging.warning("Could not directly infer the decoder type") - # TODO: Add a reasonable behavior for GPTConfig (for instance based on position_embedding_type) - return "llama" - - class Quantizer: """Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints. @@ -146,16 +125,37 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" self.torch_dtype = torch_dtype_from_precision(dtype) - def _setup(self, model: llm.GPTModel) -> None: + @staticmethod + def _setup(model: MegatronParallel) -> None: """Setup model for quantization.""" # TODO: disable activation checkpointing model.config.vocab_size = model.tokenizer.vocab_size model.freeze() - def _get_decoder_type(self, config: llm.GPTConfig): - return self.export_config.decoder_type or get_modelopt_decoder_type(config) + def _get_decoder_type(self, model: MegatronParallel): + if self.export_config.decoder_type is not None: + return self.export_config.decoder_type + unwrapped_model = model + while not isinstance(unwrapped_model, llm.GPTModel): + unwrapped_model = unwrapped_model.module + + return get_modelopt_decoder_type(unwrapped_model) + + @staticmethod + def _generate_sample(model: MegatronParallel): + prompts = ["Born in north-east France, Soyer trained as a", "Born in California, Soyer trained as a"] + + mcore_tokenizer = MCoreTokenizerWrappper(model.tokenizer) + mcore_inference = model.get_inference_wrapper( + params_dtype=torch.bfloat16, inference_batch_times_seqlen_threshold=30 + ) + + generated = [r.generated_text for r in generate(mcore_inference, mcore_tokenizer, prompts)] + outputs = [prompt + generation for prompt, generation in zip(prompts, generated)] + + logging.info(f'Sample generation after PTQ (with prompts): {outputs}') - def quantize(self, model: llm.GPTModel, forward_loop=None): + def quantize(self, model: MegatronParallel, forward_loop=None): """Quantize the model and calibrate using given forward loop.""" if forward_loop is None: get_dataloader = create_data_iterator_getter( @@ -185,7 +185,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): self._setup(model) unwrapped_model = get_unwrapped_mcore_model(model) - decoder_type = self._get_decoder_type(unwrapped_model.config) + decoder_type = self._get_decoder_type(model) quant_cfg = QUANT_CFG_CHOICES[algorithm] if "awq" in algorithm: weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] @@ -230,6 +230,10 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): if dist.get_rank() == 0: mtq.print_quant_summary(unwrapped_model) + if self.export_config.generate_sample: + logging.info("Generating a sample output after model quantization.") + self._generate_sample(model) + return model def create_megatron_forward_loop( @@ -266,21 +270,34 @@ def loop(model): return loop - def export(self, model: llm.GPTModel, model_dir: str) -> None: + @staticmethod + def _validate_quantized_checkpoint(checkpoint_dir: Path, tensor_parallelism_size: int) -> bool: + """Basic validation of the model structure.""" + + saved_config = (checkpoint_dir / 'config.json').exists() + saved_weights = True + for i in range(tensor_parallelism_size): + saved_weights &= (checkpoint_dir / f'rank{i}.safetensors').exists() + + export_successful = saved_config and saved_weights + if not export_successful: + logging.error("Failed to export the quantized model.") + return export_successful + + def export(self, model: MegatronParallel, model_dir: str) -> None: """Export model to a TensorRT-LLM checkpoint.""" - assert self.export_config is not None, "Export config is not set" - # TODO: Add sample generate - # TODO: Support megatron_amp_O2 export_dir = self.export_config.path + inference_tp = self.export_config.inference_tensor_parallel + inference_pp = self.export_config.inference_pipeline_parallel use_nfs_workspace = model.config.pipeline_model_parallel_size > 1 export_tensorrt_llm_checkpoint( model=get_unwrapped_mcore_model(model), - decoder_type=self._get_decoder_type(model.config), + decoder_type=self._get_decoder_type(model), dtype=self.torch_dtype, export_dir=export_dir, - inference_tensor_parallel=self.export_config.inference_tensor_parallel, - inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, + inference_tensor_parallel=inference_tp, + inference_pipeline_parallel=inference_pp, use_nfs_workspace=use_nfs_workspace, ) dist.barrier() @@ -288,14 +305,13 @@ def export(self, model: llm.GPTModel, model_dir: str) -> None: # Save the model context in order to restore its tokenizer later. The destination # path is "nemo_context" as this name is used in nemo.export to setup tokenizer. if dist.get_rank() == 0: + assert self._validate_quantized_checkpoint(export_dir, inference_tp) shutil.copytree( - os.path.join(model_dir, CONTEXT_PATH), + ckpt_to_context_subdir(model_dir), os.path.join(export_dir, "nemo_context"), dirs_exist_ok=True, ) - logging.info("Model context saved.") - - logging.info(f"Export succeeded, model has been exported to {export_dir}.") + logging.info(f"Export succeeded, model has been exported to {export_dir}.") def get_calib_data_iter( @@ -323,7 +339,7 @@ def get_calib_data_iter( def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): """Create a function that provides iterator over a given dataset.""" - def _iterator(): + def _get_iterator(): CHARACTERS_PER_TOKEN = 4 dataloader = get_calib_data_iter( @@ -332,14 +348,13 @@ def _iterator(): batch_size=batch_size, calib_size=calibration_size, ) + + data = [] for batch in dataloader: batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] - yield torch.tensor(batch, device=model.device) + data.append(torch.tensor(batch, device=model.device)) - def _iterator_getter(): - dataloader = _iterator() - dataloader = [data for data in dataloader] - return iter(tqdm(dataloader)) + return iter(tqdm(data)) - return _iterator_getter + return _get_iterator diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index bdfccb208d06..20739c872e80 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -23,8 +23,33 @@ from nemo.utils import logging +def get_modelopt_decoder_type(model: llm.GPTModel) -> str: + """Infers the modelopt decoder type from GPTModel subclass.""" + mapping = [ + (llm.Baichuan2Model, "baichuan"), + (llm.ChatGLMModel, "chatglm"), + (llm.Gemma2Model, "gemma2"), + (llm.GemmaModel, "gemma"), + (llm.LlamaModel, "llama"), + (llm.MistralModel, "llama"), + (llm.MixtralModel, "llama"), + (llm.NemotronModel, "gptnext"), + (llm.Qwen2Model, "qwen"), + (llm.StarcoderModel, "gptnext"), + (llm.Starcoder2Model, "gptnext"), + (llm.Phi3Model, "phi3"), + ] + + for config_class, decoder_type in mapping: + if isinstance(model, config_class): + return decoder_type + + logging.warning("Could not infer the decoder type") + return None + + def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: - """Modify model config for TensorRT Model Optimizer""" + """Modify model config for TensorRT-Model-Optimizer quantization""" from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import ( get_gpt_layer_modelopt_spec, @@ -46,7 +71,9 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: def load_with_modelopt_layer_spec( nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1, inference_only: bool = True ): - # TODO: setting ddp="pytorch" with manually deleting model.optim is a hackish way to disable DDP initialization. Needs a systematic solution. + """Loads a model from a NeMo 2.0 checkpoint using modelopt layer spec.""" + # TODO: setting ddp="pytorch" and deleting model.optim is a hackish way to disable DDP initialization. + # Needs a systematic solution. if inference_only: strategy = nl.MegatronStrategy( tensor_model_parallel_size=calib_tp, @@ -81,6 +108,7 @@ def load_with_modelopt_layer_spec( def get_unwrapped_mcore_model(model): + """Unwraps NeMo 2.0 to base MCore model.""" from megatron.core.models.gpt import GPTModel as MCoreGPTModel unwrapped_model = model diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index e76729d5e31a..1db88f633e89 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -33,6 +33,8 @@ llama31_8b, llama31_70b, llama31_405b, + llama32_1b, + llama32_3b, mamba2_1_3b, mamba2_2_7b, mamba2_8b, @@ -73,6 +75,7 @@ ) from nemo.collections.llm.recipes.log.default import default_log, default_resume from nemo.collections.llm.recipes.optim import adam +from nemo.collections.llm.recipes.run.executor import torchrun __all__ = [ "baichuan2_7b", @@ -88,6 +91,8 @@ "llama31_8b", "llama31_70b", "llama31_405b", + "llama32_1b", + "llama32_3b", "mamba2_130m", "mamba2_370m", "mamba2_780m", @@ -134,4 +139,5 @@ "adam", "default_log", "default_resume", + "torchrun", ] diff --git a/nemo/collections/llm/recipes/baichuan2_7b.py b/nemo/collections/llm/recipes/baichuan2_7b.py index 823f6e07cd57..1350cbaa7edd 100644 --- a/nemo/collections/llm/recipes/baichuan2_7b.py +++ b/nemo/collections/llm/recipes/baichuan2_7b.py @@ -25,7 +25,7 @@ from nemo.collections.llm import Baichuan2Config7B, Baichuan2Model from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -254,8 +254,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,8 +281,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/chatglm3_6b.py b/nemo/collections/llm/recipes/chatglm3_6b.py index b6c640372074..2cd424ce5bf6 100644 --- a/nemo/collections/llm/recipes/chatglm3_6b.py +++ b/nemo/collections/llm/recipes/chatglm3_6b.py @@ -25,7 +25,7 @@ from nemo.collections.llm import ChatGLM3Config6B, ChatGLMModel from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -254,8 +254,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,8 +281,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py index f05fd7cb2d13..e8af7f67bdbd 100644 --- a/nemo/collections/llm/recipes/finetune_default.py +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -21,9 +21,11 @@ import nemo.lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.peft import DoRA, LoRA from nemo.collections.llm.recipes.log.default import tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.lightning.pytorch.callbacks import PEFT def default_finetune_recipe( @@ -158,3 +160,41 @@ def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]: nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path=f"nemo://{model_id}"), ) + + +@run.cli.factory(name='lora') +def lora() -> run.Config[PEFT]: + """ + Factory function to create a LoRA configuration. + + Returns: + run.Config[PEFT]: Configuration for the LoRA class. + + Examples: + CLI usage: + $ nemo llm finetune -f llama3_8b peft=lora + + Python API usage: + >>> lora_config = lora() + >>> print(lora_config) + """ + return run.Config(LoRA) + + +@run.cli.factory(name='dora') +def dora() -> run.Config[PEFT]: + """ + Factory function to create a DoRA configuration. + + Returns: + run.Config[PEFT]: Configuration for the DoRA class. + + Examples: + CLI usage: + $ nemo llm finetune -f llama3_8b peft=dora + + Python API usage: + >>> dora_config = dora() + >>> print(dora_config) + """ + return run.Config(DoRA) diff --git a/nemo/collections/llm/recipes/gemma2_27b.py b/nemo/collections/llm/recipes/gemma2_27b.py index 2025bd570503..d6b41c0a221c 100644 --- a/nemo/collections/llm/recipes/gemma2_27b.py +++ b/nemo/collections/llm/recipes/gemma2_27b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.gemma2 import gemma2_model, gemma2_trainer from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -220,8 +222,8 @@ def finetune_recipe( recipe.optim.config.lr = 5e-6 recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 2 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/collections/llm/recipes/gemma2_2b.py b/nemo/collections/llm/recipes/gemma2_2b.py index e1aa3ad4be86..138140d0515d 100644 --- a/nemo/collections/llm/recipes/gemma2_2b.py +++ b/nemo/collections/llm/recipes/gemma2_2b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.gemma2 import gemma2_model, gemma2_trainer from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -218,8 +220,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/gemma2_9b.py b/nemo/collections/llm/recipes/gemma2_9b.py index 8117102f1b75..c49ac0246307 100644 --- a/nemo/collections/llm/recipes/gemma2_9b.py +++ b/nemo/collections/llm/recipes/gemma2_9b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.gemma2 import gemma2_model, gemma2_trainer from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 recipe.trainer.strategy.tensor_model_parallel_size = 4 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/gemma_2b.py b/nemo/collections/llm/recipes/gemma_2b.py index 8798af436a9c..8bdf89696d56 100644 --- a/nemo/collections/llm/recipes/gemma_2b.py +++ b/nemo/collections/llm/recipes/gemma_2b.py @@ -24,7 +24,7 @@ from nemo.collections.llm import GemmaConfig2B, GemmaModel from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -253,8 +253,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -284,8 +286,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.context_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/gemma_7b.py b/nemo/collections/llm/recipes/gemma_7b.py index 0bfd62b33e9e..46c91e27575a 100644 --- a/nemo/collections/llm/recipes/gemma_7b.py +++ b/nemo/collections/llm/recipes/gemma_7b.py @@ -24,7 +24,7 @@ from nemo.collections.llm import GemmaConfig7B, GemmaModel from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -256,8 +256,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -287,8 +289,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index d71f3791a0af..5f08d82bd888 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -26,7 +26,7 @@ from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config405B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -266,7 +266,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -296,7 +297,7 @@ def finetune_recipe( if num_nodes is None: if peft_scheme is None or peft_scheme.lower() == 'none': num_nodes = 12 - elif peft_scheme.lower() == 'lora': + elif peft_scheme.lower() in ['lora', 'dora']: num_nodes = 3 recipe = default_finetune_recipe( @@ -307,8 +308,8 @@ def finetune_recipe( recipe.trainer.strategy.pipeline_model_parallel_size = 14 recipe.data.global_batch_size = 6 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 16 recipe.peft.alpha = 32 recipe.optim.config.use_distributed_optimizer = False @@ -348,7 +349,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. diff --git a/nemo/collections/llm/recipes/llama31_70b.py b/nemo/collections/llm/recipes/llama31_70b.py index 37e809e5bc8f..3120fedd7923 100644 --- a/nemo/collections/llm/recipes/llama31_70b.py +++ b/nemo/collections/llm/recipes/llama31_70b.py @@ -26,7 +26,7 @@ from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config70B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -266,7 +266,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -300,7 +301,7 @@ def finetune_recipe( if num_nodes is None: if peft_scheme is None or peft_scheme.lower() == 'none': num_nodes = 4 - elif peft_scheme.lower() == 'lora': + elif peft_scheme.lower() in ['lora', 'dora']: num_nodes = 1 recipe = default_finetune_recipe( @@ -310,8 +311,8 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 16 recipe.peft.alpha = 32 recipe.optim.config.use_distributed_optimizer = False @@ -349,7 +350,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. diff --git a/nemo/collections/llm/recipes/llama31_8b.py b/nemo/collections/llm/recipes/llama31_8b.py index 32a77ce076f2..62514940b678 100644 --- a/nemo/collections/llm/recipes/llama31_8b.py +++ b/nemo/collections/llm/recipes/llama31_8b.py @@ -26,7 +26,7 @@ from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config8B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -266,7 +266,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -303,8 +304,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 8 recipe.peft.alpha = 16 recipe.optim.config.use_distributed_optimizer = False @@ -341,7 +342,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. diff --git a/nemo/collections/llm/recipes/llama32_1b.py b/nemo/collections/llm/recipes/llama32_1b.py new file mode 100644 index 000000000000..32675adf3686 --- /dev/null +++ b/nemo/collections/llm/recipes/llama32_1b.py @@ -0,0 +1,270 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch +from lightning.pytorch.callbacks.callback import Callback +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama32Config1B, LlamaModel +from nemo.collections.llm.peft import PEFT_STR2CLS +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama32_1b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.2 1B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.2 1B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama32_1b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama32Config1B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.2 1B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama32_1b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=1) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.2 1B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama32_1b + $ nemo llm pretrain --factory "llama32_1b(num_nodes=1, name='my_1b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama32_1b_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.2 1B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama32_1b + + Python API usage: + >>> recipe = finetune_recipe(name="llama32_1b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.2-1B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 1 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + return recipe diff --git a/nemo/collections/llm/recipes/llama32_3b.py b/nemo/collections/llm/recipes/llama32_3b.py new file mode 100644 index 000000000000..d78ea0b50983 --- /dev/null +++ b/nemo/collections/llm/recipes/llama32_3b.py @@ -0,0 +1,270 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch +from lightning.pytorch.callbacks.callback import Callback +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama32Config3B, LlamaModel +from nemo.collections.llm.peft import PEFT_STR2CLS +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama32_3b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.2 3B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.2 3B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama32_3b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama32Config3B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.2 3B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama32_3b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=1) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.2 3B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama32_3b + $ nemo llm pretrain --factory "llama32_3b(num_nodes=1, name='my_3b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama32_3b_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.2 3B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama32_3b + + Python API usage: + >>> recipe = finetune_recipe(name="llama32_3b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.2-3B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 1 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index 93aeb5c07dc1..8b61bff80e01 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -26,7 +26,7 @@ from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama3Config70B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -263,7 +263,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -297,7 +298,7 @@ def finetune_recipe( if num_nodes is None: if peft_scheme is None or peft_scheme.lower() == 'none': num_nodes = 4 - elif peft_scheme.lower() == 'lora': + elif peft_scheme.lower() in ['lora', 'dora']: num_nodes = 1 recipe = default_finetune_recipe( @@ -307,8 +308,8 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 16 recipe.peft.alpha = 32 recipe.optim.config.use_distributed_optimizer = False @@ -346,7 +347,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 3e26f8fe1082..36b20c12ddb2 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -26,7 +26,7 @@ from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -250,7 +250,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -287,8 +288,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 8 recipe.peft.alpha = 16 recipe.optim.config.use_distributed_optimizer = False @@ -325,7 +326,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. diff --git a/nemo/collections/llm/recipes/mamba2_130m.py b/nemo/collections/llm/recipes/mamba2_130m.py index 3f13f91f6609..e70fec03b3fb 100644 --- a/nemo/collections/llm/recipes/mamba2_130m.py +++ b/nemo/collections/llm/recipes/mamba2_130m.py @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_1_3b.py b/nemo/collections/llm/recipes/mamba2_1_3b.py index 1a280b8b92a1..aaa263078686 100644 --- a/nemo/collections/llm/recipes/mamba2_1_3b.py +++ b/nemo/collections/llm/recipes/mamba2_1_3b.py @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -157,7 +162,17 @@ def pretrain_recipe( name: str = "default", tokenizer_model: str = None, num_nodes: int = 1, - num_gpus_per_node: int = 8, + num_gpus_per_node: int = 1, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -191,17 +206,24 @@ def pretrain_recipe( """ return run.Partial( fn, - model=model(), + model=model(tokenizer_model=tokenizer_model), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), @@ -217,7 +239,15 @@ def finetune_recipe( resume_path: str = None, tokenizer_model: str = None, num_nodes: int = 1, - num_gpus_per_node: int = 8, + num_gpus_per_node: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_2_7b.py b/nemo/collections/llm/recipes/mamba2_2_7b.py index 0915cec748dd..b4fd5b487b6a 100644 --- a/nemo/collections/llm/recipes/mamba2_2_7b.py +++ b/nemo/collections/llm/recipes/mamba2_2_7b.py @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_370m.py b/nemo/collections/llm/recipes/mamba2_370m.py index bb063dfcfc3f..6fa619b33486 100644 --- a/nemo/collections/llm/recipes/mamba2_370m.py +++ b/nemo/collections/llm/recipes/mamba2_370m.py @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_780m.py b/nemo/collections/llm/recipes/mamba2_780m.py index e89905b2269a..45d28f82f779 100644 --- a/nemo/collections/llm/recipes/mamba2_780m.py +++ b/nemo/collections/llm/recipes/mamba2_780m.py @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_8b.py b/nemo/collections/llm/recipes/mamba2_8b.py index 873d79fcb0f0..8f8384b45059 100644 --- a/nemo/collections/llm/recipes/mamba2_8b.py +++ b/nemo/collections/llm/recipes/mamba2_8b.py @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(name=NAME) def trainer( tensor_parallelism: int = 8, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 8, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -191,17 +206,24 @@ def pretrain_recipe( """ return run.Partial( fn, - model=model(), + model=model(tokenizer_model=tokenizer_model), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), @@ -218,6 +240,14 @@ def finetune_recipe( name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 8, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=8, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_hybrid_8b.py b/nemo/collections/llm/recipes/mamba2_hybrid_8b.py index 09bb88a57089..b91c8e228bc9 100644 --- a/nemo/collections/llm/recipes/mamba2_hybrid_8b.py +++ b/nemo/collections/llm/recipes/mamba2_hybrid_8b.py @@ -39,7 +39,7 @@ def tokenizer(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: return run.Config( get_nmt_tokenizer, - library='megatronNVIDIAMambaConfig8B', + library='megatron', model_name="GPTSentencePieceTokenizer", tokenizer_model=tokenizer_model, use_fast=True, @@ -69,6 +69,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 8, pipeline_parallelism: int = 1, @@ -78,7 +79,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -139,15 +144,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -160,6 +165,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 8, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,17 +208,24 @@ def pretrain_recipe( """ return run.Partial( fn, - model=model(), + model=model(tokenizer_model=tokenizer_model), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), @@ -220,6 +242,14 @@ def finetune_recipe( name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 8, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -268,8 +298,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=8, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -285,10 +315,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -298,7 +329,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -306,7 +336,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mistral_7b.py b/nemo/collections/llm/recipes/mistral_7b.py index 3bc1e568185a..9e2d2e256fbe 100644 --- a/nemo/collections/llm/recipes/mistral_7b.py +++ b/nemo/collections/llm/recipes/mistral_7b.py @@ -24,9 +24,8 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -207,8 +206,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -237,8 +238,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/mistral_nemo_12b.py b/nemo/collections/llm/recipes/mistral_nemo_12b.py index 7d9fa1d792e9..a10f8ae804b8 100644 --- a/nemo/collections/llm/recipes/mistral_nemo_12b.py +++ b/nemo/collections/llm/recipes/mistral_nemo_12b.py @@ -24,9 +24,8 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMoConfig12B -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -255,8 +254,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -285,8 +286,10 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config( + PEFT_STR2CLS[peft_scheme.lower()], target_modules=['linear_qkv', 'linear_proj'], dim=32 + ) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/mixtral_8x22b.py b/nemo/collections/llm/recipes/mixtral_8x22b.py index 16e6168e649b..ec1641a08d80 100644 --- a/nemo/collections/llm/recipes/mixtral_8x22b.py +++ b/nemo/collections/llm/recipes/mixtral_8x22b.py @@ -24,9 +24,8 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x22B, MixtralModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -227,7 +226,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: ), run.Config( MegatronCommOverlapCallback, - overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing align_param_gather=True, ), ] @@ -259,8 +258,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. Returns: run.Partial: Partial configuration for fine-tuning. @@ -286,8 +287,10 @@ def finetune_recipe( recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 14 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config( + PEFT_STR2CLS[peft_scheme.lower()], target_modules=['linear_qkv', 'linear_proj'], dim=32 + ) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index 5fbb0ac22c61..d06e22fc2180 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -24,9 +24,8 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -222,7 +221,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: run.Config(MegatronTokenDropCallback), run.Config( MegatronCommOverlapCallback, - overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing. + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing. align_param_gather=True, ), ] @@ -254,8 +253,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -280,8 +281,10 @@ def finetune_recipe( recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config( + PEFT_STR2CLS[peft_scheme.lower()], target_modules=['linear_qkv', 'linear_proj'], dim=32 + ) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron3_22b.py b/nemo/collections/llm/recipes/nemotron3_22b.py index 2dd9c3ff5205..4c763301bc52 100644 --- a/nemo/collections/llm/recipes/nemotron3_22b.py +++ b/nemo/collections/llm/recipes/nemotron3_22b.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional import lightning.pytorch as pl import nemo_run as run @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -239,8 +239,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -265,8 +267,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron3_4b.py b/nemo/collections/llm/recipes/nemotron3_4b.py index c208ee740265..fc6f09a09358 100644 --- a/nemo/collections/llm/recipes/nemotron3_4b.py +++ b/nemo/collections/llm/recipes/nemotron3_4b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -216,8 +218,8 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron3_8b.py b/nemo/collections/llm/recipes/nemotron3_8b.py index 7799512c6260..f60463330cad 100644 --- a/nemo/collections/llm/recipes/nemotron3_8b.py +++ b/nemo/collections/llm/recipes/nemotron3_8b.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional import lightning.pytorch as pl import nemo_run as run @@ -21,8 +21,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -256,8 +255,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -282,8 +283,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron4_15b.py b/nemo/collections/llm/recipes/nemotron4_15b.py index ad0f884b0d3b..49f92fcc1616 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b.py +++ b/nemo/collections/llm/recipes/nemotron4_15b.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional import lightning.pytorch as pl import nemo_run as run @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -228,8 +228,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -254,8 +256,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron4_340b.py b/nemo/collections/llm/recipes/nemotron4_340b.py index b22abc43d558..14d4c0f32d11 100644 --- a/nemo/collections/llm/recipes/nemotron4_340b.py +++ b/nemo/collections/llm/recipes/nemotron4_340b.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional import lightning.pytorch as pl import nemo_run as run import torch -from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -240,8 +239,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -268,8 +269,8 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 12 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 1e-4 diff --git a/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py b/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py index 1cbc877dc33e..73bbe4735adb 100644 --- a/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py +++ b/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py @@ -25,7 +25,7 @@ from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.phi3mini import Phi3ConfigMini, Phi3Model -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -222,8 +222,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', - 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -260,8 +260,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 1 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 8 recipe.peft.alpha = 16 recipe.optim.config.use_distributed_optimizer = False diff --git a/nemo/collections/llm/recipes/qwen2_1p5b.py b/nemo/collections/llm/recipes/qwen2_1p5b.py index a3d705c4fb3a..99ba5cd907fc 100644 --- a/nemo/collections/llm/recipes/qwen2_1p5b.py +++ b/nemo/collections/llm/recipes/qwen2_1p5b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -218,8 +220,8 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/qwen2_500m.py b/nemo/collections/llm/recipes/qwen2_500m.py index 08541ca9e421..96d99c271c85 100644 --- a/nemo/collections/llm/recipes/qwen2_500m.py +++ b/nemo/collections/llm/recipes/qwen2_500m.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -218,8 +220,8 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/qwen2_72b.py b/nemo/collections/llm/recipes/qwen2_72b.py index c0bc9bf40611..33bb0dd40835 100644 --- a/nemo/collections/llm/recipes/qwen2_72b.py +++ b/nemo/collections/llm/recipes/qwen2_72b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -221,8 +223,8 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/collections/llm/recipes/qwen2_7b.py b/nemo/collections/llm/recipes/qwen2_7b.py index 67bcc5e953bf..2e62176a408e 100644 --- a/nemo/collections/llm/recipes/qwen2_7b.py +++ b/nemo/collections/llm/recipes/qwen2_7b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/run/__init__.py b/nemo/collections/llm/recipes/run/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/llm/recipes/run/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/llm/recipes/run/executor.py b/nemo/collections/llm/recipes/run/executor.py new file mode 100644 index 000000000000..fe14a4f55bd2 --- /dev/null +++ b/nemo/collections/llm/recipes/run/executor.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import nemo_run as run + + +@run.cli.factory +def torchrun(devices: int = 8) -> run.Config[run.LocalExecutor]: + """Local executor using torchrun.""" + env_vars = { + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + } + + executor = run.Config( + run.LocalExecutor, + ntasks_per_node=devices, + launcher="torchrun", + env_vars=env_vars, + ) + + return executor diff --git a/nemo/collections/llm/recipes/starcoder2_15b.py b/nemo/collections/llm/recipes/starcoder2_15b.py index 14b53809111a..e424cb67dba4 100644 --- a/nemo/collections/llm/recipes/starcoder2_15b.py +++ b/nemo/collections/llm/recipes/starcoder2_15b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/starcoder2_3b.py b/nemo/collections/llm/recipes/starcoder2_3b.py index 3ee81522ebc9..faf0b416c56a 100644 --- a/nemo/collections/llm/recipes/starcoder2_3b.py +++ b/nemo/collections/llm/recipes/starcoder2_3b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/starcoder2_7b.py b/nemo/collections/llm/recipes/starcoder2_7b.py index 96b5ab36b876..091e882cd932 100644 --- a/nemo/collections/llm/recipes/starcoder2_7b.py +++ b/nemo/collections/llm/recipes/starcoder2_7b.py @@ -20,7 +20,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/starcoder_15b.py b/nemo/collections/llm/recipes/starcoder_15b.py index d87788be5613..382d0eb4d8ca 100644 --- a/nemo/collections/llm/recipes/starcoder_15b.py +++ b/nemo/collections/llm/recipes/starcoder_15b.py @@ -23,7 +23,7 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.model.starcoder import StarcoderConfig15B, StarcoderModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -280,7 +280,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -302,8 +303,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.pipeline_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/t5_11b.py b/nemo/collections/llm/recipes/t5_11b.py index 8baf54b4f42f..ee7323aa044f 100644 --- a/nemo/collections/llm/recipes/t5_11b.py +++ b/nemo/collections/llm/recipes/t5_11b.py @@ -24,7 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_trainer, nemo_resume from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -229,7 +229,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,8 +280,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/t5_220m.py b/nemo/collections/llm/recipes/t5_220m.py index 27feb43837fb..975ac5519859 100644 --- a/nemo/collections/llm/recipes/t5_220m.py +++ b/nemo/collections/llm/recipes/t5_220m.py @@ -24,7 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_trainer, nemo_resume from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -229,7 +229,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -247,15 +248,17 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', - lr=1e-4, + lr=0.0001, use_distributed_optimizer=True, bf16=True, weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=50, max_steps=2000, min_lr=0.00001, @@ -272,16 +275,17 @@ def finetune_recipe( SquadDataModule, seq_length=512, seq_length_dec=128, global_batch_size=128, micro_batch_size=1 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=nemo_resume(checkpoint_path), ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 1 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + return recipe diff --git a/nemo/collections/llm/recipes/t5_3b.py b/nemo/collections/llm/recipes/t5_3b.py index 333661d97117..82772e1b865a 100644 --- a/nemo/collections/llm/recipes/t5_3b.py +++ b/nemo/collections/llm/recipes/t5_3b.py @@ -24,7 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_trainer, nemo_resume from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -229,7 +229,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,8 +280,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/t5/data/squad.py b/nemo/collections/llm/t5/data/squad.py index 3e413919211c..4e90b09e622e 100644 --- a/nemo/collections/llm/t5/data/squad.py +++ b/nemo/collections/llm/t5/data/squad.py @@ -42,6 +42,7 @@ class SquadDataModule(FineTuningDataModule, IOMixin): def __init__( self, + dataset_root: str = None, seq_length: int = 512, seq_length_dec: int = 128, tokenizer: Optional["TokenizerSpec"] = None, @@ -60,7 +61,7 @@ def __init__( self.delete_raw = delete_raw super().__init__( - dataset_root=get_dataset_root("squad"), + dataset_root=get_dataset_root("squad") if dataset_root is None else dataset_root, seq_length=seq_length, seq_length_dec=seq_length_dec, tokenizer=tokenizer, diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py b/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py index 5113ee745895..9ea1b4afe318 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py @@ -49,7 +49,8 @@ class Gemma2DotProductAttention(MegatronModule): Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). - See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + See Reducing Activation Recomputation in Large Transformer Models: + https://arxiv.org/abs/2205.05198 for more details. We use the following notation: h: hidden size @@ -126,7 +127,12 @@ def forward( attention_mask: Tensor, attn_mask_type: AttnMaskType = None, packed_seq_params: PackedSeqParams = None, + **kwargs, ): + """Forward. + Modified from mcore.transformer.dot_product_attention to support Gemma2-specific + final_logit_softcapping. + """ assert packed_seq_params is None, ( "Packed sequence is not supported by DotProductAttention." "Please use TEDotProductAttention instead." ) @@ -243,6 +249,8 @@ def forward( class TERowParallelLinearLayerNorm(TERowParallelLinear): + """Modified From TERowParallelLinear with an additional Post-LN.""" + def __init__( self, input_size: int, @@ -270,12 +278,16 @@ def __init__( self.post_layernorm = TENorm(config, output_size) def forward(self, x): + """Forward with additional Post LN on output""" output, bias = super().forward(x) return self.post_layernorm(output), bias class Gemma2OutputLayer(ColumnParallelLinear): + """Extends from ColumnParallelLinear with logit soft capping.""" + def forward(self, *args, **kwargs): + """Forward with logit soft capping.""" output, bias = super().forward(*args, **kwargs) output = logit_softcapping(output, self.config.final_logit_softcapping) return output, bias diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 0c61b085bc7f..6a87eb28723c 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -397,7 +397,22 @@ def dummy(): model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() sharded_state_dict = model.sharded_state_dict() - checkpoint['state_dict'] = sharded_state_dict + if kwargs.get("load_mlm", False): + mlm_sharded_state_dict = {} + for k, v in sharded_state_dict.items(): + # Remove 'model.' from the sharded_state_dict keys + new_key = k.replace('model.', '', 1) + + # Update the key attribute of the ShardedTensor value + new_value = v + if hasattr(v, 'key'): + new_value.key = v.key.replace('model.', '', 1) + + # Add the updated key-value pair to the new dictionary + mlm_sharded_state_dict[new_key] = new_value + checkpoint['state_dict'] = mlm_sharded_state_dict + else: + checkpoint['state_dict'] = sharded_state_dict # load the checkpoint from disk checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir) # restore the weights diff --git a/nemo/collections/vlm/__init__.py b/nemo/collections/vlm/__init__.py index 7d8cc2c94247..266790f3af71 100644 --- a/nemo/collections/vlm/__init__.py +++ b/nemo/collections/vlm/__init__.py @@ -29,6 +29,7 @@ DataConfig, ImageDataConfig, ImageToken, + LlavaNextTaskEncoder, MultiModalToken, NevaLazyDataModule, NevaMockDataModule, @@ -42,7 +43,8 @@ NevaConfig, NevaModel, ) -from nemo.collections.vlm.neva.model.llava import Llava1_5Config7B, Llava1_5Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.llava import Llava15Config7B, Llava15Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.vit_config import CLIPViTL_14_336_Config, SigLIPViT400M_14_384_Config from nemo.collections.vlm.peft import LoRA from nemo.collections.vlm.recipes import * @@ -59,13 +61,16 @@ "VideoToken", "CLIPViTConfig", "HFCLIPVisionConfig", + "CLIPViTL_14_336_Config", + "SigLIPViT400M_14_384_Config", "MultimodalProjectorConfig", "NevaConfig", "NevaModel", "LlavaConfig", - "Llava1_5Config7B", - "Llava1_5Config13B", + "Llava15Config7B", + "Llava15Config13B", "LlavaModel", + "LlavaNextTaskEncoder", "MLlamaModel", "MLlamaModelConfig", "CrossAttentionTextConfig", diff --git a/nemo/collections/vlm/layer_specs.py b/nemo/collections/vlm/layer_specs.py new file mode 100644 index 000000000000..11c4d697a5aa --- /dev/null +++ b/nemo/collections/vlm/layer_specs.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm + + warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm') + LNImpl = WrappedTorchLayerNorm + + +def get_layer_spec(is_vit, normalization) -> ModuleSpec: + """Transformer Layer Spec""" + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + if normalization == "LayerNorm": + norm = LNImpl + elif normalization == "RMSNorm": + norm = TENorm + else: + raise RuntimeError("unknown normalization", normalization) + + mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=norm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=norm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_layer_spec_te(is_vit=False) -> ModuleSpec: + """Transformer Layer Spec w/ TE Modules""" + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + + mlp = get_norm_mlp_module_spec_te() + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + """MLP Submodule Spec""" + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +def get_norm_mlp_module_spec_te() -> ModuleSpec: + """Norm + MLP Submodule Spec""" + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear), + ) diff --git a/nemo/collections/vlm/mllama/data/mock.py b/nemo/collections/vlm/mllama/data/mock.py index a88838b0025f..4d078c745492 100644 --- a/nemo/collections/vlm/mllama/data/mock.py +++ b/nemo/collections/vlm/mllama/data/mock.py @@ -25,6 +25,26 @@ class MockDataModule(pl.LightningDataModule): + """ + Mock DataModule for testing and development. + Generates synthetic data for training, validation, and testing purposes. + + Args: + seq_length (int): Sequence length for the generated data. + decoder_seq_length (Optional[int]): Decoder sequence length if applicable, used in pp. + vocab_size (int): Size of the vocabulary of tokenizer. + crop_size (Tuple[int, int]): Image crop size (height, width). + micro_batch_size (int): Micro batch size for data loading. + global_batch_size (int): Global batch size across all processes. + rampup_batch_size (Optional[List[int]]): Batch size ramp-up configuration. + num_train_samples (int): Number of training samples to generate. + num_val_samples (int): Number of validation samples to generate. + num_test_samples (int): Number of test samples to generate. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory for data loading. + persistent_workers (bool): Whether workers should remain persistent. + """ + def __init__( self, seq_length: int = 2048, @@ -34,6 +54,8 @@ def __init__( micro_batch_size: int = 4, global_batch_size: int = 8, rampup_batch_size: Optional[List[int]] = None, + tokenizer: Optional = None, + image_processor: Optional = None, num_train_samples: int = 10_000, num_val_samples: int = 10_000, num_test_samples: int = 10_000, @@ -52,6 +74,8 @@ def __init__( self.persistent_workers = persistent_workers self.vocab_size = vocab_size self.crop_size = crop_size + self.tokenizer = tokenizer + self.image_processor = image_processor self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, @@ -62,6 +86,7 @@ def __init__( ) def setup(self, stage: str = "") -> None: + """Set up datasets for the specified stage.""" self._train_ds = _MockMLlamaDataset( self.vocab_size, self.crop_size, "train", self.num_train_samples, self.decoder_seq_length ) @@ -73,21 +98,25 @@ def setup(self, stage: str = "") -> None: ) def train_dataloader(self) -> TRAIN_DATALOADERS: + """Returns the DataLoader for training.""" if not hasattr(self, "_train_ds"): self.setup() return self._create_dataloader(self._train_ds) def val_dataloader(self) -> EVAL_DATALOADERS: + """Returns the DataLoader for validation.""" if not hasattr(self, "_validation_ds"): self.setup() return self._create_dataloader(self._validation_ds) def test_dataloader(self) -> EVAL_DATALOADERS: + """Returns the DataLoader for testing.""" if not hasattr(self, "_test_ds"): self.setup() return self._create_dataloader(self._test_ds) def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + """Creates a DataLoader for the specified dataset.""" return DataLoader( dataset, num_workers=self.num_workers, @@ -99,6 +128,18 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader: class _MockMLlamaDataset(Dataset): + """ + Mock dataset for generating synthetic data with text and image components. + + Args: + vocab_size (int): Vocabulary size for text data. + crop_size (Tuple[int, int]): Image crop size (height, width). + name (str): Name of the dataset split ('train', 'valid', 'test'). + num_samples (int): Number of samples in the dataset. + seq_length (int): Sequence length for the text data. + seed (int): Seed for random number generation. + """ + def __init__( self, vocab_size, @@ -123,13 +164,16 @@ def __init__( self.position_ids = torch.arange(self.seq_length, dtype=torch.int64) def __len__(self) -> int: + """Returns the number of samples in the dataset.""" return self.length def _get_text(self, idx: int) -> np.ndarray: + """Generates a random sequence of integers representing text tokens.""" np_gen = np.random.default_rng(seed=(self.seed + idx)) return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64) def __getitem__(self, idx) -> Dict[str, torch.Tensor]: + """Generates a single data sample.""" # Generate data of the expected size and datatype (based on GPTDataset). np_gen = np.random.default_rng(seed=(self.seed + idx)) tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length + 1], dtype=np.int64)) @@ -142,8 +186,8 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]: return { "images": images, - "masks": [[5, 512]], - "num_chunks": [4], + "masks": torch.tensor([[5, 512]]), + "num_chunks": torch.tensor([4]), "tokens": tokens, "aspect_ratio_ids": aspect_ratio_ids, "loss_mask": self.loss_mask, diff --git a/nemo/collections/vlm/mllama/model/base.py b/nemo/collections/vlm/mllama/model/base.py index 7dd84fefbb18..9279936e23d7 100644 --- a/nemo/collections/vlm/mllama/model/base.py +++ b/nemo/collections/vlm/mllama/model/base.py @@ -40,13 +40,15 @@ from nemo.collections.vlm.mllama.model.language import CrossAttentionTextModel from nemo.collections.vlm.mllama.model.utils import _generate_cross_attention_mask, _pad_attention_masks from nemo.collections.vlm.mllama.model.vision import VisionEncoder +from nemo.collections.vlm.neva.model.base import MODEL_CONFIG_ATTR from nemo.lightning import get_vocab_size, io from nemo.lightning.megatron_parallel import MaskedTokenLossReduction from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule from nemo.utils import logging -def llama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: +def mllama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: + """Mllama data step.""" from megatron.core import parallel_state # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 @@ -95,7 +97,8 @@ def llama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: return output -def llama_forward_step(model, batch) -> torch.Tensor: +def mllama_forward_step(model, batch) -> torch.Tensor: + """Mllama model forward step.""" forward_config = { "batch_images": batch["batch_images"], "batch_masks": batch["batch_masks"], @@ -113,13 +116,15 @@ def llama_forward_step(model, batch) -> torch.Tensor: def set_input_tensor(self, tensor): + """Placeholder for `set_input_tensor` method for PP implementation.""" pass @dataclass class CrossAttentionVisionConfig(TransformerConfig, io.IOMixin): - # core params + """Configuration for llama vision model.""" + # core params bias_activation_fusion: bool = True bias_dropout_add_fusion: bool = True @@ -149,9 +154,11 @@ class CrossAttentionVisionConfig(TransformerConfig, io.IOMixin): @property def max_aspect_ratio_id(self) -> int: + # pylint: disable=C0115,C0116 return len(self.supported_aspect_ratios) def configure_model(self) -> "CrossAttentionVisionModel": + """Configure mllama vision model.""" return CrossAttentionVisionModel( self, ) @@ -159,6 +166,10 @@ def configure_model(self) -> "CrossAttentionVisionModel": @dataclass class CrossAttentionTextConfig(Llama31Config): + """ + Configuration for llama model with cross-attention layers to take in multimodal features. + """ + rotary_base: int = 500_000 seq_length: int = 8192 num_layers: int = 32 @@ -170,12 +181,14 @@ class CrossAttentionTextConfig(Llama31Config): apply_rope_fusion: bool = False def _init_fusion_schedule(self, num_layers: int) -> List[int]: - llama_layers = list(range(self.num_layers)) + """Initialize self-attention layer / cross-attention layer fusion schedule""" + mllama_layers = list(range(self.num_layers)) # uniformly spread the layers - k = math.ceil(len(llama_layers) / num_layers) - return llama_layers[::-1][::k][:num_layers][::-1] + k = math.ceil(len(mllama_layers) / num_layers) + return mllama_layers[::-1][::k][:num_layers][::-1] def configure_model(self, tokenizer, pre_process=True, post_process=True): + """Configure mllama text model.""" self.fusion_schedule = self._init_fusion_schedule(self.num_cross_attention_layers) vp_size = self.virtual_pipeline_model_parallel_size if vp_size: @@ -224,6 +237,8 @@ def configure_model(self, tokenizer, pre_process=True, post_process=True): @dataclass class MLlamaModelConfig(TransformerConfig, io.IOMixin): + """Combined configuration for multimodal vision-language model.""" + language_model_config: Optional[CrossAttentionTextConfig] = None vision_model_config: Optional[CrossAttentionVisionConfig] = None @@ -236,42 +251,16 @@ class MLlamaModelConfig(TransformerConfig, io.IOMixin): language_model_from_pretrained: Optional[str] = None # TODO vision_model_from_pretrained: Optional[str] = None # TODO - forward_step_fn: Callable = llama_forward_step - data_step_fn: Callable = llama_data_step + forward_step_fn: Callable = mllama_forward_step + data_step_fn: Callable = mllama_data_step def __post_init__(self): - model_config_attr = [ - 'num_layers', - 'hidden_size', - 'num_attention_heads', - 'num_query_groups', - 'ffn_hidden_size', - 'kv_channels', - 'hidden_dropout', - 'attention_dropout', - 'fp32_residual_connection', - 'apply_residual_connection_post_layernorm', - 'layernorm_epsilon', - 'layernorm_zero_centered_gamma', - 'add_bias_linear', - 'add_qkv_bias', - 'gated_linear_unit', - 'activation_func', - 'activation_func_fp8_input_store', - 'num_moe_experts', - 'rotary_interleaved', - 'window_size', - 'normalization', - 'qk_layernorm', - 'test_mode', - 'calculate_per_token_loss', - ] - if self.language_model_config is not None: - for attr in model_config_attr: + for attr in MODEL_CONFIG_ATTR: setattr(self, attr, getattr(self.language_model_config, attr)) def configure_model(self, tokenizer) -> "MLlamaBaseModel": + """Configure mllama model.""" from megatron.core import parallel_state as ps self.language_model_config.tensor_model_parallel_size = self.tensor_model_parallel_size @@ -300,6 +289,8 @@ def configure_model(self, tokenizer) -> "MLlamaBaseModel": class CrossAttentionVisionModel(MegatronModule): + """Mllama vision model.""" + def __init__(self, config) -> None: super().__init__(config=config) return_intermediate = "3,7,15,23,30" @@ -329,6 +320,7 @@ def __init__(self, config) -> None: self.vision_projection.encoder.skip_bias_add = False # Temporary fix for a MCore side bug def forward(self, images: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" # vision_tokens: (B, T, D) # aspect_ratio_ids: (B, 1) # h: (B, T, D) @@ -339,10 +331,13 @@ def forward(self, images: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch return vision_tokens def set_input_tensor(self, tensor): + # pylint: disable=C0115,C0116 pass class MLlamaBaseModel(MegatronModule): + """Mllama base model combining vision and text models with cross-attention.""" + def __init__( self, config: MLlamaModelConfig, @@ -382,10 +377,6 @@ def __init__( self.patch_size = 14 self.image_res = vision_model_config.vision_chunk_size self.max_num_chunks = vision_model_config.vision_max_num_chunks - logging.warning("[WARNING] NeMo Mllama will always pad images to max number of tiles. A fix is coming soon!") - - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - self.language_model.setup_cache(max_batch_size, dtype) def compute_xattn_caches_masks( self, @@ -395,6 +386,7 @@ def compute_xattn_caches_masks( num_chunks: torch.Tensor, total_len: int, ) -> Tuple[List, torch.Tensor, torch.Tensor]: + """Compute xattn caches masks used in text model.""" bsz, nimg, nchunk, ntok, image_token_dim = vision_orig_shape xattn_caches = [ @@ -434,6 +426,7 @@ def forward( full_text_row_masked_out_mask: Optional[torch.Tensor] = None, xattn_caches: Optional[List] = None, ) -> torch.Tensor: + """Forward.""" if xattn_caches is None: bsz, max_num_images = batch_images.size(0), batch_images.size(1) vision_orig_shape = ( @@ -444,8 +437,8 @@ def forward( self.config.hidden_size, ) skip_vision_encoder = False - num_chunks[num_chunks > 0] = self.max_num_chunks if max_num_images == 0: + num_chunks[num_chunks > 0] = self.max_num_chunks skip_vision_encoder = True if self.encoder_hidden_state is not None: @@ -515,6 +508,8 @@ def set_input_tensor(self, input_tensor) -> None: class MLlamaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): + """Lightning Module for the MLlama model.""" + def __init__( self, config: MLlamaModelConfig, @@ -532,6 +527,7 @@ def __init__( self._validation_loss_reduction = None def configure_model(self) -> None: + """Configure mllama model""" if not hasattr(self, "module"): self.module: MLlamaBaseModel = self.config.configure_model(self.tokenizer) @@ -548,7 +544,7 @@ def forward( full_text_row_masked_out_mask: Optional[torch.Tensor] = None, xattn_caches: Optional[torch.Tensor] = None, ) -> torch.Tensor: - + """Forward.""" output_tensor = self.module( position_ids=position_ids, tokens=tokens, @@ -565,22 +561,26 @@ def forward( return output_tensor def data_step(self, dataloader_iter) -> Dict[str, torch.Tensor]: + # pylint: disable=C0115,C0116 return self.config.data_step_fn(dataloader_iter) def forward_step(self, batch) -> torch.Tensor: + # pylint: disable=C0115,C0116 return self.config.forward_step_fn(self, batch) def training_step(self, batch, batch_idx=None) -> torch.Tensor: + # pylint: disable=C0115,C0116 # In mcore the loss-function is part of the forward-pass (when labels are provided) return self.forward_step(batch) def validation_step(self, batch, batch_idx=None) -> torch.Tensor: + # pylint: disable=C0115,C0116 # In mcore the loss-function is part of the forward-pass (when labels are provided) - return self.forward_step(batch) @property def training_loss_reduction(self) -> MaskedTokenLossReduction: + # pylint: disable=C0115,C0116 if not self._training_loss_reduction: self._training_loss_reduction = MaskedTokenLossReduction() @@ -588,6 +588,7 @@ def training_loss_reduction(self) -> MaskedTokenLossReduction: @property def validation_loss_reduction(self) -> MaskedTokenLossReduction: + # pylint: disable=C0115,C0116 if not self._validation_loss_reduction: self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) @@ -599,8 +600,8 @@ def validation_loss_reduction(self) -> MaskedTokenLossReduction: "MLlamaModelConfig", "CrossAttentionTextConfig", "CrossAttentionVisionConfig", - "llama_data_step", - "llama_forward_step", + "mllama_data_step", + "mllama_forward_step", "transformer_engine_layer_spec", "local_layer_spec", ] diff --git a/nemo/collections/vlm/mllama/model/language.py b/nemo/collections/vlm/mllama/model/language.py index b8985e53c54c..5d4cc2e09f21 100644 --- a/nemo/collections/vlm/mllama/model/language.py +++ b/nemo/collections/vlm/mllama/model/language.py @@ -60,6 +60,10 @@ @dataclass class MLlamaCrossAttentionSubmodules: + """ + Defines the submodules required for cross-attention layers in the Llama architecture. + """ + linear_q: Union[ModuleSpec, type] = None linear_kv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None @@ -69,6 +73,10 @@ class MLlamaCrossAttentionSubmodules: class CrossAttentionTextModel(MCoreGPTModel): + """ + GPT-based model with integrated cross-attention layers for multimodal tasks. + """ + def __init__( self, config: TransformerConfig, @@ -122,6 +130,7 @@ def __init__( self._thresh = self.num_frozen_embeddings - 1 def get_partially_trainable_embedding(self, x): + """Get word embedding w/ few extra learnable tokens.""" xz = torch.zeros_like(x, device=x.device) oz = torch.ones_like(x, device=x.device) x_orig = torch.minimum(x, torch.tensor(self._thresh, device=x.device)) @@ -148,7 +157,7 @@ def forward( packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, ) -> Tensor: - + """Forward.""" # Decoder embedding. if decoder_input is not None: pass @@ -171,6 +180,9 @@ def forward( ) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + dtype = decoder_input.dtype + cross_attention_bias = cross_attention_masks.to(dtype) * torch.finfo(dtype).min + # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, @@ -178,9 +190,10 @@ def forward( inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, packed_seq_params=packed_seq_params, - cross_attention_masks=cross_attention_masks, + cross_attention_masks=None, full_text_row_masked_out_mask=full_text_row_masked_out_mask, xattn_caches=xattn_caches, + cross_attention_bias=cross_attention_bias, **(extra_block_kwargs or {}), ) @@ -203,6 +216,10 @@ def forward( class CrossAttentionTransformerBlock(TransformerBlock): + """ + Transformer block with integrated cross-attention layers for multimodal tasks. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -220,7 +237,7 @@ def __init__(self, *args, **kwargs): submodules=TransformerLayerSubmodules( cross_attention=ModuleSpec( module=MLlamaCrossAttention, - params={"attn_mask_type": AttnMaskType.arbitrary}, + params={"attn_mask_type": AttnMaskType.no_mask}, submodules=MLlamaCrossAttentionSubmodules( linear_q=TELayerNormColumnParallelLinear, # This wraps attention_norm before attention linear_kv=TEColumnParallelLinear, @@ -250,6 +267,7 @@ def __init__(self, *args, **kwargs): assert len(self.xattn_layers) == len(self.layers), 'Check PP implementation for cross attention layers!' def _get_layer_offset(self): + """Get correct layer offset when encoder pipeline parallel size > 0.""" encoder_pipeline_model_parallel_size = getattr(self.config, "encoder_pipeline_model_parallel_size", 0) decoder_pipeline_model_parallel_rank = ( parallel_state.get_pipeline_model_parallel_rank() - encoder_pipeline_model_parallel_size @@ -264,9 +282,12 @@ def forward( cross_attention_masks: Tensor = None, full_text_row_masked_out_mask: Tensor = None, rotary_pos_emb: Tensor = None, + attention_bias: Tensor = None, + cross_attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, ): + """Forward.""" # hidden_states (float): [s, b, h] # attention_mask (bool): [1, 1, s, s] @@ -324,6 +345,7 @@ def forward( xattn_cache=xattn_caches[l_no], full_text_row_masked_out_mask=full_text_row_masked_out_mask, rotary_pos_emb=rotary_pos_emb, + cross_attention_bias=cross_attention_bias, inference_params=inference_params, packed_seq_params=packed_seq_params, ) @@ -331,6 +353,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, inference_params=inference_params, packed_seq_params=packed_seq_params, ) @@ -361,6 +384,7 @@ def forward( def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None ) -> ShardedStateDict: + """Update shareded state dict for cross-attention layers""" sharded_state_dict = {} layer_prefix = f'{prefix}layers.' @@ -399,6 +423,10 @@ def sharded_state_dict( class CrossAttentionTransformerLayer(TransformerLayer): + """ + Transformer layer with cross-attention for integration. + """ + def __init__( self, config: TransformerConfig, @@ -417,6 +445,7 @@ def __init__( self.gate_ffn = nn.Parameter(torch.zeros(1, dtype=self.config.params_dtype)) def compute_xattn_kv_cache(self, xattn_tokens: Tensor) -> Tensor: + """Compute cross-attention kv cahce.""" return self.cross_attention._compute_xattn_kv_cache(xattn_tokens) def forward( @@ -426,9 +455,11 @@ def forward( xattn_cache=None, full_text_row_masked_out_mask=None, rotary_pos_emb=None, + cross_attention_bias=None, inference_params=None, packed_seq_params=None, ): + """Forward.""" # hidden_states: [s, b, h] # Residual connection. @@ -444,6 +475,7 @@ def forward( xattn_cache=xattn_cache, full_text_row_masked_out_mask=full_text_row_masked_out_mask, rotary_pos_emb=rotary_pos_emb, + cross_attention_bias=cross_attention_bias, inference_params=inference_params, ) @@ -507,11 +539,13 @@ def __call__( return hidden_states, None def compute_xattn_kv_cache(self, xattn_tokens: Tensor) -> Optional[Tensor]: + # pylint: disable=C0115,C0116 return None class MLlamaCrossAttention(Attention): - """Cross-attention layer class for Llama VLM support + """ + Cross-attention layer for Llama multimodal tasks. Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size. @@ -574,6 +608,7 @@ def __init__( ) def get_key_value_tensors(self, key_value_states): + """Get key value tensors.""" mixed_kv, _ = self.linear_kv(key_value_states) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] @@ -590,7 +625,7 @@ def get_key_value_tensors(self, key_value_states): return key, value def get_query_tensor(self, hidden_states): - + """ "Get query tensor.""" # Attention head [sq, b, h] --> [sq, b, hp] query, _ = self.linear_q(hidden_states) @@ -607,6 +642,7 @@ def get_query_tensor(self, hidden_states): return query def get_query_key_value_tensors(self, hidden_states, key_value_states): + """Get query key value tensors.""" query = self.get_query_tensor(hidden_states) key, value = self.get_key_value_tensors(key_value_states) return query, key, value @@ -619,8 +655,17 @@ def forward( full_text_row_masked_out_mask=None, inference_params=None, rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + cross_attention_bias=None, packed_seq_params=None, ): + """Forward.""" + # hidden_states: [sq, b, h] + if self.config.flash_decode: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None # For self attention we just duplicate the rotary_pos_emb if it isn't already if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): @@ -637,8 +682,8 @@ def forward( # =================================================== # Adjust key, value, and rotary_pos_emb for inference # =================================================== - key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_params, key, value, rotary_pos_emb + query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin ) if packed_seq_params is not None: @@ -650,9 +695,6 @@ def forward( # core attention computation # ================================== - # In TE "True" means masked out - cross_attention_masks = torch.where(cross_attention_masks == 0, False, True) - if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, @@ -660,6 +702,7 @@ def forward( value, cross_attention_masks, attn_mask_type=attn_mask_type, + attention_bias=cross_attention_bias, packed_seq_params=packed_seq_params, ) else: @@ -669,6 +712,7 @@ def forward( value, cross_attention_masks, attn_mask_type=attn_mask_type, + attention_bias=cross_attention_bias, packed_seq_params=packed_seq_params, ) @@ -702,8 +746,22 @@ def apply_rope_scaling( high_freq_factor: int = 4, old_context_len: int = 8192, ): + """ + Apply scaling to rotary embeddings for positional encoding. + + Args: + inv_freq (Tensor): Tensor of inverse frequencies. + factor (int): Scaling factor for medium-to-high frequencies. + low_freq_factor (int): Factor for identifying low frequencies. + high_freq_factor (int): Factor for identifying high frequencies. + old_context_len (int): Original context length for scaling computation. + + Returns: + Tensor: Scaled inverse frequencies. + """ logging.info( - f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, old_context_len={old_context_len}." + f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, " + f"high_freq_factor={high_freq_factor}, old_context_len={old_context_len}." ) low_freq_wavelen = old_context_len / low_freq_factor diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index f023cc7bf943..bb58ad093cd6 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -120,15 +120,16 @@ def build_encoder_attention_mask( torch.Tensor: Tensor containing the attention mask. """ masks = [] + dtype = x.dtype for ar_id in ar_ids: arx = supported_aspect_ratios[ar_id - 1] mask_i = torch.ones((num_chunks, x.shape[1] // num_chunks), device=x.device) mask_i[: arx[0] * arx[1], :ntok] = 0 mask_i = mask_i.view(num_chunks * x.shape[1] // num_chunks, -1) - mask_i = (mask_i @ mask_i.T).type(torch.bool) + mask_i = mask_i @ mask_i.T mask_i = mask_i.unsqueeze(0) masks.append(mask_i) - masks = torch.stack(masks) + masks = torch.stack(masks).to(dtype) * torch.finfo(dtype).min return masks @@ -197,6 +198,7 @@ def forward_with_return_intermediate( context: Tensor = None, context_mask: Tensor = None, rotary_pos_emb: Tensor = None, + attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, return_intermediate: List[int] = None, @@ -253,6 +255,7 @@ def forward_with_return_intermediate( context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, packed_seq_params=packed_seq_params, ) else: @@ -269,6 +272,7 @@ def forward_with_return_intermediate( context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, inference_params=inference_params, packed_seq_params=packed_seq_params, ) @@ -506,6 +510,7 @@ def forward( attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, packed_seq_params=packed_seq_params, ) @@ -690,11 +695,12 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: x = x.view(bsz * num_concurrent_media, -1, dim) npad, attn_mask = 0, None - attn_mask = build_encoder_attention_mask(x, ar_ids, ntok, num_chunks, self.config.supported_aspect_ratios) + attn_bias = build_encoder_attention_mask(x, ar_ids, ntok, num_chunks, self.config.supported_aspect_ratios) x = x.transpose(0, 1).contiguous() x, int_x = self.transformer( hidden_states=x, attention_mask=attn_mask, + attention_bias=attn_bias, return_intermediate=self.return_intermediate, ) @@ -709,6 +715,7 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: x = self.global_transformer( hidden_states=x, attention_mask=None, + attention_bias=attn_bias, ) x = x.transpose(0, 1) x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) diff --git a/nemo/collections/vlm/neva/data/__init__.py b/nemo/collections/vlm/neva/data/__init__.py index f210d01a06fd..df9716fe5610 100644 --- a/nemo/collections/vlm/neva/data/__init__.py +++ b/nemo/collections/vlm/neva/data/__init__.py @@ -14,6 +14,7 @@ from nemo.collections.vlm.neva.data.config import DataConfig, ImageDataConfig, VideoDataConfig from nemo.collections.vlm.neva.data.lazy import NevaLazyDataModule +from nemo.collections.vlm.neva.data.llava_next_energon import LlavaNextTaskEncoder from nemo.collections.vlm.neva.data.mock import MockDataModule as NevaMockDataModule from nemo.collections.vlm.neva.data.multimodal_tokens import ImageToken, MultiModalToken, VideoToken @@ -26,4 +27,5 @@ "MultiModalToken", "ImageToken", "VideoToken", + "LlavaNextTaskEncoder", ] diff --git a/nemo/collections/vlm/neva/data/conversation.py b/nemo/collections/vlm/neva/data/conversation.py index d78d3bd28acb..58953dc53b7a 100644 --- a/nemo/collections/vlm/neva/data/conversation.py +++ b/nemo/collections/vlm/neva/data/conversation.py @@ -77,7 +77,6 @@ def process_chat_template(self, tokenizer_name_or_path, messages): def get_prompt(self): messages = self.messages - messages = self.process_prompt_with_images(messages) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep @@ -100,6 +99,8 @@ def get_prompt(self): if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] + # Add space to make sure the labels can be correctly generated. + self.messages[i][1] = " " + self.messages[i][1] else: ret += role + ":" @@ -155,7 +156,6 @@ def get_prompt(self): ret = self.process_chat_template(tokenizer_name_or_path, messages) elif self.sep_style == SeparatorStyle.MLLAMA: - """ """ tokenizer_name_or_path = self.tokenizer_name_or_path or "meta-llama/Llama-3.2-11B-Vision-Instruct" ret = self.process_chat_template(tokenizer_name_or_path, messages) diff --git a/nemo/collections/vlm/neva/data/lazy.py b/nemo/collections/vlm/neva/data/lazy.py index fddaca14faeb..5bc2cbe0458e 100644 --- a/nemo/collections/vlm/neva/data/lazy.py +++ b/nemo/collections/vlm/neva/data/lazy.py @@ -251,7 +251,7 @@ def __init__( data_config, tokenizer, image_processor, - sequence_length, + sequence_length=None, ): super().__init__() if data_path is not None: @@ -497,6 +497,7 @@ def __init__( weights: Optional[List[float]] = None, data_config: Optional[DataConfig] = ImageDataConfig, seq_length: int = 2048, + decoder_seq_length: Optional[int] = None, tokenizer: Optional = None, image_processor: Optional = None, micro_batch_size: int = 4, @@ -523,6 +524,7 @@ def __init__( self.weights = weights self.data_config = data_config self.seq_length = seq_length + self.decoder_seq_length = decoder_seq_length self.tokenizer = tokenizer self.image_processor = image_processor self.num_train_samples = num_train_samples @@ -538,13 +540,15 @@ def __init__( if tokenizer is None or image_processor is None: logging.warning(f"Processor and tokenizer are not provided! Fall back to `llava-hf/llava-1.5-7b-hf`.") from transformers import AutoProcessor + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - self.tokenizer = tokenizer or processor.tokenizer + self.tokenizer = tokenizer or AutoTokenizer("llava-hf/llava-1.5-7b-hf") self.image_processor = image_processor or processor.image_processor self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_length, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, dataloader_type="cyclic", diff --git a/nemo/collections/vlm/neva/data/llava_next_energon.py b/nemo/collections/vlm/neva/data/llava_next_energon.py new file mode 100644 index 000000000000..c45ee50e5be3 --- /dev/null +++ b/nemo/collections/vlm/neva/data/llava_next_energon.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Dict, List + +import torch +from megatron.energon import VQASample, batch_list, batch_pad_stack +from torch.nn.utils.rnn import pad_sequence + +from nemo.collections.multimodal.data.energon.config import ImageTextRawBatch, ImageTextSample, MultiModalSampleConfig +from nemo.collections.multimodal.data.energon.sample_encoder import SampleEncoder, VQASampleEncoder +from nemo.collections.multimodal.data.energon.task_encoder import MultiModalTaskEncoder +from nemo.utils import logging + + +class LlavaNextTextSample(ImageTextSample): + num_media_tiles: int = 0 + + +@dataclass +class LlavaNextTextRawBatch(ImageTextRawBatch): + num_media_tiles: List[int] = field(default_factory=list) + + +class LlavaNextSampleEncoder(VQASampleEncoder): + def __init__(self, tokenizer, image_processor, multimodal_sample_config=MultiModalSampleConfig()): + """ + Initialize the LlavaNextSampleEncoder, inherited from VQASampleEncoder for multimodal samples + focused on VQA-style data to support LLaVANeXT + + Parameters: + tokenizer (Tokenizer): The HF tokenizer used for processing text. + image_processor (ImageProcessor): The HF image processor used for preprocessing images. + multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. + Defaults to MultiModalSampleConfig(). + """ + super().__init__(tokenizer, image_processor, multimodal_sample_config) + + def process_image(self, image): + """ + Process and prepare an image sample for encoding. + + This method preprocesses the image using the HF image_processor, converting it to + a tensor. + + Parameters: + image: The input image to be processed. + + Returns: + torch.Tensor: The processed image tensor. + """ + image_array = self.image_processor.preprocess(image, return_tensors='pt', do_rescale=False)['pixel_values'][0] + return image_array + + def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample): + """ + Encode a single sample into a format suitable for model input. + + This method prepares the conversation prompt, tokenizes it, and processes + the associated image. It fills the output sample with tokens, labels, loss mask, + and other required fields for multimodal processing. + + Parameters: + input_sample (VQASample): The input VQA sample containing an image and conversation text. + output_sample (LlavaNextTextSample): The output sample structure where encoded results are stored. + + Returns: + LlavaNextTextSample: The encoded output sample, containing processed tokens, labels, + images, loss masks, and metadata. + """ + conversation_prompt = self.apply_prompt_template(input_sample) + logging.debug(f"task encoder encode_sample conversation_prompt {conversation_prompt}") + # tokenize prompt + tokens = self.tokenize(conversation_prompt) + labels = self.compute_labels(tokens, input_sample) + + tokens = tokens[:-1].contiguous() + labels = labels[1:].contiguous() + logging.debug(f"task encoder encode_sample after tokenize prompt tokens {tokens}") + logging.debug(f"task encoder encode_sample lables {labels}") + loss_mask = self.compute_loss_mask(labels) + processed_image = self.process_image(input_sample.image) + output_sample.__key__ = input_sample.__key__ + output_sample.images = processed_image + output_sample.tokens = tokens + output_sample.labels = labels + output_sample.loss_mask = loss_mask + output_sample.num_media_tiles = processed_image.shape[0] + return output_sample + + +class LlavaNextTaskEncoder(MultiModalTaskEncoder): + def __init__(self, tokenizer, image_processor, multimodal_sample_config): + """ + Initialize the LlavaNextTaskEncoder. + + This encoder extends MultiModalTaskEncoder to specifically handle LlavaNeXT, + overriding encoders for VQA sample type. + + Parameters: + tokenizer (Tokenizer): The tokenizer for processing text data across sample types. + image_processor (ImageProcessor): The image processor for preprocessing images. + multimodal_sample_config (MultiModalSampleConfig): Configuration settings for multimodal samples. + """ + super().__init__(tokenizer, image_processor, multimodal_sample_config) + self.encoders: Dict[str, SampleEncoder] = { + VQASample.__name__: LlavaNextSampleEncoder(tokenizer, image_processor, multimodal_sample_config) + } + + def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: + """ + Batch multiple encoded samples into a single batch structure for model input. + + This method combines individual sample fields (keys, images, tokens, labels, etc.) and + pads or stacks them as needed to create a unified batch. + + Parameters: + samples (List[LlavaNextTextSample]): A list of LlavaNextTextSample instances to be batched. + + Returns: + LlavaNextTextRawBatch: A batch containing all input samples' images, tokens, labels, + loss masks, and other metadata prepared for model processing. + """ + keys, images, tokens, labels, loss_mask, num_media_tiles = [], [], [], [], [], [] + for sample in samples: + keys.append(sample.__key__) + images.append(sample.images) + tokens.append(sample.tokens) + labels.append(sample.labels) + loss_mask.append(sample.loss_mask) + num_media_tiles.append(sample.num_media_tiles) + + batch_keys = batch_list(keys) + + batch_images = torch.cat(images, dim=0) + + batch_tokens = pad_sequence(tokens, batch_first=True) + batch_labels = pad_sequence(labels, batch_first=True) + + batch_loss_mask = batch_pad_stack(loss_mask) + batch_num_media_tiles = torch.tensor(batch_list(num_media_tiles), dtype=torch.int) + return LlavaNextTextRawBatch( + __keys__=batch_keys, + images=batch_images, + tokens=batch_tokens, + labels=batch_labels, + loss_mask=batch_loss_mask, + num_media_tiles=batch_num_media_tiles, + ) diff --git a/nemo/collections/vlm/neva/data/mock.py b/nemo/collections/vlm/neva/data/mock.py index ede06e9f5778..9e2308752641 100644 --- a/nemo/collections/vlm/neva/data/mock.py +++ b/nemo/collections/vlm/neva/data/mock.py @@ -23,26 +23,29 @@ from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging class MockDataModule(pl.LightningDataModule): def __init__( self, seq_length: int = 2048, + decoder_seq_length: Optional[int] = None, tokenizer: Optional = None, image_processor: Optional = None, micro_batch_size: int = 4, global_batch_size: int = 8, rampup_batch_size: Optional[List[int]] = None, - num_train_samples: int = 10_000, - num_val_samples: int = 10_000, - num_test_samples: int = 10_000, + num_train_samples: int = 10_000_000, + num_val_samples: int = 10_000_000, + num_test_samples: int = 10_000_000, num_workers: int = 8, pin_memory: bool = True, persistent_workers: bool = False, ): super().__init__() self.seq_length = seq_length + self.decoder_seq_len = decoder_seq_length self.num_train_samples = num_train_samples self.num_val_samples = num_val_samples self.num_test_samples = num_test_samples @@ -51,13 +54,16 @@ def __init__( self.persistent_workers = persistent_workers if tokenizer is None or image_processor is None: + logging.warning(f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-1.5-7b-hf`.") from transformers import AutoProcessor + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - self.tokenizer = tokenizer or processor.tokenizer + self.tokenizer = tokenizer or AutoTokenizer("llava-hf/llava-1.5-7b-hf") self.image_processor = image_processor or processor.image_processor self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_len, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, rampup_batch_size=rampup_batch_size, diff --git a/nemo/collections/vlm/neva/model/__init__.py b/nemo/collections/vlm/neva/model/__init__.py index 25842186ecfe..99862f97b9ed 100644 --- a/nemo/collections/vlm/neva/model/__init__.py +++ b/nemo/collections/vlm/neva/model/__init__.py @@ -19,16 +19,19 @@ NevaConfig, NevaModel, ) -from nemo.collections.vlm.neva.model.llava import Llava1_5Config7B, Llava1_5Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.llava import Llava15Config7B, Llava15Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.vit_config import CLIPViTL_14_336_Config, SigLIPViT400M_14_384_Config __all__ = [ "CLIPViTConfig", + "CLIPViTL_14_336_Config", + "SigLIPViT400M_14_384_Config", "HFCLIPVisionConfig", "MultimodalProjectorConfig", "NevaConfig", "NevaModel", "LlavaConfig", - "Llava1_5Config7B", - "Llava1_5Config13B", + "Llava15Config7B", + "Llava15Config13B", "LlavaModel", ] diff --git a/nemo/collections/vlm/neva/model/api.py b/nemo/collections/vlm/neva/model/api.py index 19e94c70381e..13444632464e 100644 --- a/nemo/collections/vlm/neva/model/api.py +++ b/nemo/collections/vlm/neva/model/api.py @@ -14,18 +14,18 @@ import lightning.pytorch as pl -from nemo.collections.vlm.neva.model import Llava1_5Config7B, Llava1_5Config13B, LlavaModel +from nemo.collections.vlm.neva.model import Llava15Config7B, Llava15Config13B, LlavaModel -def llava1_5_7b() -> pl.LightningModule: - return LlavaModel(Llava1_5Config7B()) +def llava15_7b() -> pl.LightningModule: + return LlavaModel(Llava15Config7B()) -def llava1_5_13b() -> pl.LightningModule: - return LlavaModel(Llava1_5Config13B()) +def llava15_13b() -> pl.LightningModule: + return LlavaModel(Llava15Config13B()) __all__ = [ - "llava1_5_7b", - "llava1_5_13b", + "llava15_7b", + "llava15_13b", ] diff --git a/nemo/collections/vlm/neva/model/base.py b/nemo/collections/vlm/neva/model/base.py index d4e578218ed2..360874152cf7 100644 --- a/nemo/collections/vlm/neva/model/base.py +++ b/nemo/collections/vlm/neva/model/base.py @@ -22,17 +22,20 @@ import torch.distributed import torch.nn.functional as F from megatron.core import dist_checkpointing +from megatron.core import parallel_state as ps +from megatron.core.enums import ModelType +from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.inference_params import InferenceParams from megatron.core.models.multimodal.llava_model import LLaVAModel as MCoreLLaVAModel from megatron.core.models.vision.clip_vit_model import CLIPViTModel as MCoreCLIPViTModel from megatron.core.models.vision.multimodal_projector import MultimodalProjector as MCoreMultimodalProjector from megatron.core.optimizer import OptimizerConfig +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.custom_layers.transformer_engine import ( TEColumnParallelLinear, TENorm, TERowParallelLinear, ) -from megatron.core.transformer.enums import ModelType from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -41,15 +44,43 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.llm import fn -from nemo.collections.llm.gpt.model import local_layer_spec, transformer_engine_layer_spec +from nemo.collections.llm.gpt.model import transformer_engine_layer_spec from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank, get_packed_seq_params -from nemo.collections.nlp.modules.common.megatron.module import MegatronModule -from nemo.collections.vlm.neva.data.multimodal_tokens import IGNORE_INDEX, IMAGE_TOKEN_INDEX +from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX from nemo.lightning import io +from nemo.lightning.io.pl import ckpt_to_weights_subdir from nemo.lightning.megatron_parallel import MaskedTokenLossReductionWithLossMask from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule from nemo.utils import logging +MODEL_CONFIG_ATTR = [ + 'num_layers', + 'hidden_size', + 'num_attention_heads', + 'num_query_groups', + 'ffn_hidden_size', + 'kv_channels', + 'hidden_dropout', + 'attention_dropout', + 'fp32_residual_connection', + 'apply_residual_connection_post_layernorm', + 'layernorm_epsilon', + 'layernorm_zero_centered_gamma', + 'add_bias_linear', + 'add_qkv_bias', + 'gated_linear_unit', + 'activation_func', + 'activation_func_fp8_input_store', + 'num_moe_experts', + 'rotary_interleaved', + 'window_size', + 'normalization', + 'qk_layernorm', + 'test_mode', + 'calculate_per_token_loss', + 'seq_length', +] + def get_image_sequence_length(img_h, img_w, patch_dim, add_class_token, class_token_len): """Get image sequence length given image size, patch size, and class token.""" @@ -64,9 +95,7 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842 - batch = next(dataloader_iter) - _batch: dict if isinstance(batch, tuple) and len(batch) == 3: _batch = batch[0] @@ -74,11 +103,23 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: _batch = batch required_keys = set() - required_keys.add("attention_mask") + required_keys.update( + ( + "tokens", + "attention_mask", + "media", + "num_media_tiles", + ) + ) if parallel_state.is_pipeline_first_stage(): - required_keys.update(("media", "tokens", "position_ids")) + required_keys.update(("position_ids",)) if parallel_state.is_pipeline_last_stage(): - required_keys.update(("labels", "loss_mask")) + required_keys.update( + ( + "labels", + "loss_mask", + ) + ) _batch = { key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None @@ -98,6 +139,7 @@ def neva_forward_step(model, batch) -> torch.Tensor: "attention_mask": batch.get("attention_mask", None), "loss_mask": batch.get("loss_mask", None), "labels": batch.get("labels", None), + "num_media_tiles": batch.get("num_media_tiles", None), } if 'cu_seqlens' in batch: @@ -176,10 +218,11 @@ class HFCLIPVisionConfig(CLIPVisionConfig, io.IOMixin): https://github.com/huggingface/transformers/blob/v4.44.0/src/transformers/models/clip/configuration_clip.py#L261 """ + hidden_size: int = 1024 pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None - def configure_hf_config(self, *args, **kwargs) -> None: - CLIPVisionConfig.__init__(self, *args, **kwargs) + def __post_init__(self, *args, **kwargs) -> None: + CLIPVisionConfig.__init__(self, *args, **kwargs, hidden_size=self.hidden_size) def configure_model(self) -> "CLIPVisionModel": # Monkey patch the method to the vision encoder @@ -198,26 +241,40 @@ def configure_model(self) -> "CLIPVisionModel": @dataclass class CLIPViTConfig(TransformerConfig, io.IOMixin): ln_pre_impl: Union[ModuleSpec, type] = TENorm + ln_post_impl: Union[ModuleSpec, type] = TENorm add_class_token: bool = True class_token_len: int = 1 patch_dim: int = 14 img_h: int = 336 img_w: int = 336 + vision_model_type: str = "clip" # ["clip", "siglip"] transformer_layer_spec: ModuleSpec = transformer_engine_layer_spec - def configure_model(self) -> "MCoreCLIPViTModel": + num_layers: int = 1 # Placeholder, NOT used! + num_attention_heads: int = 8 # Placeholder, NOT used! + + def __post_init__(self): + if self.vision_model_type == "siglip": + self.add_class_token = False + self.class_token_len = 0 + + def configure_model(self) -> "CLIPViTModel": transformer_layer_spec = self.transformer_layer_spec if not isinstance(transformer_layer_spec, ModuleSpec): - transformer_layer_spec = transformer_layer_spec(self) - return MCoreCLIPViTModel( + from nemo.collections.vlm.layer_specs import get_layer_spec_te + + transformer_layer_spec = get_layer_spec_te(is_vit=True) + return CLIPViTModel( self, transformer_layer_spec, ln_pre_impl=self.ln_pre_impl, + ln_post_impl=self.ln_post_impl, add_class_token=self.add_class_token, class_token_len=self.class_token_len, patch_dim=self.patch_dim, img_h=self.img_h, img_w=self.img_w, + model_subtype=self.vision_model_type, ) @@ -226,283 +283,173 @@ class NevaConfig(TransformerConfig, io.IOMixin): language_transformer_config: Optional[TransformerConfig] = None vision_transformer_config: Optional[TransformerConfig] = None vision_projection_config: Optional[TransformerConfig] = None + drop_vision_class_token: bool = True + vision_feature_layer: int = -2 + + encoder_pipeline_model_parallel_size: int = 0 + encoder_tensor_model_parallel_size: int = 1 num_layers: int = 1 # Placeholder, NOT used! num_attention_heads: int = 8 # Placeholder, NOT used! - vision_feature_layer: int = -2 + + seq_length: int = 1024 language_model_from_pretrained: Optional[str] = None vision_model_from_pretrained: Optional[str] = None # TODO vision_projection_from_pretrained: Optional[str] = None # TODO - freeze_language_model: bool = True - freeze_vision_model: bool = True + freeze_language_model: bool = False + freeze_vision_model: bool = False freeze_vision_projection: bool = False forward_step_fn: Callable = neva_forward_step data_step_fn: Callable = neva_data_step - def configure_model(self, tokenizer) -> "MCoreLLaVAModel": - language_model = self.language_transformer_config.configure_model(tokenizer=tokenizer) - vision_model = self.vision_transformer_config.configure_model() - vision_projection = self.vision_projection_config.configure_model() - - if self.language_model_from_pretrained is not None: - sharded_state_dict = dict(state_dict=language_model.sharded_state_dict(prefix="module.")) - loaded_state_dict = dist_checkpointing.load( - sharded_state_dict=sharded_state_dict, checkpoint_dir=self.language_model_from_pretrained + def __post_init__(self): + if self.language_transformer_config is not None: + for attr in MODEL_CONFIG_ATTR: + setattr(self, attr, getattr(self.language_transformer_config, attr)) + + def configure_model(self, tokenizer) -> "MCoreNevaModel": + from megatron.core import parallel_state as ps + + self.language_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.sequence_parallel = self.sequence_parallel + self.vision_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.pipeline_model_parallel_size = self.pipeline_model_parallel_size + + if self.encoder_pipeline_model_parallel_size > 0: + assert self.encoder_pipeline_model_parallel_size == 1, "ViT can only live on 1 pipeline stage." + self.vision_transformer_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.vision_projection_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.language_transformer_config.encoder_pipeline_model_parallel_size = ( + self.encoder_pipeline_model_parallel_size ) - loaded_state_dict = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} - language_model.load_state_dict(loaded_state_dict) - logging.info(f"Restored language model weights from {self.language_model_from_pretrained}") + if self.encoder_tensor_model_parallel_size > 0: + self.vision_transformer_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + model = MCoreNevaModel( - transformer_config=self, - language_model=language_model, - vision_model=vision_model, - vision_projection=vision_projection, + config=self, + tokenizer=tokenizer, + pre_process=ps.is_pipeline_first_stage() + or ps.get_pipeline_model_parallel_rank() == self.encoder_pipeline_model_parallel_size, + post_process=ps.is_pipeline_last_stage(), + add_encoder=ps.is_pipeline_first_stage(), + add_decoder=ps.is_pipeline_last_stage() + or ps.get_pipeline_model_parallel_rank() >= self.encoder_pipeline_model_parallel_size, drop_vision_class_token=self.drop_vision_class_token, ) - model.freeze( - freeze_language_model=self.freeze_language_model, - freeze_vision_model=self.freeze_vision_model, - freeze_vision_projection=self.freeze_vision_projection, - ) + return model +class CLIPViTModel(MCoreCLIPViTModel): + """CLIP ViT vision model.""" + + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, num_unused_layers: int = 0 + ) -> torch.Tensor: + if num_unused_layers > 0: + unused_layers = self.decoder.layers[-num_unused_layers:] + self.decoder.layers = self.decoder.layers[:-num_unused_layers] + x = super().forward(x, attention_mask) + self.decoder.layers.append(unused_layers) + return x + + return super().forward(x, attention_mask) + + class MCoreNevaModel(MCoreLLaVAModel): def __init__( self, - transformer_config: TransformerConfig, - language_model: MegatronModule, - vision_model: MegatronModule, - vision_projection: MegatronModule, + config: NevaConfig, + tokenizer: Optional = None, pre_process: bool = True, post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, drop_vision_class_token: bool = False, ) -> None: - super(MCoreLLaVAModel, self).__init__(config=transformer_config) + super(MCoreLLaVAModel, self).__init__(config=config) - logging.warning("LLaVA model is under development and may be missing features.") + language_transformer_config = config.language_transformer_config + vision_transformer_config = config.vision_transformer_config + vision_projection_config = config.vision_projection_config self.pre_process = pre_process self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder self.encoder_hidden_state = None - self.vision_model = vision_model - self.vision_projection = vision_projection - self.language_model = language_model - self.model_type = ModelType.encoder_or_decoder - # This attribute is needed to check if an all-reduce is required - # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. - self.share_embeddings_and_output_weights = False - if self.language_model is not None: - self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights - self._language_max_sequence_length = self.language_model.max_sequence_length + self.vision_model = None + self.vision_projection = None + self.language_model = None - if self.vision_model is not None: - self._drop_vision_class_token = drop_vision_class_token + self.sequence_parallel_lm = language_transformer_config.sequence_parallel + self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap - self.add_encoder = self.vision_model is not None - self.add_decoder = self.language_model is not None - self.vision_model_from_hf = str(self.vision_model.__class__.__module__).startswith("transformers.") + self.share_embeddings_and_output_weights = False if self.add_decoder: - vision_config = self.config.vision_transformer_config - if self.vision_model_from_hf: - # img_h, img_w, patch_dim, add_class_token, class_token_len - self._img_seq_len = get_image_sequence_length( - img_h=vision_config.image_size, - img_w=vision_config.image_size, - patch_dim=vision_config.patch_size, - add_class_token=not drop_vision_class_token, - class_token_len=0 if "siglip" in vision_config.model_type else 1, + self.language_model = language_transformer_config.configure_model( + tokenizer=tokenizer, pre_process=pre_process, post_process=post_process + ) + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + self._language_max_sequence_length = self.language_model.max_sequence_length + self._language_is_pipeline_parallel = language_transformer_config.pipeline_model_parallel_size > 1 + if config.language_model_from_pretrained is not None: + sharded_state_dict = dict(state_dict=self.language_model.sharded_state_dict(prefix="module.")) + loaded_state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=ckpt_to_weights_subdir(config.language_model_from_pretrained, is_saving=False), + validate_access_integrity=False, ) - else: - self._img_seq_len = 576 # TODO(yuya): Fix hardcode + loaded_state_dict = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} + self.language_model.load_state_dict(loaded_state_dict) + logging.info(f"Restored language model weights from {config.language_model_from_pretrained}") else: - self._img_seq_len = 0 - - def _preprocess_data( - self, - image_embeddings, - language_embeddings, - input_ids, - loss_mask, - labels, - use_inference_kv_cache, - image_token_index, - num_image_tiles, - ): - # TODO (yuya): remove this and use the mcore method - """Preprocess input data before input to language model. - - This function is adopted from - https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409 - for our input data conventions. - - image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3] and labels = [1, -200, 2, 3, 4], for example. - We want to replace the image position (-200) with image_embeddings and return the following: - - final_embeddings = [0, 1, image_embeddings, 2, 3], - - final_labels = [1, -100, 2, 3, 4] - - final_loss_mask = [1, 0, 0, 1, 1] - - This function also handles the case where the input does not contain an image (text-only sample). It also handles the case where a single input - image is split into multiple tiles. - - If pipeline parallelism is not used, then self.pre_process and self.post_process are both True and we update both - input embeddings, labels and loss masks (if available). - - If pipeline parallelism is used, then we do the following - - the first language model chunk has self.pre_process = True and self.post_process = False. We update input embeddings. - - the middle language model chunk(s) has self.pre_process = False and self.post_process = False. We don't need to update anything. - - the last language model chunk has self.pre_process = False and self.post_process = True. We update labels and loss mask. + if config.language_model_from_pretrained is not None: + dist_checkpointing.load( + sharded_state_dict=dict(state_dict={}), + checkpoint_dir=config.language_model_from_pretrained, + validate_access_integrity=False, + ) - TODO: This function should adjust the attention mask too. Currently, we assume the language model uses a causal mask. + if self.add_encoder: + self.vision_model = vision_transformer_config.configure_model() + self.vision_projection = vision_projection_config.configure_model() + self._drop_vision_class_token = drop_vision_class_token - Returns: - final_embedding (torch.Tensor): image and text embeddings concated [combined_seq_len, b, h]. - final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len]. - final_loss_mask (torch.Tensor): loss mask for image and text positions [b, combined_seq_len]. - """ - assert self.add_decoder, "input text preprocessing is only needed for the language model" + self.freeze( + freeze_language_model=config.freeze_language_model, + freeze_vision_model=config.freeze_vision_model, + freeze_vision_projection=config.freeze_vision_projection, + ) - # No pre- or postprocessing needed. With pipeline parallel > 2, this means a chunk in the middle of the model. - if not self.pre_process and not self.post_process: - return language_embeddings, loss_mask, labels + self.model_type = ModelType.encoder_or_decoder + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. - # If using the inference KV cache, the image tokens are already computed. - if use_inference_kv_cache: - return language_embeddings, loss_mask, labels - - img_seq_len = self._img_seq_len - batch_size, text_seq_len = input_ids.shape - - has_labels = labels is not None - if has_labels: - assert ( - labels.shape == loss_mask.shape - ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}" - - # Create indices for new text and label positions. - with torch.no_grad(): - image_token_mask = input_ids == image_token_index - num_image_tokens = torch.sum(image_token_mask, dim=-1) - - # Number of tiles per sample. - num_image_tiles_batch = num_image_tiles.split(num_image_tokens.tolist(), dim=0) - num_image_tiles_batch = torch.tensor([x.sum() for x in num_image_tiles_batch], device=input_ids.device) - - # Sequence length for each sample is the image sequence length multiplied by the number of tiles for that image, minus image token indices, - # plus text sequence length. - seq_lens = num_image_tiles_batch * img_seq_len - num_image_tokens + text_seq_len - max_seq_len = seq_lens.max() - batch_indices, non_image_indices = torch.where(input_ids != image_token_index) - - # New position ids for the text tokens, shifted by the image sequence length. - # E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get new_position_ids = [576, 577, 578, 579]. - # text_position_ids are then [577, 578, 579]. - image_token_mask_lens = image_token_mask.int().clone() - # -1 is for the removed image token index. - image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1 - # +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing. - new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1 - text_position_ids = new_position_ids[batch_indices, non_image_indices] - - # Labels are shifted to left by one. So, shift text position ids and non-image indices to left by one. - if has_labels: - label_text_position_ids = text_position_ids - 1 - valid_label_text_position_ids = label_text_position_ids >= 0 - label_text_position_ids = label_text_position_ids[valid_label_text_position_ids] - - label_batch_indices = batch_indices[valid_label_text_position_ids] - - label_non_image_indices = non_image_indices - 1 - valid_label_non_image_indices = label_non_image_indices >= 0 - label_non_image_indices = label_non_image_indices[valid_label_non_image_indices] - - # Create a mask for the image embedding positions. - with torch.no_grad(): - images_mask = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device) - # No images in the text positions. - images_mask[batch_indices, text_position_ids] = False - # Samples can have different amount of images tokens. new_position_ids[:, -1] gives the last text position id for each sample. - # Padding is needed when the number of image tokens differs. - first_padding_idx = new_position_ids[:, -1] + 1 - images_mask[ - torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1) - >= first_padding_idx.unsqueeze(1) - ] = False - - # Create the final input embedding (if this is the first language model stage). - final_embedding = None - if self.pre_process: - embed_dim = language_embeddings.shape[-1] - final_embedding = torch.zeros( - batch_size, - max_seq_len, - embed_dim, - dtype=image_embeddings.dtype, - device=image_embeddings.device, + self.vision_model_from_hf = hasattr(vision_transformer_config, "image_size") + if self.vision_model_from_hf: + # img_h, img_w, patch_dim, add_class_token, class_token_len + self._img_seq_len = get_image_sequence_length( + img_h=vision_transformer_config.image_size, + img_w=vision_transformer_config.image_size, + patch_dim=vision_transformer_config.patch_size, + add_class_token=not drop_vision_class_token, + class_token_len=0 if "siglip" in vision_transformer_config.model_type else 1, ) - - # Put text embeddings to the text positions in the result tensor. - final_embedding[batch_indices, text_position_ids] = language_embeddings[batch_indices, non_image_indices] - - # Put image embeddings to image positions. - final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous() - - # Create the final labels and loss mask (if this is the last language model stage). - final_labels, final_loss_mask = None, None - if has_labels: - final_labels = torch.full( - (batch_size, max_seq_len), IGNORE_INDEX, dtype=labels.dtype, device=labels.device + else: + self._img_seq_len = get_image_sequence_length( + img_h=vision_transformer_config.img_h, + img_w=vision_transformer_config.img_w, + patch_dim=vision_transformer_config.patch_dim, + add_class_token=not drop_vision_class_token, + class_token_len=vision_transformer_config.class_token_len, ) - final_loss_mask = torch.full((batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device) - - # Put text labels and loss mask to the text positions. - final_labels[label_batch_indices, label_text_position_ids] = labels[ - label_batch_indices, label_non_image_indices - ] - - final_loss_mask[batch_indices, text_position_ids] = loss_mask[batch_indices, non_image_indices] - - # For labels, we need to pick the last label index that got dropped by the shift to left. - label_extra_text_position_ids = seq_lens - 1 - batch_range = torch.arange(len(label_extra_text_position_ids)) - final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1] - - # Loss mask the image positions. - final_loss_mask[images_mask] = 0 - - # Loss mask last text position just before an image so that text token does not need to predict the first image token. - batch_image_indices, image_indices = torch.where(image_token_mask) - # Indices just before image tokens. If it's -1, skip it. - before_image_indices = image_indices - 1 - valid = before_image_indices >= 0 - valid_batch_image_indices = batch_image_indices[valid] - valid_before_image_indices = before_image_indices[valid] - # Map those indices those position ids. - valid_before_image_indices = new_position_ids[valid_batch_image_indices, valid_before_image_indices] - - final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0 - - if final_embedding is not None and has_labels: - assert ( - final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape - ), "unexpected shapes after data preprocessing" - - if final_embedding is not None: - final_embedding = final_embedding.transpose(1, 0).contiguous() - - # Truncate if exceeding the language model's max sequence length. - if final_embedding is not None and final_embedding.shape[0] > self._language_max_sequence_length: - final_embedding = final_embedding[: self._language_max_sequence_length] - - if has_labels and final_labels.shape[1] > self._language_max_sequence_length: - final_labels = final_labels[:, : self._language_max_sequence_length] - final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length] - - return final_embedding, final_labels, final_loss_mask def forward( self, @@ -515,6 +462,7 @@ def forward( inference_params: Optional[InferenceParams] = None, num_media_tiles: Optional[List[int]] = None, media_token_index: Optional[int] = IMAGE_TOKEN_INDEX, + runtime_gather_output: Optional[bool] = None, ) -> torch.Tensor: """Forward function of the LLaVA model. @@ -533,34 +481,44 @@ def forward( output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s]. """ + use_inference_kv_cache = ( inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict ) + has_images = media.shape[0] > 0 + # If running inference, we can skip media token computation if they were computed already earlier for this sample. - if use_inference_kv_cache or media is None: + if use_inference_kv_cache: media_embeddings = None - elif self.add_encoder: + elif self.add_encoder and not has_images: + # If no images provided, use an empty image embeddings tensor. + media_embeddings = torch.tensor([], dtype=media.dtype, device=media.device).reshape(0, 0, 0) + elif self.add_encoder and has_images: # media is in shape of (num_images_in_mbs, c, h, w) # note num_images_in_mbs is not mbs but total images in this mbs. if self.vision_model_from_hf: - media_embeddings = self.vision_model( - media, output_hidden_states=True - ) # [num_images, img_seq_len, h_vision] + self.vision_model = self.vision_model.eval() + media_embeddings = self.vision_model(media, output_hidden_states=True) media_embeddings = media_embeddings[-1][ self.config.vision_feature_layer - ] # take second from last layer + ] # [num_images, img_seq_len, h_vision] else: # TODO(yuya): MCore Clip path not yet support taking a specific layer hidden states - media_embeddings = self.vision_model(media) + media = media.to(next(self.vision_model.parameters()).dtype) + media_embeddings = self.vision_model(media, num_unused_layers=-self.config.vision_feature_layer - 1) if self._drop_vision_class_token: class_token_len = getattr(self.vision_model, "class_token_len", 1) media_embeddings = media_embeddings[:, class_token_len:, :] + # contiguous() required as `permute` can sparsify the tensor and this breaks pipelining + media_embeddings = media_embeddings.permute(1, 0, 2).contiguous() # [img_seq_len, num_tiles, h_vision] + # map vision model output size to language model input size. - media_embeddings = self.vision_projection(media_embeddings) # [img_seq_len, num_tiles, h_vision] + media_embeddings = self.vision_projection(media_embeddings) # [img_seq_len, num_tiles, h_language] - # If running inference, the language model KV cache will be updated for media token positions. - # Here we store the media tokens sequence length, which can be used as an offset to the KV cache later. + # TODO: Support batched inference. + # In inference, the language model KV cache will be updated for image token positions. + # Store the image tokens sequence length to be used as an offset to the KV cache later. if inference_params is not None: inference_params.key_value_memory_dict["media_tokens_count"] = ( media_embeddings.shape[0] * media_embeddings.shape[1] @@ -569,40 +527,61 @@ def forward( media_embeddings = self.encoder_hidden_state if not self.add_decoder: - return media_embeddings, loss_mask + return media_embeddings language_embeddings = None if self.pre_process: input_ids_text = input_ids.clone() # MultiModal Token indices are assumed to be values input_ids_text[input_ids_text < 0] = 0 - # Note: This adds absolute position embedding but not RoPE. Each image is counted as one position. - # RoPE is added in language_model forward call. Each image embedding is one position. + # Note: This adds absolute position embedding but not RoPE. + # Each image is counted as one position. + # RoPE is added in language_model forward. Each image embedding is one position. + if self.sequence_parallel_lm: + # Pad to nearest multiple of TP world size for embedding. + tp_world_size = ps.get_tensor_model_parallel_world_size() + padded_seq_len = ( + int((input_ids_text.shape[1] + tp_world_size - 1) // tp_world_size * tp_world_size) + - input_ids_text.shape[1] + ) + if padded_seq_len != 0: + input_ids_text = torch.nn.functional.pad(input_ids_text, (0, padded_seq_len)) + if position_ids is not None: + position_ids = torch.nn.functional.pad(position_ids, (0, padded_seq_len)) language_embeddings = self.language_model.embedding( input_ids=input_ids_text, position_ids=position_ids ) # [text_seq_len, b, h_language] + if self.sequence_parallel_lm: + # Gather the language embeddings back. + # We use the full embedding to insert image embeddings + # and then scatter to avoid load imbalance. + language_embeddings = gather_from_sequence_parallel_region( + language_embeddings, tensor_parallel_output_grad=False + ) + # Remove the padding done for SP as we'll need new padding calculation + # after image embeddings are inserted. + if padded_seq_len != 0: + language_embeddings = language_embeddings[:-padded_seq_len] language_embeddings = language_embeddings.transpose(1, 0).contiguous() # [b, text_seq_len, h_language] - if media is None: - combined_embeddings = language_embeddings.transpose(1, 0).contiguous() - final_labels = labels - final_loss_mask = loss_mask - else: - # Assume 1 tile per image if the number of tiles is not provided. - if num_media_tiles is None: - num_media_tiles = torch.ones(media.shape[0], dtype=torch.int, device=input_ids.device) - - # Preprocess input, labels and loss mask. - combined_embeddings, final_labels, final_loss_mask = self._preprocess_data( - media_embeddings, - language_embeddings, - input_ids, - loss_mask, - labels, - use_inference_kv_cache, - media_token_index, - num_media_tiles, - ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] + # Assume 1 tile per image if the number of tiles is not provided. + if num_media_tiles is None: + num_media_tiles = torch.ones(media.shape[0], dtype=torch.int, device=input_ids.device) + elif isinstance(num_media_tiles, list): + num_media_tiles = torch.tensor(num_media_tiles, dtype=torch.int, device=input_ids.device) + + # Preprocess input, labels and loss mask. + combined_embeddings, final_labels, final_loss_mask, final_attention_mask = self._preprocess_data( + media_embeddings, + language_embeddings, + input_ids, + loss_mask, + labels, + use_inference_kv_cache, + media_token_index, + num_media_tiles, + attention_mask, + ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] output = self.language_model( input_ids=None, @@ -611,6 +590,7 @@ def forward( decoder_input=combined_embeddings, labels=final_labels, inference_params=inference_params, + runtime_gather_output=runtime_gather_output, ) if labels is None or loss_mask is None: @@ -618,6 +598,23 @@ def forward( return output, final_loss_mask.contiguous() + def set_input_tensor(self, input_tensor) -> None: + """Set model chunk input tensor.""" + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for llava' + + if self.add_encoder and self.add_decoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + class NevaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): def __init__( @@ -649,6 +646,7 @@ def forward( media: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, inference_params: InferenceParams = None, + num_media_tiles: Optional[List[int]] = None, ) -> torch.Tensor: output_tensor = self.module( media=media, @@ -658,6 +656,7 @@ def forward( attention_mask=attention_mask, labels=labels, inference_params=inference_params, + num_media_tiles=num_media_tiles, ) return output_tensor @@ -697,6 +696,4 @@ def validation_loss_reduction(self) -> MaskedTokenLossReductionWithLossMask: "NevaConfig", "neva_data_step", "neva_forward_step", - "transformer_engine_layer_spec", - "local_layer_spec", ] diff --git a/nemo/collections/vlm/neva/model/llava.py b/nemo/collections/vlm/neva/model/llava.py index 52b55b6f9c2d..5e02b4f9e9d7 100644 --- a/nemo/collections/vlm/neva/model/llava.py +++ b/nemo/collections/vlm/neva/model/llava.py @@ -43,7 +43,7 @@ class LlavaConfig(NevaConfig): @dataclass -class Llava1_5Config7B(LlavaConfig): +class Llava15Config7B(LlavaConfig): from transformers import PretrainedConfig language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config7B()) @@ -56,7 +56,7 @@ class Llava1_5Config7B(LlavaConfig): @dataclass -class Llava1_5Config13B(LlavaConfig): +class Llava15Config13B(LlavaConfig): from transformers import PretrainedConfig language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config13B()) @@ -111,7 +111,6 @@ def convert_state(self, source, target): "language_model.model.layers.*.post_attention_layernorm.weight": "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight", "language_model.model.norm.weight": "language_model.decoder.final_layernorm.weight", "language_model.lm_head.weight": "language_model.output_layer.weight", - "vision_tower.vision_model.**": "vision_model.vision_model.**", } if "vision_projection.encoder.linear_fc1.weight" in target.module.state_dict().keys(): mapping.update( @@ -134,7 +133,45 @@ def convert_state(self, source, target): else: raise KeyError("Unable to map vision projection keys.") - return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + if "vision_model.vision_model.embeddings.class_embedding" in target.module.state_dict().keys(): + mapping.update( + { + "vision_tower.vision_model.**": "vision_model.vision_model.**", + } + ) + elif "vision_model.class_token" in target.module.state_dict().keys(): + mapping.update( + { + "vision_tower.vision_model.embeddings.patch_embedding.weight": "vision_model.conv1.weight", + "vision_tower.vision_model.embeddings.position_embedding.weight": "vision_model.position_embeddings.weight", + "vision_tower.vision_model.encoder.layers.*.layer_norm1.weight": "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "vision_tower.vision_model.encoder.layers.*.layer_norm1.bias": "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "vision_tower.vision_model.encoder.layers.*.layer_norm2.weight": "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "vision_tower.vision_model.encoder.layers.*.layer_norm2.bias": "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "vision_tower.vision_model.encoder.layers.*.self_attn.out_proj.weight": "vision_model.decoder.layers.*.self_attention.linear_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.out_proj.bias": "vision_model.decoder.layers.*.self_attention.linear_proj.bias", + "vision_tower.vision_model.encoder.layers.*.mlp.fc1.weight": "vision_model.decoder.layers.*.mlp.linear_fc1.weight", + "vision_tower.vision_model.encoder.layers.*.mlp.fc1.bias": "vision_model.decoder.layers.*.mlp.linear_fc1.bias", + "vision_tower.vision_model.encoder.layers.*.mlp.fc2.weight": "vision_model.decoder.layers.*.mlp.linear_fc2.weight", + "vision_tower.vision_model.encoder.layers.*.mlp.fc2.bias": "vision_model.decoder.layers.*.mlp.linear_fc2.bias", + "vision_tower.vision_model.pre_layrnorm.weight": "vision_model.ln_pre.weight", + "vision_tower.vision_model.pre_layrnorm.bias": "vision_model.ln_pre.bias", + } + ) + else: + raise KeyError("Unable to map vision encoder keys.") + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[ + _import_language_qkv, + _import_vision_qkv, + _import_vision_qkv_bias, + _import_cls_token, + _import_linear_fc1, + ], + ) @property def tokenizer(self) -> "AutoTokenizer": @@ -183,80 +220,7 @@ def make_vocab_size_divisible_by(vocab_size): return output -@io.model_exporter(LlavaModel, "hf") -class HFLlavaExporter(io.ModelConnector[LlavaModel, "LlavaForConditionalGeneration"]): - def init(self) -> "LlavaForConditionalGeneration": - raise NotImplementedError("Neva Exporter hasn't been verified!") - - from transformers import AutoModelForCausalLM - - return AutoModelForCausalLM.from_config(self.config) - - def apply(self, output_path: Path) -> Path: - target = self.init() - source, _ = self.nemo_load(str(self)) - - target = self.convert_state(source, target) - - target = target.cpu() - target.save_pretrained(output_path) - self.tokenizer.save_pretrained(output_path) - - return output_path - - def convert_state(self, source, target): - mapping = { - "language_model.embedding.word_embeddings.weight": "language_model.model.embed_tokens.weight", - "language_model.decoder.layers.*.self_attention.linear_proj.weight": "language_model.model.layers.*.self_attn.o_proj.weight", - "language_model.decoder.layers.*.mlp.linear_fc2.weight": "language_model.model.layers.*.mlp.down_proj.weight", - "language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "language_model.model.layers.*.input_layernorm.weight", - "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "language_model.model.layers.*.post_attention_layernorm.weight", - "language_model.decoder.final_layernorm.weight": "language_model.model.norm.weight", - "language_model.output_layer.weight": "language_model.lm_head.weight", - } - - return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) - - @property - def tokenizer(self): - return io.load_context(str(self)).model.tokenizer.tokenizer - - @property - def config(self) -> "HFLlavaConfig": - source: LlavaConfig = io.load_context(str(self)).model.config - - from transformers import LlavaConfig as HFLlavaConfig - - return HFLlavaConfig( - num_hidden_layers=source.num_layers, - hidden_size=source.hidden_size, - intermediate_size=source.ffn_hidden_size, - num_attention_heads=source.num_attention_heads, - max_position_embeddings=source.seq_length, - initializer_range=source.init_method_std, - rms_norm_eps=source.layernorm_epsilon, - num_key_value_heads=source.num_query_groups, - rope_theta=source.rotary_base, - vocab_size=self.tokenizer.vocab_size, - ) - - -@io.state_transform( - source_key=( - "language_model.model.layers.*.self_attn.q_proj.weight", - "language_model.model.layers.*.self_attn.k_proj.weight", - "language_model.model.layers.*.self_attn.v_proj.weight", - ), - target_key="language_model.decoder.layers.*.self_attention.linear_qkv.weight", -) -def _import_qkv(ctx: io.TransformCTX, q, k, v): - megatron_config = ctx.target.config.language_transformer_config - head_num = megatron_config.num_attention_heads - num_query_groups = megatron_config.num_query_groups - heads_per_group = head_num // num_query_groups - hidden_size = megatron_config.hidden_size - head_size = megatron_config.kv_channels - +def import_qkv(q, k, v, head_num, num_query_groups, heads_per_group, hidden_size, head_size): old_tensor_shape = q.size() new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] @@ -282,59 +246,85 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): @io.state_transform( - source_key="language_model.decoder.layers.*.self_attention.linear_qkv.weight", - target_key=( + source_key=( "language_model.model.layers.*.self_attn.q_proj.weight", "language_model.model.layers.*.self_attn.k_proj.weight", "language_model.model.layers.*.self_attn.v_proj.weight", ), + target_key="language_model.decoder.layers.*.self_attention.linear_qkv.weight", ) -def _export_qkv(ctx: io.TransformCTX, linear_qkv): - megatron_config = ctx.source.config - - head_num = megatron_config.num_attention_heads - num_query_groups = megatron_config.num_query_groups - heads_per_group = head_num // num_query_groups - hidden_size = megatron_config.hidden_size - head_size = megatron_config.kv_channels - qkv_total_dim = head_num + 2 * num_query_groups - - linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) - q_slice = torch.cat( - [ - torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) - for i in range(num_query_groups) - ] +def _import_language_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config.language_transformer_config + return import_qkv( + q, + k, + v, + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=megatron_config.hidden_size, + head_size=megatron_config.kv_channels, ) - k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) - v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) - q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() - k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() - v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() - return q_proj, k_proj, v_proj +@io.state_transform( + source_key=( + "vision_tower.vision_model.encoder.layers.*.self_attn.q_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.k_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.v_proj.weight", + ), + target_key="vision_model.decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_vision_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config.vision_transformer_config + return import_qkv( + q, + k, + v, + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=megatron_config.hidden_size, + head_size=megatron_config.kv_channels, + ) @io.state_transform( source_key=( - "language_model.model.layers.*.mlp.gate_proj.weight", - "language_model.model.layers.*.mlp.up_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.q_proj.bias", + "vision_tower.vision_model.encoder.layers.*.self_attn.k_proj.bias", + "vision_tower.vision_model.encoder.layers.*.self_attn.v_proj.bias", ), - target_key="language_model.decoder.layers.*.mlp.linear_fc1.weight", + target_key="vision_model.decoder.layers.*.self_attention.linear_qkv.bias", ) -def _import_linear_fc1(down, gate): - return torch.cat((down, gate), axis=0) +def _import_vision_qkv_bias(ctx: io.TransformCTX, q_bias, k_bias, v_bias): + megatron_config = ctx.target.config.vision_transformer_config + return import_qkv( + q_bias.unsqueeze(-1), + k_bias.unsqueeze(-1), + v_bias.unsqueeze(-1), + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=1, + head_size=megatron_config.kv_channels, + ).squeeze(-1) + + +@io.state_transform( + source_key=("vision_tower.vision_model.embeddings.class_embedding",), + target_key="vision_model.class_token", +) +def _import_cls_token(ctx: io.TransformCTX, cls_token): + return cls_token.reshape(1, 1, -1) @io.state_transform( - source_key="language_model.decoder.layers.*.mlp.linear_fc1.weight", - target_key=( + source_key=( "language_model.model.layers.*.mlp.gate_proj.weight", "language_model.model.layers.*.mlp.up_proj.weight", ), + target_key="language_model.decoder.layers.*.mlp.linear_fc1.weight", ) -def _export_linear_fc1(linear_fc1): - gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) - - return gate_proj, up_proj +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0) diff --git a/nemo/collections/vlm/neva/model/vit_config.py b/nemo/collections/vlm/neva/model/vit_config.py new file mode 100644 index 000000000000..5d60a84313ca --- /dev/null +++ b/nemo/collections/vlm/neva/model/vit_config.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from nemo.collections.llm.fn.activation import openai_gelu, quick_gelu + +from nemo.collections.vlm.neva.model.base import CLIPViTConfig + + +@dataclass +class CLIPViTL_14_336_Config(CLIPViTConfig): + """Clip vit large patch14 config""" + + vision_model_type = "clip" + patch_dim = 14 + img_h = 336 + img_w = 336 + num_layers = 24 + num_attention_heads = 16 + add_bias_linear = True + add_qkv_bias = True + hidden_size = 1024 + hidden_dropout = 0.0 + attention_dropout = 0.0 + ffn_hidden_size = 4096 + gated_linear_unit = False + activation_func = quick_gelu + kv_channels = 64 + num_query_groups = 16 + layernorm_zero_centered_gamma = False + apply_query_key_layer_scaling = False + bias_activation_fusion = False + bias_dropout_fusion = False + attention_softmax_in_fp32 = True + normalization = 'LayerNorm' + apply_rope_fusion = False + + +@dataclass +class SigLIPViT400M_14_384_Config(CLIPViTConfig): + """Siglip so400m patch14 384 config""" + + vision_model_type = "siglip" + patch_dim = 14 + img_h = 384 + img_w = 384 + num_layers = 27 + num_attention_heads = 16 + add_bias_linear = True + add_qkv_bias = True + hidden_size = 1152 + hidden_dropout = 0.0 + attention_dropout = 0.0 + ffn_hidden_size = 4304 + gated_linear_unit = False + activation_func = openai_gelu + kv_channels = 72 + num_query_groups = 16 + layernorm_zero_centered_gamma = False + apply_query_key_layer_scaling = False + bias_activation_fusion = False + bias_dropout_fusion = False + attention_softmax_in_fp32 = True + normalization = 'LayerNorm' + apply_rope_fusion = False + qk_layernorm = False + layernorm_epsilon = 1e-6 diff --git a/nemo/collections/vlm/peft/lora.py b/nemo/collections/vlm/peft/lora.py index 1e394daa8ead..7a80b7e06883 100644 --- a/nemo/collections/vlm/peft/lora.py +++ b/nemo/collections/vlm/peft/lora.py @@ -48,7 +48,7 @@ class LoRA(LLMLoRA): """ freeze_language_model: bool = True - freeze_vision_model: bool = False + freeze_vision_model: bool = True def freeze_model(self, model: nn.Module) -> None: modules = [] diff --git a/nemo/collections/vlm/recipes/__init__.py b/nemo/collections/vlm/recipes/__init__.py index 2b71ecc50f8f..ba8706437c56 100644 --- a/nemo/collections/vlm/recipes/__init__.py +++ b/nemo/collections/vlm/recipes/__init__.py @@ -13,9 +13,11 @@ # limitations under the License. -from nemo.collections.vlm.recipes import mllama_11b, mllama_90b +from nemo.collections.vlm.recipes import llava15_7b, llava15_13b, mllama_11b, mllama_90b __all__ = [ + "llava15_7b", + "llava15_13b", "mllama_11b", "mllama_90b", ] diff --git a/nemo/collections/vlm/recipes/llava15_13b.py b/nemo/collections/vlm/recipes/llava15_13b.py new file mode 100644 index 000000000000..97b77b82d3de --- /dev/null +++ b/nemo/collections/vlm/recipes/llava15_13b.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.llm.recipes.finetune_default import nemo_resume +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.vlm.neva.data.mock import MockDataModule + +NAME = "llava15_13b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llava 1.5 13B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llava 1.5 13B model model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llava15_13b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(vlm.LlavaModel, config=run.Config(vlm.Llava15Config13B)) + + +@run.cli.factory(target=llm.finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', +) -> run.Partial: + """ + Create a fine-tuning recipe for Llava1.5 13B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llava15_13b + + Python API usage: + >>> recipe = finetune_recipe(name="llava15_13b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + ) + + recipe = run.Partial( + llm.finetune, + model=model(), + trainer=trainer, + data=run.Config( + MockDataModule, + seq_length=4096, + global_batch_size=128, + micro_batch_size=1, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=2.0e-05, min_lr=2.0e-07, warmup_steps=150), + resume=nemo_resume("llava-hf/llava-1.5-13b-hf"), + ) + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 2e-05 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config( + vlm.LoRA, + freeze_vision_model=False, + target_modules=[ + "*.language_model.*.linear_qkv", + "*.language_model.*.linear_q", + "*.language_model.*.linear_kv", + "*.language_model.*.linear_proj", + "*.language_model.*.linear_fc1", + "*.language_model.*.linear_fc2", + ], + ) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + return recipe diff --git a/nemo/collections/vlm/recipes/llava15_7b.py b/nemo/collections/vlm/recipes/llava15_7b.py new file mode 100644 index 000000000000..04e6bd36f4d4 --- /dev/null +++ b/nemo/collections/vlm/recipes/llava15_7b.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.llm.recipes.finetune_default import nemo_resume +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.vlm.neva.data.mock import MockDataModule +from nemo.utils.exp_manager import TimingCallback + +NAME = "llava15_7b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llava 1.5 7B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llava 1.5 7B model model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llava15_7b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(vlm.LlavaModel, config=run.Config(vlm.Llava15Config7B)) + + +@run.cli.factory(target=llm.finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'none', +) -> run.Partial: + """ + Create a fine-tuning recipe for Llava1.5 7B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llava15_7b + + Python API usage: + >>> recipe = finetune_recipe(name="llava15_7b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + callbacks=[run.Config(TimingCallback)], + ) + + recipe = run.Partial( + llm.finetune, + model=model(), + trainer=trainer, + data=run.Config( + MockDataModule, + seq_length=4096, + global_batch_size=128, + micro_batch_size=2, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=2.0e-05, min_lr=2.0e-07, warmup_steps=150), + resume=nemo_resume("llava-hf/llava-1.5-7b-hf"), + ) + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 2e-05 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config( + vlm.LoRA, + freeze_vision_model=False, + target_modules=[ + "*.language_model.*.linear_qkv", + "*.language_model.*.linear_q", + "*.language_model.*.linear_kv", + "*.language_model.*.linear_proj", + "*.language_model.*.linear_fc1", + "*.language_model.*.linear_fc2", + ], + ) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + return recipe diff --git a/nemo/collections/vlm/recipes/mllama_11b.py b/nemo/collections/vlm/recipes/mllama_11b.py index e4842ae63d52..4b08606900e3 100644 --- a/nemo/collections/vlm/recipes/mllama_11b.py +++ b/nemo/collections/vlm/recipes/mllama_11b.py @@ -26,6 +26,7 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed from nemo.collections.vlm.mllama.data.mock import MockDataModule +from nemo.utils.exp_manager import TimingCallback NAME = "mllama_11b" @@ -46,7 +47,7 @@ def model() -> run.Config[pl.LightningModule]: >>> model_config = model() >>> print(model_config) """ - return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig11B)) + return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig11BInstruct)) @run.cli.factory(target=llm.finetune, name=NAME) @@ -107,6 +108,7 @@ def finetune_recipe( plugins=bf16_mixed(), strategy=strategy, val_check_interval=100, + callbacks=[run.Config(TimingCallback)], ) recipe = run.Partial( @@ -115,34 +117,37 @@ def finetune_recipe( trainer=trainer, data=run.Config( MockDataModule, - seq_length=4100, # encoder (vision) seq length - decoder_seq_length=512, # decoder (llm) seq length - global_batch_size=16, - micro_batch_size=2, + seq_length=6404, # encoder (vision) seq length + decoder_seq_length=2048, # decoder (llm) seq length + global_batch_size=2, + micro_batch_size=1, vocab_size=128256, - crop_size=(448, 448), + crop_size=(560, 560), num_workers=0, ), log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=2.0e-07, warmup_steps=150), - resume=nemo_resume("meta-llama/Llama-3.2-11B-Vision"), + resume=nemo_resume("meta-llama/Llama-3.2-11B-Vision-Instruct"), ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 2e-05 elif peft_scheme.lower() == 'lora': + # pylint: disable=line-too-long + """Adapted from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/configs/peft.py""" recipe.peft = run.Config( vlm.LoRA, - freeze_vision_model=False, + freeze_vision_model=True, target_modules=[ - "*.language_model.*.linear_qkv", - "*.language_model.*.linear_q", - "*.language_model.*.linear_kv", - "*.language_model.*.linear_proj", - "*.language_model.*.linear_fc1", - "*.language_model.*.linear_fc2", + "linear_qkv", + "linear_q", + "linear_kv", ], + dim=8, + alpha=32, + dropout=0.05, + dropout_position="pre", ) recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/collections/vlm/recipes/mllama_90b.py b/nemo/collections/vlm/recipes/mllama_90b.py index 28a6ff7ff9a6..12e0329fc6dd 100644 --- a/nemo/collections/vlm/recipes/mllama_90b.py +++ b/nemo/collections/vlm/recipes/mllama_90b.py @@ -26,6 +26,7 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed from nemo.collections.vlm.mllama.data.mock import MockDataModule +from nemo.utils.exp_manager import TimingCallback NAME = "mllama_90b" @@ -46,7 +47,7 @@ def model() -> run.Config[pl.LightningModule]: >>> model_config = model() >>> print(model_config) """ - return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90B)) + return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90BInstruct)) @run.cli.factory(target=llm.finetune, name=NAME) @@ -107,6 +108,7 @@ def finetune_recipe( plugins=bf16_mixed(), strategy=strategy, val_check_interval=100, + callbacks=[run.Config(TimingCallback)], ) recipe = run.Partial( @@ -116,7 +118,7 @@ def finetune_recipe( data=run.Config( MockDataModule, seq_length=6404, # encoder (vision) seq length - decoder_seq_length=512, # decoder (llm) seq length + decoder_seq_length=2048, # decoder (llm) seq length global_batch_size=16, micro_batch_size=2, vocab_size=128256, @@ -125,23 +127,26 @@ def finetune_recipe( ), log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=2.0e-07, warmup_steps=150), - resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision"), + resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision-Instruct"), ) if peft_scheme is None or peft_scheme.lower() == 'none': raise ValueError("Full finetuning recipe for Llama-3.2-90B model will be supported soon.") elif peft_scheme.lower() == 'lora': + # pylint: disable=line-too-long + """Adapted from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/configs/peft.py""" recipe.peft = run.Config( vlm.LoRA, - freeze_vision_model=False, + freeze_vision_model=True, target_modules=[ - "*.language_model.*.linear_qkv", - "*.language_model.*.linear_q", - "*.language_model.*.linear_kv", - "*.language_model.*.linear_proj", - "*.language_model.*.linear_fc1", - "*.language_model.*.linear_fc2", + "linear_qkv", + "linear_q", + "linear_kv", ], + dim=8, + alpha=32, + dropout=0.05, + dropout_position="pre", ) recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index a1e6cb0e03c4..8f2b0db20341 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -30,10 +30,11 @@ import wrapt from tensorrt_llm._utils import numpy_to_torch +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.deploy import ITritonDeployable from nemo.export.tarutils import TarPath, unpack_tarball from nemo.export.trt_llm.converter.model_converter import model_to_trtllm_ckpt -from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt +from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt, get_layer_prefix from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import ( build_tokenizer, @@ -65,6 +66,8 @@ @wrapt.decorator def noop_decorator(func): + """No op decorator""" + def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -80,6 +83,7 @@ def wrapper(*args, **kwargs): use_pytriton = False +# pylint: disable=line-too-long class TensorRTLLM(ITritonDeployable): """ Exports nemo checkpoints to TensorRT-LLM and run fast inference. @@ -347,43 +351,14 @@ def export( DEFAULT_CONVERSION_DICT, ) from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper - from megatron.core.transformer.transformer_config import TransformerConfig from tensorrt_llm.layers import MoeConfig - def get_transformer_config(nemo_model_config): - normalization = nemo_model_config.get('normalization', 'layernorm') - transformer_config_normalization = 'LayerNorm' - layernorm_zero_centered_gamma = False - if normalization == 'layernorm1p': - layernorm_zero_centered_gamma = True - elif normalization == 'rmsnorm': - transformer_config_normalization = 'RMSNorm' - - conf = TransformerConfig( - num_layers=nemo_model_config.get('num_layers'), - moe_router_topk=nemo_model_config.get('moe_router_topk', 0), - num_attention_heads=nemo_model_config.get('num_attention_heads'), - num_query_groups=nemo_model_config.get( - 'num_query_groups', nemo_model_config['num_attention_heads'] - ), - kv_channels=nemo_model_config.get("kv_channels", None), - hidden_size=nemo_model_config.get('hidden_size'), - ffn_hidden_size=nemo_model_config.get('ffn_hidden_size'), - layernorm_epsilon=nemo_model_config.get('layernorm_epsilon'), - add_bias_linear=nemo_model_config.get('bias'), - num_moe_experts=nemo_model_config.get('num_moe_experts', None), - normalization=transformer_config_normalization, - layernorm_zero_centered_gamma=layernorm_zero_centered_gamma, - ) - - return conf - # We build the transformer config using the nemo model config. - transformer_config = get_transformer_config(model_configs) + transformer_config = self.get_transformer_config(model_configs) input_model_type = getattr(ModelType, model_type) # MCore export supports some default conversion dictionaries - mcore_model_conversion_dict = DEFAULT_CONVERSION_DICT[input_model_type] + mcore_model_conversion_dict = DEFAULT_CONVERSION_DICT # All Mcore conversion dicts start with "decoder.layers.4.blah.blah" , while nemo models start with "model.decoder.layers.4.blahblah". so we append model. to the keys nemo_model_conversion_dict = { f'model.{key}': value for key, value in mcore_model_conversion_dict.items() @@ -524,6 +499,34 @@ def get_transformer_config(nemo_model_config): if load_model: self._load() + def get_transformer_config(self, nemo_model_config): + """Given nemo model config get transformer config""" + from megatron.core.transformer.transformer_config import TransformerConfig + + normalization = nemo_model_config.get('normalization', 'layernorm') + transformer_config_normalization = 'LayerNorm' + layernorm_zero_centered_gamma = False + if normalization == 'layernorm1p': + layernorm_zero_centered_gamma = True + elif normalization == 'rmsnorm': + transformer_config_normalization = 'RMSNorm' + + conf = TransformerConfig( + num_layers=nemo_model_config.get('num_layers'), + moe_router_topk=nemo_model_config.get('moe_router_topk', 0), + num_attention_heads=nemo_model_config.get('num_attention_heads'), + num_query_groups=nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + kv_channels=nemo_model_config.get("kv_channels", None), + hidden_size=nemo_model_config.get('hidden_size'), + ffn_hidden_size=nemo_model_config.get('ffn_hidden_size'), + layernorm_epsilon=nemo_model_config.get('layernorm_epsilon'), + add_bias_linear=nemo_model_config.get('bias'), + num_moe_experts=nemo_model_config.get('num_moe_experts', None), + normalization=transformer_config_normalization, + layernorm_zero_centered_gamma=layernorm_zero_centered_gamma, + ) + return conf + def convert_to_safe_tensors( self, nemo_checkpoint_path: str, @@ -536,6 +539,7 @@ def convert_to_safe_tensors( use_embedding_sharing: bool = False, dtype: str = "bfloat16", ): + """Convert to safe tensor""" gpus_per_node = tensor_parallelism_size if gpus_per_node is None else gpus_per_node if Path(self.model_dir).exists(): @@ -601,6 +605,167 @@ def convert_to_safe_tensors( if tensorrt_llm.mpi_world_size() > 1: tensorrt_llm.mpi_barrier() + def gather_and_reshard_model(self, model_config, model, storage_dtype): + """ + Accumulate all vp model chunks together, and reshard model (i.e) gather all pp ranks + if required and return the final model state dict + """ + + def _get_layer_index(split_key): + for index, key in enumerate(split_key): + if key == "layers": + return index + 1 + raise ValueError(f"Unknown layer name format: {split_key}") + + def rename_layer_num(param_name, layer_num): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + split_key[layer_index] = str(layer_num) + return ".".join(split_key) + + def get_layer_num(param_name): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + return int(split_key[layer_index]) + + from megatron.core import parallel_state + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_first_rank = parallel_state.get_pipeline_model_parallel_first_rank() + pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + pp_group = parallel_state.get_pipeline_model_parallel_group() + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if not vp_size: + vp_size = 1 + + inference_tp_size = self.tp_size + inference_pp_size = self.pp_size + reshard_model = False + if inference_tp_size != tp_size or inference_pp_size != pp_size: + LOGGER.info("Training/Generation model parallelism resharding enabled") + if inference_pp_size == 1 and pp_size > 1 and inference_tp_size == tp_size: + reshard_model = True + else: + raise NotImplementedError( + f"NeMo currently only supports PP>1 -> PP=1 resharding, other types of resharding will come in future releases." + ) + + num_layers = model_config["num_layers"] + layers_per_pp = num_layers // pp_size + layers_per_chunk = layers_per_pp // vp_size + + tl_params = {} + model_level_params = {} + if vp_size > 1: # consolidate params across model chunks + for idx, model_chunk in enumerate(model): + for key, val in model_chunk.state_dict().items(): + if torch.is_tensor(val): + if 'layers' in key: + key2 = rename_layer_num(key, get_layer_num(key) + idx * pp_size * layers_per_chunk) + tl_params[key2] = val + else: + model_level_params[key] = val + else: + for key, val in model.state_dict().items(): + if torch.is_tensor(val): + if 'decoder.layers' in key: + tl_params[key] = val + else: + model_level_params[key] = val + + if vp_size > 1 or reshard_model: + # gather layers across pp ranks + gathered_params = {} + for key, val in tl_params.items(): + weight_list = [torch.zeros_like(val) for _ in range(pp_size)] + torch.distributed.all_gather(weight_list, val, group=pp_group) + for idx in range(pp_size): + layer_num = get_layer_num(key) + idx * layers_per_chunk + key2 = rename_layer_num(key, layer_num) + if not reshard_model: # Save only layers of 1 single PP stage + layers_start = layers_per_pp * pp_rank + layers_end = layers_per_pp * (pp_rank + 1) - 1 + if layer_num >= layers_start and layer_num <= layers_end: + key2 = rename_layer_num(key, layer_num % layers_per_pp) + gathered_params[key2] = weight_list[idx] + else: + gathered_params[key2] = weight_list[idx] + tl_params = gathered_params + + model_state_dict = model_level_params + model_state_dict.update(tl_params) + + def get_tensor_if_available(key, pp_src_idx, group): + tensor = model_state_dict.get(key) + if tensor is not None: + tensor_shape = [tensor.shape] + else: + tensor_shape = [None] + + torch.distributed.broadcast_object_list(tensor_shape, pp_src_idx, group=group) + + if tensor_shape[0] is None: + return None + if torch.distributed.get_rank() != pp_src_idx: + tensor = torch.empty(tensor_shape[0], dtype=storage_dtype).cuda() + + torch.distributed.broadcast(tensor.contiguous(), pp_src_idx, group=pp_group) + return tensor + + if reshard_model: + key = 'decoder.final_layernorm.weight' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'decoder.final_layernorm.bias' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'embedding.word_embeddings.weight' + tensor = get_tensor_if_available(key, pp_first_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'output_layer.weight' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + return model_state_dict + + def get_input_dtype(self, storage_dtype): + """ + Return mcore export dtype given torch dtype + """ + from megatron.core.export.data_type import DataType + + if storage_dtype == torch.bfloat16: + return DataType.bfloat16 + elif storage_dtype == torch.float32: + return DataType.float32 + elif storage_dtype == torch.float16: + return DataType.float16 + + def get_nemo_to_trtllm_conversion_dict(self, model_state_dict): + """MCore export supports some default conversion dictionaries + All Mcore conversion dicts start with "decoder.layers.4.blah.blah" , while nemo models sometimes start with "model.decoder.layers.4.blahblah". so we append model prefix. to the keys + """ + from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import DEFAULT_CONVERSION_DICT + + model_prefix, _ = get_layer_prefix(layer_names=model_state_dict.keys(), is_mcore=True) + + nemo_model_conversion_dict = {} + for key, value in DEFAULT_CONVERSION_DICT.items(): + if 'layers' in key: + nemo_model_conversion_dict[f'{model_prefix}.{key}'] = value + else: + nemo_model_conversion_dict[key] = value + return nemo_model_conversion_dict + def build( self, model, @@ -613,6 +778,7 @@ def build( max_batch_size: int = 4, use_refit: bool = True, reshard_model: bool = False, + use_mcore_path: bool = True, ): """ Convert a model parallel nemo model to TensorRT-LLM. @@ -627,31 +793,103 @@ def build( if self.dp_size > 1: self.model_dir = os.path.join(self.model_dir, f"dp_rank{self.dp_rank}") - weights, model_config = model_to_trtllm_ckpt( - model=model, - nemo_model_config=model_config, - nemo_export_dir=self.model_dir, - decoder_type=model_type, - tensor_parallel_size=self.tp_size, - pipeline_parallel_size=self.pp_size, - gpus_per_node=gpus_per_node, - use_parallel_embedding=True, - use_distributed_convert=True, - model_parallel_rank=self.mp_rank, - vocab_size=self.tokenizer.vocab_size, - ) + if use_mcore_path: + from megatron.core.export.model_type import ModelType + from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper + from tensorrt_llm.layers import MoeConfig + + storage_dtype = torch_dtype_from_precision(model_config.precision) + model_state_dict = self.gather_and_reshard_model(model_config, model, storage_dtype) + # We build the transformer config using the nemo model config. + transformer_config = self.get_transformer_config(model_config) + input_model_type = getattr(ModelType, model_type) + + nemo_model_conversion_dict = self.get_nemo_to_trtllm_conversion_dict(model_state_dict) + + self.trtllm_helper = TRTLLMHelper( + transformer_config=transformer_config, + model_type=input_model_type, + trtllm_conversion_dict=nemo_model_conversion_dict, + position_embedding_type=model_config.get('position_embedding_type'), + max_position_embeddings=model_config.get('max_position_embeddings'), + rotary_percentage=model_config.get('rotary_percentage', 1.0), + rotary_base=model_config.get('rotary_base', 10000), + moe_tp_mode=model_config.get('moe_tp_mode', 2), + multi_query_mode=model_config.get("multi_query_mode", False), + activation=model_config.get('activation', "gelu"), + seq_len_interpolation_factor=model_config.get("seq_len_interpolation_factor"), + moe_renorm_mode=model_config.get( + 'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE + ), + share_embeddings_and_output_weights=model_config.get("share_embeddings_and_output_weights", False), + ) + + input_dtype = self.get_input_dtype(storage_dtype) + + trtllm_model_weights_list, trtllm_model_config_list = ( + self.trtllm_helper.get_trtllm_pretrained_config_and_model_weights( + model_state_dict=model_state_dict, + dtype=input_dtype, + state_dict_split_by_layer_numbers=True, + on_device_distributed_conversion=True, + vocab_size=self.tokenizer.vocab_size, + gpus_per_node=gpus_per_node, + ) + ) + trtllm_model_config = trtllm_model_config_list[0] + trtllm_model_weights = trtllm_model_weights_list[0] + + if reshard_model: + assert self.pp_size == 1, 'Reshard is true, but pp size is not one' + # MCORE Export will use parallel_state to determine pp . + # Since we reshard to pp = 1, we need to modify the config and mapping + world_size = self.tp_size * self.pp_size + trtllm_model_config.pp_size = self.pp_size + trtllm_model_config.world_size = world_size + trtllm_model_config.mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=self.mp_rank, + tp_size=self.tp_size, + pp_size=self.pp_size, + ) + + engine = self.trtllm_helper.build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_seq_len=max_input_len + max_output_len, + max_batch_size=max_batch_size, + trtllm_model_config=trtllm_model_config, + trtllm_model_weights=trtllm_model_weights, + engine_dir=self.model_dir, + use_refit=use_refit, + ) + else: + weights, model_config = model_to_trtllm_ckpt( + model=model, + nemo_model_config=model_config, + nemo_export_dir=self.model_dir, + decoder_type=model_type, + tensor_parallel_size=self.tp_size, + pipeline_parallel_size=self.pp_size, + gpus_per_node=gpus_per_node, + use_parallel_embedding=True, + use_distributed_convert=True, + model_parallel_rank=self.mp_rank, + vocab_size=self.tokenizer.vocab_size, + ) + + engine = build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_seq_len=max_input_len + max_output_len, + max_batch_size=max_batch_size, + model_config=model_config[0], + model_weights=weights[0], + model_dir=self.model_dir, + model_type=model_type, + use_refit=use_refit, + ) - engine = build_and_save_engine( - max_input_len=max_input_len, - max_output_len=max_output_len, - max_seq_len=max_input_len + max_output_len, - max_batch_size=max_batch_size, - model_config=model_config[0], - model_weights=weights[0], - model_dir=self.model_dir, - model_type=model_type, - use_refit=use_refit, - ) torch.distributed.barrier() cfg_path = Path(os.path.join(self.model_dir, f'config_{torch.distributed.get_rank()}.json')) @@ -660,18 +898,33 @@ def build( load_distributed(self.model_dir, self.mp_rank, gpus_per_node) - def refit(self, model, model_config): + def refit(self, model, model_config, use_mcore_path=True): """ Refits an TensorRT engine using an instantiated nemo model. This function should only be used after calling build() """ - weights_dict = dist_model_to_trt_llm_ckpt( - model=model, - nemo_model_config=model_config, - inference_tp_size=self.tp_size, - inference_pp_size=self.pp_size, - tokenizer_vocab_size=self.tokenizer.vocab_size, - ) + weights_dict = None + if use_mcore_path: + storage_dtype = torch_dtype_from_precision(model_config.precision) + + model_state_dict = self.gather_and_reshard_model(model_config, model, storage_dtype) + + nemo_model_conversion_dict = self.get_nemo_to_trtllm_conversion_dict(model_state_dict) + self.trtllm_helper.weights_converter.convert( + model_state_dict=model_state_dict, + tokenizer_vocab_size=self.tokenizer.vocab_size, + trtllm_conversion_dict=nemo_model_conversion_dict, + ) + weights_dict = self.trtllm_helper.weights_converter.trtllm_model_weights + + else: + weights_dict = dist_model_to_trt_llm_ckpt( + model=model, + nemo_model_config=model_config, + inference_tp_size=self.tp_size, + inference_pp_size=self.pp_size, + tokenizer_vocab_size=self.tokenizer.vocab_size, + ) load_distributed(self.model_dir, self.mp_rank, self.gpus_per_node) gc.collect() torch.cuda.empty_cache() @@ -815,6 +1068,7 @@ def forward( ) def add_prompt_table(self, task_name: str, prompt_embeddings_checkpoint_path: str): + """Add prompt table""" if self.model is None: raise Exception( "A nemo checkpoint should be exported to TensorRT-LLM and " @@ -836,6 +1090,7 @@ def add_prompt_table(self, task_name: str, prompt_embeddings_checkpoint_path: st self._prep_ptuning_table() def remove_prompt_table(self, task_name: str): + """Remove prompt table""" if self.ptuning_tables is not None: for i in range(len(self.ptuning_tables)): if self.ptuning_tables[i]["task_name"] == task_name: @@ -847,11 +1102,13 @@ def remove_prompt_table(self, task_name: str): @property def get_supported_models_list(self): + """Supported model list""" # gpt and gptnext are the same. Keeping the gptnext due to backward compatibility. return ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"] @property def get_hidden_size(self): + """Get hidden size""" if self.config is None: return None else: @@ -859,6 +1116,7 @@ def get_hidden_size(self): @property def get_triton_input(self): + """Get triton input""" inputs = ( Tensor(name="prompts", shape=(-1,), dtype=bytes), Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), @@ -885,6 +1143,7 @@ def get_triton_output(self): @batch def triton_infer_fn(self, **inputs: np.ndarray): + """Triton infer function for streaming""" output_dict = {} try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} @@ -929,6 +1188,7 @@ def triton_infer_fn(self, **inputs: np.ndarray): @batch def triton_infer_fn_streaming(self, **inputs: np.ndarray): + """Triton infer function for streaming""" try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} if "max_output_len" in inputs: @@ -1138,4 +1398,5 @@ def _load(self): ) from error def unload_engine(self): + """Unload engine""" unload_engine() diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 9ace6425f533..f3cb73811af1 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -318,14 +318,14 @@ def build_tokenizer(tokenizer): tokenizer.add_special_tokens({"eos_token": ""}) else: # For NeMo tokenizers, monkey patch encode & batch_decode methods for unified interface - from nemo.collections.common.tokenizers import AutoTokenizer, SentencePieceTokenizer, TokenizerSpec + import nemo.collections.common.tokenizers as nemo_tokenizers - if isinstance(tokenizer, TokenizerSpec): - if isinstance(tokenizer, AutoTokenizer): + if isinstance(tokenizer, nemo_tokenizers.TokenizerSpec): + if isinstance(tokenizer, nemo_tokenizers.AutoTokenizer): # Unwrap the original methods of HF tokenizer batch_decode = tokenizer.tokenizer.batch_decode encode = tokenizer.tokenizer.encode - elif isinstance(tokenizer, SentencePieceTokenizer): + elif isinstance(tokenizer, nemo_tokenizers.SentencePieceTokenizer): # Define HF equivalents based on available SP methods def batch_decode(self, ids): if torch.is_tensor(ids): @@ -340,8 +340,8 @@ def batch_decode(self, ids): tokenizer.bos_token_id = tokenizer.bos_id tokenizer.eos_token_id = tokenizer.eos_id - TokenizerSpec.encode = encode - TokenizerSpec.batch_decode = batch_decode + nemo_tokenizers.TokenizerSpec.encode = encode + nemo_tokenizers.TokenizerSpec.batch_decode = batch_decode return tokenizer diff --git a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py index eac1ab743849..f601c8cb1c5a 100644 --- a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +++ b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py @@ -18,7 +18,6 @@ import warnings from typing import List, Optional -import tensorrt_llm from tensorrt_llm.models import PretrainedConfig from nemo.export.trt_llm.qnemo.utils import CONFIG_NAME, WEIGHTS_NAME @@ -51,7 +50,7 @@ def qnemo_to_tensorrt_llm( warnings.warn( "Note that setting tensor_parallel_size, pipeline_parallel_size and use_parallel_embedding " - " parameters for quantized models is done on calibration step with nemo.export.quantize module." + " parameters for quantized models is done on the calibration step (in PTQ workflow)." " These parameters are ignored when building and running TensorRT-LLM engine below.", UserWarning, stacklevel=3, @@ -93,11 +92,7 @@ def qnemo_to_tensorrt_llm( build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} " build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} " build_cmd += f"--reduce_fusion {'enable' if reduce_fusion else 'disable'} " - # TODO: resolve version check for setting use_fused_mlp once we move to 0.13.0 in the NeMo container - if tensorrt_llm.__version__ >= "0.13.0": - build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} " - else: - build_cmd += "--use_fused_mlp " if use_fused_mlp else "" + build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} " if not use_qdq: build_cmd += f"--gemm_plugin auto " diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index beca40bcd3d7..5a6b6280d7c1 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from omegaconf import OmegaConf @@ -24,22 +25,23 @@ TOKENIZER_CONFIG_FILE = "tokenizer_config.yaml" TOKENIZER_DIR = "tokenizer" +LOGGER = logging.getLogger("NeMo") def get_nmt_tokenizer(nemo_checkpoint_path: str): """Build tokenizer from Nemo tokenizer config.""" - print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") + LOGGER.info(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)) library = tokenizer_cfg.library legacy = tokenizer_cfg.get("sentencepiece_legacy", library == "sentencepiece") if library == "huggingface": - print(f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_cfg.type}") + LOGGER.info(f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_cfg.type}") tokenizer = AutoTokenizer.from_pretrained(tokenizer_cfg["type"], use_fast=tokenizer_cfg.get("use_fast", False)) elif library == "sentencepiece": - print(f"Getting SentencePieceTokenizer with model: {tokenizer_cfg.model}") + LOGGER.info(f"Getting SentencePieceTokenizer with model: {tokenizer_cfg.model}") tokenizer = SentencePieceTokenizer( model_path=os.path.join(nemo_checkpoint_path, tokenizer_cfg.model), legacy=legacy ) diff --git a/nemo/export/vllm_exporter.py b/nemo/export/vllm_exporter.py index 0ce7d49126d3..97575058bd1c 100644 --- a/nemo/export/vllm_exporter.py +++ b/nemo/export/vllm_exporter.py @@ -222,7 +222,6 @@ def export( max_num_seqs=256, # Note: max_model_len can be derived by model_config if the input value is None max_model_len=model_config.max_model_len, - use_v2_block_manager=False, num_lookahead_slots=0, delay_factor=0.0, enable_chunked_prefill=False, @@ -403,6 +402,7 @@ def get_triton_input(self): Tensor(name="top_p", shape=(-1,), dtype=numpy.single, optional=True), Tensor(name="temperature", shape=(-1,), dtype=numpy.single, optional=True), Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True), + Tensor(name="output_generation_logits", shape=(-1,), dtype=numpy.bool_, optional=True), ) return inputs @@ -455,6 +455,7 @@ def forward( prompt_embeddings_checkpoint_path: Optional[str] = None, streaming: bool = False, output_log_probs: bool = False, + output_generation_logits: bool = False, ) -> Union[List[List[str]], Iterable[List[List[str]]]]: """ The forward function performs LLM evaluation on the provided array of prompts with other parameters shared, @@ -484,6 +485,9 @@ def forward( if output_log_probs: raise NotImplementedError("output_log_probs is not supported") + if output_generation_logits: + raise NotImplementedError("output_generation_logits is not supported") + request_ids = [] for index in range(len(input_texts)): prompt = input_texts[index] diff --git a/nemo/lightning/io/state.py b/nemo/lightning/io/state.py index 6632768ec8dd..f2c26aa4d495 100644 --- a/nemo/lightning/io/state.py +++ b/nemo/lightning/io/state.py @@ -242,7 +242,12 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX: source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} target_matches = _match_keys(list(target_dict.keys()), target_key) param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) - for layer_names_group in zip(*([source_matches_dict[v] for v in param_names] + [target_matches])): + source_matches = [ + source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] + for v in param_names + ] + target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] + for layer_names_group in zip(*(source_matches + target_matches)): # Wrap in a list if it's a single layer (ie non-expert) if isinstance(layer_names_group[0], str): layer_names_group = [[x] for x in layer_names_group] diff --git a/nemo/lightning/pytorch/accelerate/transformer_engine.py b/nemo/lightning/pytorch/accelerate/transformer_engine.py index a96a9590ea75..8e621352d099 100755 --- a/nemo/lightning/pytorch/accelerate/transformer_engine.py +++ b/nemo/lightning/pytorch/accelerate/transformer_engine.py @@ -39,7 +39,7 @@ def te_accelerate(model, fp8_autocast=False): @torch.no_grad def _apply_basic_module_replacement(model): - for name, module in model.named_modules(): + for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): has_bias = module.bias is not None if any(p % 16 != 0 for p in module.weight.shape): @@ -51,17 +51,19 @@ def _apply_basic_module_replacement(model): if has_bias: te_module.bias.copy_(module.bias) - setattr(module, name.split(".")[-1], te_module) + setattr(model, name, te_module) elif isinstance(module, torch.nn.LayerNorm): te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype) te_module.weight.copy_(module.weight) te_module.bias.copy_(module.bias) - setattr(module, name.split(".")[-1], te_module) + setattr(model, name, te_module) elif isinstance(module, torch.nn.RMSNorm): te_module = te.RMSNorm(module.normalized_shape[0], eps=module.eps, dtype=module.weight.dtype) te_module.weight.copy_(module.weight) te_module.bias.copy_(module.bias) - setattr(module, name.split(".")[-1], te_module) + setattr(model, name, te_module) + else: + _apply_basic_module_replacement(module) def is_te_accelerated(model): diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt index 18abe82c9f96..aa33b3b55127 100644 --- a/requirements/requirements_multimodal.txt +++ b/requirements/requirements_multimodal.txt @@ -5,7 +5,7 @@ diffusers>=0.19.3 einops_exts imageio kornia -megatron-energon +megatron-energon<3.0.0 nerfacc>=0.5.3 open_clip_torch==2.24.0 PyMCubes diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt index 414e05078680..6f5c8880f632 100644 --- a/requirements/requirements_vllm.txt +++ b/requirements/requirements_vllm.txt @@ -1 +1,19 @@ -vllm==0.5.3.post1 +# Minimal set of NeMo requirements to run vLLM export & deployment in /opt/venv in a NeMo container +braceexpand +faiss-cpu +h5py +hydra-core>1.3,<=1.3.2 +ijson +jieba +lightning>2.2.1 +matplotlib>=3.3.2 +omegaconf<=2.3 +onnx>=1.7.0 +OpenCC +pangu +rouge_score +sacrebleu +scikit-learn +vllm==0.6.3 +webdataset>=0.2.86 +wget diff --git a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py index 392e3628ccdb..2f66773f8724 100644 --- a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py @@ -21,7 +21,7 @@ import torch from lightning.pytorch import Trainer from transformers import LlamaTokenizer, PreTrainedTokenizerFast -from transformers.convert_slow_tokenizer import LlamaConverter +from transformers.convert_slow_tokenizer import LlamaConverter, TikTokenConverter from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -130,6 +130,20 @@ def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2) +def convert_tiktoken(vocab_file) -> None: + with open(vocab_file, 'r') as f: + vocab = json.load(f) + os.remove(vocab_file) + + lines = [] + for line in vocab: + lines.append(f"{line['token_bytes']} {line['rank']}") + + for line in lines: + with open(vocab_file, 'a') as f: + f.write(line + '\n') + + def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None: """ Convert NeMo weights to HF weights @@ -323,6 +337,28 @@ def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tok ) tokenizer.save_pretrained(output_hf_path) logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}") + elif tokenizer_cfg.library == "tiktoken": + tokenizer_fn = tokenizer_cfg.model[5:] + special_tokens = ["", "", ""] + import tarfile + + archive = tarfile.open(nemo_file, "r") + tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix + archive.extract(tokenizer_filename, output_hf_path) + archive.close() + vocab_file = os.path.join(output_hf_path, tokenizer_fn) + convert_tiktoken(vocab_file) + converted_tokenizer = TikTokenConverter( + vocab_file=vocab_file, additional_special_tokens=special_tokens + ).converted() + os.remove(vocab_file) + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=converted_tokenizer, + model_input_names=["input_ids", "attention_mask"], + bos_token="", + eos_token="", + ) + tokenizer.save_pretrained(output_hf_path) elif isinstance(nemo_tokenizer, AutoTokenizer): nemo_tokenizer.tokenizer.save_pretrained(output_hf_path) logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}") diff --git a/scripts/deploy/nlp/deploy_vllm_triton.py b/scripts/deploy/nlp/deploy_vllm_triton.py index ab9f13a1b8da..a3cf5e8ec762 100755 --- a/scripts/deploy/nlp/deploy_vllm_triton.py +++ b/scripts/deploy/nlp/deploy_vllm_triton.py @@ -41,7 +41,7 @@ def get_args(argv): "-mt", "--model_type", type=str, - required=False, + required=True, choices=["llama", "mistral", "mixtral", "starcoder2", "gemma"], help="Type of the model", ) diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index c04d32290e5f..2afe38c37b4d 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -17,6 +17,8 @@ def get_args(): + """Parses PTQ arguments""" + parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="NeMo PTQ argument parser", @@ -58,6 +60,10 @@ def get_args(): type=str, help='Calibration dataset to be used. Should be \"wikitext\", \"cnn_dailymail\" or path to a local .json file', ) + parser.add_argument( + '--generate_sample', help='Generate sample model output after performing PTQ', action='store_true' + ) + parser.set_defaults(generate_sample=False) args = parser.parse_args() if args.output_path is None: @@ -68,6 +74,8 @@ def get_args(): def main(): + """Example NeMo 2.0 Post Training Quantization workflow""" + args = get_args() quantization_config = quantization.QuantizationConfig( @@ -87,6 +95,7 @@ def main(): inference_tensor_parallel=args.tensor_parallelism_size, inference_pipeline_parallel=args.pipeline_parallelism_size, dtype=args.dtype, + generate_sample=args.generate_sample, ) quantizer = quantization.Quantizer(quantization_config, export_config) diff --git a/scripts/vlm/mllama_finetune.py b/scripts/vlm/mllama_finetune.py new file mode 100644 index 000000000000..2b6990a03aa5 --- /dev/null +++ b/scripts/vlm/mllama_finetune.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger +from transformers import AutoProcessor + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.vlm import ImageDataConfig +from nemo.collections.vlm.mllama.data.lazy import MLlamaLazyDataModule +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + """ + Main function for setting up and training the MLLama model. + + This function prepares the data module, model, training strategy, + logger, checkpointing, and optimizer configuration. It then starts + the training loop using PyTorch Lightning's trainer. + + Args: + args (argparse.Namespace): The command-line arguments passed to the script. + """ + # Setting gbs, mbs, and max_steps from arguments + gbs = args.gbs + mbs = args.mbs + max_steps = args.max_steps + + # encoder (vision) seq length + # ((img_res / patch_size) ** 2 + cls_token) * num_tiles, = ((560 / 14) ** 2 + 1) * 4 = 6404 + seq_length = 6404 + decoder_seq_length = 1024 # decoder (llm) seq length + + if args.restore_path is not None and args.restore_path.startswith("nemo://"): + model_id = args.restore_path[len("nemo://") :] + else: + model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + processor = AutoProcessor.from_pretrained(model_id) + image_processor = processor.image_processor + tokenizer = processor.tokenizer + + # Data configuration + data_config = ImageDataConfig( + image_folder=args.image_folder, + conv_template="mllama", + ) + + # Data module setup + data = MLlamaLazyDataModule( + paths=args.data_path, + data_config=data_config, + seq_length=seq_length, + decoder_seq_length=decoder_seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer, + image_processor=image_processor, + num_workers=16, + ) + + model_configs = { + "meta-llama/Llama-3.2-11B-Vision": vlm.MLlamaConfig11B, + "meta-llama/Llama-3.2-11B-Vision-Instruct": vlm.MLlamaConfig11BInstruct, + "meta-llama/Llama-3.2-90B-Vision": vlm.MLlamaConfig90B, + "meta-llama/Llama-3.2-90B-Vision-Instruct": vlm.MLlamaConfig90BInstruct, + } + conf = model_configs[model_id]() + if args.pp_size > 1: + conf.language_model_config.first_pipeline_num_layers = 0 + model = vlm.MLlamaModel(conf, tokenizer=tokenizer) + + # Training strategy setup + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + encoder_pipeline_model_parallel_size=args.encoder_pp_size, + pipeline_dtype=torch.bfloat16, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=6, + every_n_train_steps=100, + dirpath=args.log_dir, + ) + + # Trainer setup + trainer = nl.Trainer( + num_nodes=args.num_nodes, + devices=args.devices, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=500, + limit_val_batches=gbs, + log_every_n_steps=1, + num_sanity_val_steps=0, + ) + + # Logger setup + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + + # Auto resume setup + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_directory=args.log_dir, + restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None, + ) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=100, + constant_steps=0, + min_lr=args.lr, + ) + opt = MegatronOptimizerModule(opt_config, sched) + + # PEFT setup + if args.peft == 'lora': + peft = vlm.peft.LoRA( + freeze_vision_model=True, + target_modules=[ + "linear_qkv", + "linear_q", + "linear_kv", + ], + dim=8, + alpha=32, + dropout=0.05, + dropout_position="pre", + ) + else: + peft = None + + llm.finetune( + model=model, + data=data, + trainer=trainer, + peft=peft, + log=nemo_logger, + optim=opt, + resume=resume, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mllama Model Training Script") + + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset") + parser.add_argument("--image_folder", type=str, required=True, help="Path to the image folder") + parser.add_argument( + "--log_dir", + type=str, + required=False, + default="/results", + help="Directory for logging and checkpoints", + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--num_nodes", type=int, required=False, default=1) + parser.add_argument("--max_steps", type=int, required=False, default=5190) + parser.add_argument("--tp_size", type=int, required=False, default=1) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + parser.add_argument("--name", type=str, required=False, default="neva_pretrain") + parser.add_argument("--peft", type=str, default='none', help="none | lora") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--gbs", type=int, required=False, default=64, help="Global batch size") + parser.add_argument("--mbs", type=int, required=False, default=2, help="Micro batch size") + parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate") + + args = parser.parse_args() + main(args) diff --git a/scripts/vlm/mllama_generation.py b/scripts/vlm/mllama_generation.py new file mode 100644 index 000000000000..4ebf2d0055ad --- /dev/null +++ b/scripts/vlm/mllama_generation.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import requests +import torch +from PIL import Image +from transformers import AutoProcessor + +from nemo import lightning as nl +from nemo.collections import vlm +from nemo.collections.vlm.mllama.model.utils import create_vision_mask_tensor + +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + +def load_image(image_url: str) -> Image.Image: + # pylint: disable=C0115,C0116 + try: + response = requests.get(image_url, stream=True) + response.raise_for_status() + image = Image.open(response.raw) + return image + except requests.exceptions.RequestException as e: + print(f"Error loading image from {image_url}: {e}") + return None + + +def generate(model, processor, image, text): + # pylint: disable=C0115,C0116 + tokenizer = processor.tokenizer + + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": text}], + } + ] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + batch = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + input_ids = batch["input_ids"].cuda(non_blocking=True) + position_ids = ( + torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) + ) + num_tiles = processor.image_processor.preprocess(image, return_tensors='pt')["num_tiles"] + + min_prompt_len = position_ids.shape[-1] + + input_ids = input_ids[:, :min_prompt_len] + generated_ids = input_ids.clone() + + from tqdm import tqdm + + for cur_pos in tqdm(range(min_prompt_len, min_prompt_len + 100)): + with torch.no_grad(): + position_ids = torch.arange(0, cur_pos, dtype=torch.long, device="cuda").reshape(1, -1) + batch_masks = create_vision_mask_tensor(generated_ids[0]) + + output = model( + batch_images=batch["pixel_values"].cuda(non_blocking=True), + batch_masks=[batch_masks], + num_chunks=torch.tensor(num_tiles), + aspect_ratio_ids=batch["aspect_ratio_ids"].cuda(non_blocking=True), + tokens=generated_ids, + position_ids=position_ids, + ) + + next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) + # Broadcast the tensor from rank 0 to all other ranks + torch.distributed.broadcast(next_token_ids, src=0) + generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) + if (next_token_ids == tokenizer.eos_token_id).all(): + break + + generated_ids = generated_ids.tolist() + generated_texts = tokenizer.decode(generated_ids[0][min_prompt_len:]) + + if torch.distributed.get_rank() == 0: + print("======== GENERATED TEXT OUTPUT ========") + print(f"{generated_texts}") + print("=======================================") + return generated_texts + + +def main(args) -> None: + # pylint: disable=C0115,C0116 + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + ckpt_load_optimizer=False, + ckpt_save_optimizer=False, + ) + trainer = nl.Trainer( + devices=args.tp_size, + max_steps=1000, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + val_check_interval=1000, + limit_val_batches=50, + ) + + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = processor.tokenizer + + fabric = trainer.to_fabric() + + if args.load_from_hf: + model = fabric.import_model(f"hf://{model_id}", vlm.MLlamaModel) + else: + model = vlm.MLlamaModel(vlm.MLlamaConfig11BInstruct(), tokenizer=tokenizer) + model = fabric.load_model(args.local_model_path, model) + + model = model.module.cuda() + model.eval() + model = model.to(torch.bfloat16) + + # Load the image + raw_image = load_image(args.image_url) + if raw_image is None: + return # Exit if the image can't be loaded + + generate(model, processor, image=raw_image, text="<|image|>\nDescribe the image.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--load_from_hf", + action="store_true", + help="Flag to indicate whether to load the model from Hugging Face hub.", + ) + parser.add_argument( + "--local_model_path", + type=str, + default=None, + help="Local path to the model if not loading from Hugging Face.", + ) + parser.add_argument( + "--image_url", + type=str, + # pylint: disable=line-too-long + default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + help="URL of the image to use for inference.", + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--tp_size", type=int, required=False, default=1) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + + args = parser.parse_args() + main(args) diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 247906247091..02442291a918 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -19,9 +19,12 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig from nemo.collections.asr.data import audio_to_text +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import configs from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.parts.submodules import ctc_beam_decoding as beam_decode @@ -118,6 +121,18 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.unit def test_save_restore_artifact(self, asr_model): @@ -333,6 +348,15 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self): 'bucketing_strategy', 'bucketing_weights', 'channel_selector', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'} @@ -372,6 +396,15 @@ def test_ASRDatasetConfig_for_TarredAudioToBPEDataset(self): 'bucketing_strategy', 'bucketing_weights', 'max_utts', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = { diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index 28a07fd54663..55451758578f 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -15,12 +15,16 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, OmegaConf, open_dict import nemo.collections.asr as nemo_asr from nemo.collections.asr.data import audio_to_text +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.utils.config_utils import assert_dataclass_signature_match, update_model_config @@ -131,6 +135,19 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + token_list = [" ", "a", "b", "c"] + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.unit def test_vocab_change(self, asr_model): old_vocab = copy.deepcopy(asr_model.decoder.vocabulary) @@ -274,6 +291,15 @@ def test_ASRDatasetConfig_for_AudioToCharDataset(self): 'bucketing_strategy', 'bucketing_weights', 'channel_selector', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = {'trim_silence': 'trim'} @@ -307,6 +333,15 @@ def test_ASRDatasetConfig_for_TarredAudioToCharDataset(self): 'bucketing_strategy', 'bucketing_weights', 'max_utts', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = { diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index 1743acc6878c..d13c879e47f9 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -18,8 +18,11 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode @@ -166,6 +169,18 @@ def test_forward(self, hybrid_asr_model): diff = torch.max(torch.abs(logits_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, hybrid_asr_model): + hybrid_asr_model = hybrid_asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=hybrid_asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = hybrid_asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 5362966e2e9e..b5c34e197237 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -16,14 +16,18 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, ListConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecHybridRNNTCTCModel from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint, SampledRNNTJoint, StatelessTransducerDecoder from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ from nemo.utils.config_utils import assert_dataclass_signature_match @@ -164,6 +168,19 @@ def test_forward(self, hybrid_asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, hybrid_asr_model): + token_list = [" ", "a", "b", "c"] + hybrid_asr_model = hybrid_asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = hybrid_asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', diff --git a/tests/collections/asr/test_asr_lhotse_dataset.py b/tests/collections/asr/test_asr_lhotse_dataset.py index 5a1450e606ac..c131fac70310 100644 --- a/tests/collections/asr/test_asr_lhotse_dataset.py +++ b/tests/collections/asr/test_asr_lhotse_dataset.py @@ -65,3 +65,35 @@ def test_lhotse_asr_dataset(tokenizer): assert tokens[2].tolist() == [1, 7, 10, 19, 20, 21, 1, 20, 6, 4, 16, 15, 5] assert token_lens.tolist() == [11, 11, 13] + + +def test_lhotse_asr_dataset_metadata(tokenizer): + + cuts = DummyManifest(CutSet, begin_id=0, end_id=2, with_data=True) + + cuts[0].id = "cuts0" + cuts[1].id = "cuts1" + cuts[0].supervisions = [ + SupervisionSegment(id="cuts0-sup0", recording_id=cuts[0].recording_id, start=0.2, duration=0.5, text="first"), + ] + cuts[1].supervisions = [ + SupervisionSegment(id="cuts1-sup0", recording_id=cuts[1].recording_id, start=0, duration=1, text=""), + ] + + datasets_metadata = LhotseSpeechToTextBpeDataset(tokenizer=tokenizer, return_cuts=True) + batch = datasets_metadata[cuts] + assert isinstance(batch, tuple) + assert len(batch) == 5 + + _, _, _, _, cuts_metadata = batch + + assert cuts_metadata[0].supervisions[0].text == "first" + assert cuts_metadata[1].supervisions[0].text == "" + assert cuts_metadata[0].id == "cuts0" + assert cuts_metadata[1].id == "cuts1" + + assert cuts_metadata[0].supervisions[0].duration == 0.5 + assert cuts_metadata[0].supervisions[0].start == 0.2 + + assert cuts_metadata[1].supervisions[0].duration == 1 + assert cuts_metadata[1].supervisions[0].start == 0.0 diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index d68088fce376..5e810243c919 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -17,13 +17,17 @@ import pytest import torch import torch.nn.functional as F +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, ListConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecRNNTModel from nemo.collections.asr.modules import HATJoint, RNNTDecoder, RNNTJoint, SampledRNNTJoint, StatelessTransducerDecoder from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ from nemo.utils.config_utils import assert_dataclass_signature_match @@ -296,6 +300,19 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + token_list = [" ", "a", "b", "c"] + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index 960445061e24..aba364868e88 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -18,8 +18,11 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import ASRModel from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode @@ -64,12 +67,18 @@ def asr_model(test_data_dir): decoder = { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', - 'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,}, + 'prednet': { + 'pred_hidden': model_defaults['pred_hidden'], + 'pred_rnn_layers': 1, + }, } joint = { '_target_': 'nemo.collections.asr.modules.RNNTJoint', - 'jointnet': {'joint_hidden': 32, 'activation': 'relu',}, + 'jointnet': { + 'joint_hidden': 32, + 'activation': 'relu', + }, } decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} @@ -123,7 +132,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): class TestEncDecRNNTBPEModel: @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit @@ -137,7 +147,8 @@ def test_constructor(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, asr_model): @@ -170,9 +181,22 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logits_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact(self, asr_model): @@ -190,7 +214,8 @@ def test_save_restore_artifact(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact_spe(self, asr_model, test_data_dir): @@ -236,7 +261,8 @@ def test_save_restore_artifact_agg(self, asr_model, test_data_dir): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, test_data_dir, asr_model): @@ -266,7 +292,8 @@ def test_vocab_change(self, test_data_dir, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, asr_model): @@ -309,7 +336,8 @@ def test_decoding_change(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.unit @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) def test_save_restore_nested_model(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -330,7 +358,7 @@ def test_save_restore_nested_model(self): # Check size of the checkpoint, which contains weights from pretrained model + linear layer fp_weights = os.path.join(tmp_dir, 'model_weights.ckpt') - assert os.path.getsize(fp_weights) > 50 * (2 ** 20) # Assert the weights are more than 50 MB + assert os.path.getsize(fp_weights) > 50 * (2**20) # Assert the weights are more than 50 MB # Check if param after restoration is exact match original_state_dict = model.inner_model.state_dict() diff --git a/tests/collections/llm/peft/lora_merge.py b/tests/collections/llm/peft/lora_merge.py new file mode 100644 index 000000000000..2ca7390ea7e6 --- /dev/null +++ b/tests/collections/llm/peft/lora_merge.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from dataclasses import dataclass + +from nemo.collections import llm + + +@dataclass +class Llama3ConfigCI(llm.Llama3Config8B): + seq_length: int = 2048 + num_layers: int = 2 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 8 + + +def get_args(): + parser = argparse.ArgumentParser(description='Merge LoRA weights with base LLM') + parser.add_argument('--lora_checkpoint_path', type=str, help="Path to finetuned LORA checkpoint") + parser.add_argument('--output_path', type=str, help="Path to save merged checkpoint") + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + + llm.peft.merge_lora( + lora_checkpoint_path=args.lora_checkpoint_path, + output_path=args.output_path, + ) diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index e929f2601022..16aca9ccea4b 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -849,6 +849,9 @@ def run_inference_tests(args): "Use the same value for --min_tps and --max_tps." ) + if args.debug: + LOGGER.setLevel(logging.DEBUG) + result_dic: Dict[int, Tuple[FunctionalResult, Optional[AccuracyResult]]] = {} if args.existing_test_models: diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb index cd3bae1cc627..aa463e2b84be 100644 --- a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb +++ b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb @@ -499,6 +499,31 @@ "```" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5. Calculate Evaluation Metrics\n", + "\n", + "We can evaluate the model's predictions by calculating the Exact Match (EM) and F1 scores.\n", + "- Exact Match is a binary measure (0 or 1) checking if the model outputs match one of the\n", + "ground truth answer exactly.\n", + "- F1 score is the harmonic mean of precision and recall for the answer words.\n", + "\n", + "Below is a script that computes these metrics. The sample scores can be improved by training the model further and performing hyperparameter tuning. In this notebook, we only train for 20 steps.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/metric_calculation/peft_metric_calc.py --pred_file peft_prediction.jsonl --label_field \"original_answers\" --pred_field \"prediction\"" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb index 479d81928e98..e84ff916fc4e 100644 --- a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb +++ b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb @@ -606,6 +606,31 @@ "{\"input\": \"Muckle Water is a long, narrow fresh water loch on Ward Hill on Rousay, Orkney, Scotland. It is the biggest loch on the island and is popular for fishing. It can be reached by a track from the roadside. The Suso Burn on the north eastern shore drains the loch into the Sound of Rousay.\\n\\nWhere is Muckle Water?\", \"category\": \"closed_qa\", \"label\": \"Muckle water is located in Rousay, Orkney, Scotland.\", \"prediction\": \" Muckle Water is a long, narrow fresh water loch on Ward Hill on Rousay,\"}\n", "```" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5. Calculate Evaluation Metrics\n", + "\n", + "We can evaluate the model's predictions by calculating the Exact Match (EM) and F1 scores.\n", + "- Exact Match is a binary measure (0 or 1) checking if the model outputs match one of the\n", + "ground truth answer exactly.\n", + "- F1 score is the harmonic mean of precision and recall for the answer words.\n", + "\n", + "Below is a script that computes these metrics. The sample scores can be improved by training the model further and performing hyperparameter tuning. In this notebook, we only train for 20 steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/metric_calculation/peft_metric_calc.py --pred_file sft_prediction.jsonl --label_field \"label\" --pred_field \"prediction\"" + ] } ], "metadata": {