Skip to content

Commit

Permalink
contant batch zero lr for loss spike debug
Browse files Browse the repository at this point in the history
  • Loading branch information
yashaswikarnati committed Nov 21, 2024
1 parent 6757a5e commit b1efbea
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 17 deletions.
190 changes: 190 additions & 0 deletions examples/vlm/llava_next_finetune_energon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torch
from megatron.core.optimizer import OptimizerConfig
from transformers import AutoProcessor

from nemo import lightning as nl
from nemo.collections import llm, vlm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule
from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig
from nemo.collections.vlm import ImageDataConfig, LlavaNextTaskEncoder
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.utils.exp_manager import TimingCallback


def main(args):
# Global and micro batch sizes
gbs = 2
mbs = 2
seq_length = 4096

processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
data_path = args.data_path
image_processor = processor.image_processor
# tokenizer = processor.tokenizer
tokenizer = AutoTokenizer("llava-hf/llava-v1.6-vicuna-7b-hf")

multimodal_sample_config = MultiModalSampleConfig()

task_encoder = LlavaNextTaskEncoder(
tokenizer=tokenizer.tokenizer,
image_processor=image_processor,
multimodal_sample_config=multimodal_sample_config,
seq_length=seq_length,
)
data = SimpleMultiModalDataModule(
path=data_path,
tokenizer=tokenizer,
image_processor=image_processor,
num_workers=0,
micro_batch_size=mbs,
global_batch_size=gbs,
multimodal_sample_config=multimodal_sample_config,
task_encoder=task_encoder,
)

# Transformer configurations
language_transformer_config = llm.Llama2Config7B()
vision_transformer_config = vlm.HFCLIPVisionConfig(
pretrained_model_name_or_path="openai/clip-vit-large-patch14-336"
)
vision_projection_config = vlm.MultimodalProjectorConfig(
projector_type=args.projector_type,
input_size=1024,
hidden_size=4096,
ffn_hidden_size=4096,
)

# NEVA model configuration
neva_config = vlm.NevaConfig(
language_transformer_config=language_transformer_config,
vision_transformer_config=vision_transformer_config,
vision_projection_config=vision_projection_config,
language_model_from_pretrained=args.language_model_path,
freeze_language_model=False,
is_llava_next=True,
)

model = vlm.NevaModel(neva_config, tokenizer=data.tokenizer)

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=args.tp_size,
pipeline_model_parallel_size=args.pp_size,
pipeline_dtype=torch.bfloat16,
ckpt_load_optimizer=True,
)

# Checkpoint callback setup
checkpoint_callback = nl.ModelCheckpoint(
save_last=True,
monitor="reduced_train_loss",
save_top_k=2,
every_n_train_steps=111,
dirpath=args.log_dir,
)

# Trainer setup
trainer = nl.Trainer(
devices=args.devices,
max_steps=5190,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
callbacks=[checkpoint_callback, TimingCallback()],
val_check_interval=211, # 1000,
limit_val_batches=gbs, # gbs,
log_every_n_steps=1,
num_sanity_val_steps=0,
)

# Logger setup
from pytorch_lightning.loggers import WandbLogger

nemo_logger = nl.NeMoLogger(
log_dir=args.log_dir,
name=args.name,
wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None,
)
nemo_logger.setup(
trainer,
resume_if_exists=True,
)

# Auto resume setup
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

resume = nl.AutoResume(
resume_if_exists=True,
resume_ignore_no_checkpoint=True,
# resume_from_directory=args.log_dir,
restore_config=(
RestoreConfig(
path=args.restore_path,
load_optim_state=False,
)
if args.restore_path is not None
else None
),
)
resume.setup(trainer, model)

# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer='adam',
lr=0, # ,2.0e-5
adam_beta1=0.9,
adam_beta2=0.95,
use_distributed_optimizer=False,
bf16=True,
)
sched = CosineAnnealingScheduler(
max_steps=trainer.max_steps, warmup_steps=150, constant_steps=0, min_lr=0 # 2.0e-07,
)
opt = MegatronOptimizerModule(opt_config, sched)
opt.connect(model)

# Start training

