diff --git a/examples/vlm/llava_next_finetune_energon.py b/examples/vlm/llava_next_finetune_energon.py new file mode 100644 index 000000000000..5c2ce9188e3a --- /dev/null +++ b/examples/vlm/llava_next_finetune_energon.py @@ -0,0 +1,190 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from transformers import AutoProcessor + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule +from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig +from nemo.collections.vlm import ImageDataConfig, LlavaNextTaskEncoder +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + # Global and micro batch sizes + gbs = 2 + mbs = 2 + seq_length = 4096 + + processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") + data_path = args.data_path + image_processor = processor.image_processor + # tokenizer = processor.tokenizer + tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf") + + multimodal_sample_config = MultiModalSampleConfig() + + task_encoder = LlavaNextTaskEncoder( + tokenizer=tokenizer.tokenizer, + image_processor=image_processor, + multimodal_sample_config=multimodal_sample_config, + seq_length=seq_length, + ) + data = SimpleMultiModalDataModule( + path=data_path, + tokenizer=tokenizer, + image_processor=image_processor, + num_workers=0, + micro_batch_size=mbs, + global_batch_size=gbs, + multimodal_sample_config=multimodal_sample_config, + task_encoder=task_encoder, + ) + + # Transformer configurations + language_transformer_config = llm.Llama2Config7B() + vision_transformer_config = vlm.HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = vlm.MultimodalProjectorConfig( + projector_type=args.projector_type, + input_size=1024, + hidden_size=4096, + ffn_hidden_size=4096, + ) + + # NEVA model configuration + neva_config = vlm.NevaConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + language_model_from_pretrained=args.language_model_path, + freeze_language_model=False, + is_llava_next=True, + ) + + model = vlm.NevaModel(neva_config, tokenizer=data.tokenizer) + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + pipeline_dtype=torch.bfloat16, + ckpt_load_optimizer=True, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=2, + every_n_train_steps=111, + dirpath=args.log_dir, + ) + + # Trainer setup + trainer = nl.Trainer( + devices=args.devices, + max_steps=5190, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=211, # 1000, + limit_val_batches=gbs, # gbs, + log_every_n_steps=1, + num_sanity_val_steps=0, + ) + + # Logger setup + from pytorch_lightning.loggers import WandbLogger + + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + nemo_logger.setup( + trainer, + resume_if_exists=True, + ) + + # Auto resume setup + from nemo.lightning.pytorch.strategies.utils import RestoreConfig + + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + # resume_from_directory=args.log_dir, + restore_config=( + RestoreConfig( + path=args.restore_path, + load_optim_state=False, + ) + if args.restore_path is not None + else None + ), + ) + resume.setup(trainer, model) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=0, # ,2.0e-5 + adam_beta1=0.9, + adam_beta2=0.95, + use_distributed_optimizer=False, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, warmup_steps=150, constant_steps=0, min_lr=0 # 2.0e-07, + ) + opt = MegatronOptimizerModule(opt_config, sched) + opt.connect(model) + + # Start training + + trainer.fit(model, data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="NEVA Model Training Script") + + # Argument parsing + parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset JSON file") + parser.add_argument("--image_folder", type=str, required=True, help="Path to the image folder") + parser.add_argument("--log_dir", type=str, required=True, help="Directory for logging and checkpoints") + parser.add_argument( + "--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model" + ) + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + parser.add_argument("--devices", type=int, required=False, default=8) + # parser.add_argument("--tp_size", type=int, required=False, default=4) + parser.add_argument("--tp_size", type=int, required=False, default=4) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu") + parser.add_argument("--name", type=str, required=False, default="neva_finetune") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + + args = parser.parse_args() + main(args) diff --git a/nemo/collections/multimodal/data/energon/conversation.py b/nemo/collections/multimodal/data/energon/conversation.py index f0749e47dc12..155c910b06ee 100644 --- a/nemo/collections/multimodal/data/energon/conversation.py +++ b/nemo/collections/multimodal/data/energon/conversation.py @@ -25,15 +25,16 @@ class BaseConversationTemplateConfig: chat_template = None +@dataclass class LLaVATemplateConfig(BaseConversationTemplateConfig): """LLava specific template configuration which extends the base config""" - system: Optional[str] = ( - "A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed and polite answers to user's questions.".format() - ) # fmt: off + system: Optional[str] = field( + default="A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed and polite answers to user's questions." + ) roles: List[str] = field(default_factory=lambda: ['user', 'assistant']) stop_string: str = "" - chat_template = """ + chat_template: str = """ {%- for message in messages %} {%- if message['role'] == 'system' %} {{- message['content'].strip() + ' ' -}} diff --git a/nemo/collections/vlm/neva/data/llava_next_energon.py b/nemo/collections/vlm/neva/data/llava_next_energon.py index ece4ef1f59ba..ac1819bffc07 100644 --- a/nemo/collections/vlm/neva/data/llava_next_energon.py +++ b/nemo/collections/vlm/neva/data/llava_next_energon.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional import torch +import torch.nn.functional as F from megatron.energon import VQASample, batch_list, batch_pad_stack from torch.nn.utils.rnn import pad_sequence @@ -25,6 +26,18 @@ from nemo.utils import logging +def pad_or_truncate(sequence_batch, seq_length: int, padding_value: int): + # Pad the sequence if it's shorter than seq_length + if sequence_batch.size(1) < seq_length: + pad_size = seq_length - sequence_batch.size(1) + sequence_batch = F.pad(sequence_batch, (0, pad_size), value=padding_value) + else: + # Truncate the sequence if it's longer than seq_length + sequence_batch = sequence_batch[:, :seq_length] + + return sequence_batch + + class LlavaNextTextSample(ImageTextSample): num_media_tiles: int = 0 image_sizes: torch.tensor = field(default_factory=lambda: torch.tensor([])) @@ -85,15 +98,14 @@ def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample): images, loss masks, and metadata. """ conversation_prompt = self.apply_prompt_template(input_sample) - logging.debug(f"task encoder encode_sample conversation_prompt {conversation_prompt}") + logging.info(f"task encoder encode_sample conversation_prompt {conversation_prompt}") # tokenize prompt tokens = self.tokenize(conversation_prompt) labels = self.compute_labels(tokens, input_sample) - tokens = tokens[:-1].contiguous() labels = labels[1:].contiguous() - logging.debug(f"[Energon] task encoder encode_sample after tokenize prompt tokens {tokens}") - logging.debug(f"[Energon] task encoder encode_sample lables {labels}") + logging.info(f"[Energon] task encoder encode_sample after tokenize prompt tokens {tokens}") + logging.info(f"[Energon] task encoder encode_sample lables {labels}") loss_mask = self.compute_loss_mask(labels) processed_image = self.process_image(input_sample.image) output_sample.__key__ = input_sample.__key__ @@ -110,7 +122,7 @@ def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample): class LlavaNextTaskEncoder(MultiModalTaskEncoder): - def __init__(self, tokenizer, image_processor, multimodal_sample_config): + def __init__(self, tokenizer, image_processor, multimodal_sample_config, seq_length): """ Initialize the LlavaNextTaskEncoder. @@ -126,6 +138,8 @@ def __init__(self, tokenizer, image_processor, multimodal_sample_config): self.encoders: Dict[str, SampleEncoder] = { VQASample.__name__: LlavaNextSampleEncoder(tokenizer, image_processor, multimodal_sample_config) } + self.seq_length = seq_length + self.ignore_index = multimodal_sample_config.ignore_place_holder def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: """ @@ -170,6 +184,10 @@ def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: image_sizes = torch.cat(image_sizes, dim=0) batch_loss_mask = batch_pad_stack(loss_mask) batch_attention_mask = batch_pad_stack(attention_mask) + # batch_tokens = pad_or_truncate(batch_tokens, self.seq_length, self.tokenizer.pad_token_id) + # batch_labels = pad_or_truncate(batch_labels, self.seq_length, self.ignore_index) + # batch_loss_mask = pad_or_truncate(batch_loss_mask, self.seq_length, 0) + # batch_attention_mask = pad_or_truncate(batch_attention_mask, self.seq_length, 0) batch_num_media_tiles = torch.tensor(batch_list(num_media_tiles), dtype=torch.int) return LlavaNextTextRawBatch( __keys__=batch_keys, diff --git a/nemo/collections/vlm/neva/model/base.py b/nemo/collections/vlm/neva/model/base.py index 51a5aeab5d6d..c789d2d02657 100644 --- a/nemo/collections/vlm/neva/model/base.py +++ b/nemo/collections/vlm/neva/model/base.py @@ -177,17 +177,38 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): return height // patch_size, width // patch_size +cached_batch = None + + def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: from megatron.core import parallel_state # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842 - batch = next(dataloader_iter) - _batch: dict - if isinstance(batch, tuple) and len(batch) == 3: - _batch = batch[0] + # batch = next(dataloader_iter) + # _batch: dict + # if isinstance(batch, tuple) and len(batch) == 3: + # _batch = batch[0] + # else: + # _batch = batch + + global cached_batch # Use the cached batch if already available + + if cached_batch is None: + # Fetch a new batch if the cache is empty + batch = next(dataloader_iter) + _batch: dict + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + # Cache the processed batch + cached_batch = _batch else: - _batch = batch + # Use the cached batch + _batch = cached_batch + required_keys = set() required_keys.add("attention_mask") required_keys.add("num_media_tiles") @@ -803,6 +824,9 @@ def forward( vision_feature_select_strategy='default', image_newline=self.image_newline, ) + # if torch.distributed.get_rank() == 0: + # breakpoint() + # torch.distributed.barrier() combined_embeddings, attention_mask, position_ids, final_labels, final_input_ids, final_loss_mask = ( merge_input_ids_with_image_features( media_embeddings, @@ -815,9 +839,6 @@ def forward( image_token_index=media_token_index, ) ) - # if torch.distributed.get_rank() == 0: - # breakpoint() - # torch.distributed.barrier() combined_embeddings = combined_embeddings.permute(1, 0, 2) combined_embeddings = combined_embeddings.contiguous() @@ -840,6 +861,9 @@ def forward( labels=final_labels, inference_params=inference_params, ) + # if torch.distributed.get_rank() == 0: + # breakpoint() + # torch.distributed.barrier() if labels is None or loss_mask is None: return output