trainer.fit(model, data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NEVA Model Training Script")

# Argument parsing
parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset JSON file")
parser.add_argument("--image_folder", type=str, required=True, help="Path to the image folder")
parser.add_argument("--log_dir", type=str, required=True, help="Directory for logging and checkpoints")
parser.add_argument(
"--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model"
)
parser.add_argument(
"--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint"
)
parser.add_argument("--devices", type=int, required=False, default=8)
# parser.add_argument("--tp_size", type=int, required=False, default=4)
parser.add_argument("--tp_size", type=int, required=False, default=4)
parser.add_argument("--pp_size", type=int, required=False, default=1)
parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu")
parser.add_argument("--name", type=str, required=False, default="neva_finetune")
parser.add_argument("--wandb_project", type=str, required=False, default=None)

args = parser.parse_args()
main(args)
9 changes: 5 additions & 4 deletions nemo/collections/multimodal/data/energon/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ class BaseConversationTemplateConfig:
chat_template = None


@dataclass
class LLaVATemplateConfig(BaseConversationTemplateConfig):
"""LLava specific template configuration which extends the base config"""

system: Optional[str] = (
"A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed and polite answers to user's questions.".format()
) # fmt: off
system: Optional[str] = field(
default="A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed and polite answers to user's questions."
)
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
stop_string: str = "</s>"
chat_template = """
chat_template: str = """
{%- for message in messages %}
{%- if message['role'] == 'system' %}
{{- message['content'].strip() + ' ' -}}
Expand Down
28 changes: 23 additions & 5 deletions nemo/collections/vlm/neva/data/llava_next_energon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
from megatron.energon import VQASample, batch_list, batch_pad_stack
from torch.nn.utils.rnn import pad_sequence

Expand All @@ -25,6 +26,18 @@
from nemo.utils import logging


def pad_or_truncate(sequence_batch, seq_length: int, padding_value: int):
# Pad the sequence if it's shorter than seq_length
if sequence_batch.size(1) < seq_length:
pad_size = seq_length - sequence_batch.size(1)
sequence_batch = F.pad(sequence_batch, (0, pad_size), value=padding_value)
else:
# Truncate the sequence if it's longer than seq_length
sequence_batch = sequence_batch[:, :seq_length]

return sequence_batch


class LlavaNextTextSample(ImageTextSample):
num_media_tiles: int = 0
image_sizes: torch.tensor = field(default_factory=lambda: torch.tensor([]))
Expand Down Expand Up @@ -85,15 +98,14 @@ def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample):
images, loss masks, and metadata.
"""
conversation_prompt = self.apply_prompt_template(input_sample)
logging.debug(f"task encoder encode_sample conversation_prompt {conversation_prompt}")
logging.info(f"task encoder encode_sample conversation_prompt {conversation_prompt}")
# tokenize prompt
tokens = self.tokenize(conversation_prompt)
labels = self.compute_labels(tokens, input_sample)

tokens = tokens[:-1].contiguous()
labels = labels[1:].contiguous()
logging.debug(f"[Energon] task encoder encode_sample after tokenize prompt tokens {tokens}")
logging.debug(f"[Energon] task encoder encode_sample lables {labels}")
logging.info(f"[Energon] task encoder encode_sample after tokenize prompt tokens {tokens}")
logging.info(f"[Energon] task encoder encode_sample lables {labels}")
loss_mask = self.compute_loss_mask(labels)
processed_image = self.process_image(input_sample.image)
output_sample.__key__ = input_sample.__key__
Expand All @@ -110,7 +122,7 @@ def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample):


class LlavaNextTaskEncoder(MultiModalTaskEncoder):
def __init__(self, tokenizer, image_processor, multimodal_sample_config):
def __init__(self, tokenizer, image_processor, multimodal_sample_config, seq_length):
"""
Initialize the LlavaNextTaskEncoder.
Expand All @@ -126,6 +138,8 @@ def __init__(self, tokenizer, image_processor, multimodal_sample_config):
self.encoders: Dict[str, SampleEncoder] = {
VQASample.__name__: LlavaNextSampleEncoder(tokenizer, image_processor, multimodal_sample_config)
}
self.seq_length = seq_length
self.ignore_index = multimodal_sample_config.ignore_place_holder

def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch:
"""
Expand Down Expand Up @@ -170,6 +184,10 @@ def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch:
image_sizes = torch.cat(image_sizes, dim=0)
batch_loss_mask = batch_pad_stack(loss_mask)
batch_attention_mask = batch_pad_stack(attention_mask)
# batch_tokens = pad_or_truncate(batch_tokens, self.seq_length, self.tokenizer.pad_token_id)
# batch_labels = pad_or_truncate(batch_labels, self.seq_length, self.ignore_index)
# batch_loss_mask = pad_or_truncate(batch_loss_mask, self.seq_length, 0)
# batch_attention_mask = pad_or_truncate(batch_attention_mask, self.seq_length, 0)
batch_num_media_tiles = torch.tensor(batch_list(num_media_tiles), dtype=torch.int)
return LlavaNextTextRawBatch(
__keys__=batch_keys,
Expand Down
40 changes: 32 additions & 8 deletions nemo/collections/vlm/neva/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,38 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size


cached_batch = None


def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
from megatron.core import parallel_state

# Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842
batch = next(dataloader_iter)
_batch: dict
if isinstance(batch, tuple) and len(batch) == 3:
_batch = batch[0]
# batch = next(dataloader_iter)
# _batch: dict
# if isinstance(batch, tuple) and len(batch) == 3:
# _batch = batch[0]
# else:
# _batch = batch

global cached_batch # Use the cached batch if already available

if cached_batch is None:
# Fetch a new batch if the cache is empty
batch = next(dataloader_iter)
_batch: dict
if isinstance(batch, tuple) and len(batch) == 3:
_batch = batch[0]
else:
_batch = batch

# Cache the processed batch
cached_batch = _batch
else:
_batch = batch
# Use the cached batch
_batch = cached_batch

required_keys = set()
required_keys.add("attention_mask")
required_keys.add("num_media_tiles")
Expand Down Expand Up @@ -803,6 +824,9 @@ def forward(
vision_feature_select_strategy='default',
image_newline=self.image_newline,
)
# if torch.distributed.get_rank() == 0:
# breakpoint()
# torch.distributed.barrier()
combined_embeddings, attention_mask, position_ids, final_labels, final_input_ids, final_loss_mask = (
merge_input_ids_with_image_features(
media_embeddings,
Expand All @@ -815,9 +839,6 @@ def forward(
image_token_index=media_token_index,
)
)
# if torch.distributed.get_rank() == 0:
# breakpoint()
# torch.distributed.barrier()

combined_embeddings = combined_embeddings.permute(1, 0, 2)
combined_embeddings = combined_embeddings.contiguous()
Expand All @@ -840,6 +861,9 @@ def forward(
labels=final_labels,
inference_params=inference_params,
)
# if torch.distributed.get_rank() == 0:
# breakpoint()
# torch.distributed.barrier()
if labels is None or loss_mask is None:
return output

Expand Down

0 comments on commit b1efbea

Please sign in to comment.