Skip to content

Commit

Permalink
Update llama32 vision (mllama) use attention bias (#11316)
Browse files Browse the repository at this point in the history
* update recipe

Signed-off-by: yaoyu-33 <[email protected]>

* fix mllama mock ds

Signed-off-by: yaoyu-33 <[email protected]>

* update to use attention bias

Signed-off-by: yaoyu-33 <[email protected]>

* remove example

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix docstring mock.py

Signed-off-by: yaoyu-33 <[email protected]>

* fix docstring language.py

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix docstring language.py

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix docstring mllama/base.py

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* fix docstring mllama/language.py

Signed-off-by: yaoyu-33 <[email protected]>

* bump mcore

Signed-off-by: Oliver Koenig <[email protected]>

* Add scripts for mllama

Signed-off-by: yaoyu-33 <[email protected]>

* fix

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

* update script

Signed-off-by: yaoyu-33 <[email protected]>

* fix pylint

Signed-off-by: yaoyu-33 <[email protected]>

* revert Dockerfile.ci

Signed-off-by: Yu Yao <[email protected]>

* update script match recipe

Signed-off-by: yaoyu-33 <[email protected]>

* update recipes

Signed-off-by: yaoyu-33 <[email protected]>

* update mllama 90b recipe

Signed-off-by: yaoyu-33 <[email protected]>

---------

Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: Oliver Koenig <[email protected]>
Signed-off-by: Yu Yao <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
Co-authored-by: Oliver Koenig <[email protected]>
  • Loading branch information
3 people authored Nov 25, 2024
1 parent 3afcde0 commit 5094b2e
Show file tree
Hide file tree
Showing 8 changed files with 573 additions and 55 deletions.
40 changes: 40 additions & 0 deletions nemo/collections/vlm/mllama/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@


class MockDataModule(pl.LightningDataModule):
"""
Mock DataModule for testing and development.
Generates synthetic data for training, validation, and testing purposes.
Args:
seq_length (int): Sequence length for the generated data.
decoder_seq_length (Optional[int]): Decoder sequence length if applicable, used in pp.
vocab_size (int): Size of the vocabulary of tokenizer.
crop_size (Tuple[int, int]): Image crop size (height, width).
micro_batch_size (int): Micro batch size for data loading.
global_batch_size (int): Global batch size across all processes.
rampup_batch_size (Optional[List[int]]): Batch size ramp-up configuration.
num_train_samples (int): Number of training samples to generate.
num_val_samples (int): Number of validation samples to generate.
num_test_samples (int): Number of test samples to generate.
num_workers (int): Number of workers for data loading.
pin_memory (bool): Whether to pin memory for data loading.
persistent_workers (bool): Whether workers should remain persistent.
"""

def __init__(
self,
seq_length: int = 2048,
Expand Down Expand Up @@ -66,6 +86,7 @@ def __init__(
)

def setup(self, stage: str = "") -> None:
"""Set up datasets for the specified stage."""
self._train_ds = _MockMLlamaDataset(
self.vocab_size, self.crop_size, "train", self.num_train_samples, self.decoder_seq_length
)
Expand All @@ -77,21 +98,25 @@ def setup(self, stage: str = "") -> None:
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Returns the DataLoader for training."""
if not hasattr(self, "_train_ds"):
self.setup()
return self._create_dataloader(self._train_ds)

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Returns the DataLoader for validation."""
if not hasattr(self, "_validation_ds"):
self.setup()
return self._create_dataloader(self._validation_ds)

def test_dataloader(self) -> EVAL_DATALOADERS:
"""Returns the DataLoader for testing."""
if not hasattr(self, "_test_ds"):
self.setup()
return self._create_dataloader(self._test_ds)

def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
"""Creates a DataLoader for the specified dataset."""
return DataLoader(
dataset,
num_workers=self.num_workers,
Expand All @@ -103,6 +128,18 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader:


class _MockMLlamaDataset(Dataset):
"""
Mock dataset for generating synthetic data with text and image components.
Args:
vocab_size (int): Vocabulary size for text data.
crop_size (Tuple[int, int]): Image crop size (height, width).
name (str): Name of the dataset split ('train', 'valid', 'test').
num_samples (int): Number of samples in the dataset.
seq_length (int): Sequence length for the text data.
seed (int): Seed for random number generation.
"""

def __init__(
self,
vocab_size,
Expand All @@ -127,13 +164,16 @@ def __init__(
self.position_ids = torch.arange(self.seq_length, dtype=torch.int64)

def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
return self.length

def _get_text(self, idx: int) -> np.ndarray:
"""Generates a random sequence of integers representing text tokens."""
np_gen = np.random.default_rng(seed=(self.seed + idx))
return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64)

def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
"""Generates a single data sample."""
# Generate data of the expected size and datatype (based on GPTDataset).
np_gen = np.random.default_rng(seed=(self.seed + idx))
tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length + 1], dtype=np.int64))
Expand Down
61 changes: 44 additions & 17 deletions nemo/collections/vlm/mllama/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
from nemo.utils import logging


def llama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
def mllama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
"""Mllama data step."""
from megatron.core import parallel_state

# Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87
Expand Down Expand Up @@ -96,7 +97,8 @@ def llama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
return output


def llama_forward_step(model, batch) -> torch.Tensor:
def mllama_forward_step(model, batch) -> torch.Tensor:
"""Mllama model forward step."""
forward_config = {
"batch_images": batch["batch_images"],
"batch_masks": batch["batch_masks"],
Expand All @@ -114,13 +116,15 @@ def llama_forward_step(model, batch) -> torch.Tensor:


def set_input_tensor(self, tensor):
"""Placeholder for `set_input_tensor` method for PP implementation."""
pass


@dataclass
class CrossAttentionVisionConfig(TransformerConfig, io.IOMixin):
# core params
"""Configuration for llama vision model."""

# core params
bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True

Expand Down Expand Up @@ -150,16 +154,22 @@ class CrossAttentionVisionConfig(TransformerConfig, io.IOMixin):

@property
def max_aspect_ratio_id(self) -> int:
# pylint: disable=C0115,C0116
return len(self.supported_aspect_ratios)

def configure_model(self) -> "CrossAttentionVisionModel":
"""Configure mllama vision model."""
return CrossAttentionVisionModel(
self,
)


@dataclass
class CrossAttentionTextConfig(Llama31Config):
"""
Configuration for llama model with cross-attention layers to take in multimodal features.
"""

rotary_base: int = 500_000
seq_length: int = 8192
num_layers: int = 32
Expand All @@ -171,12 +181,14 @@ class CrossAttentionTextConfig(Llama31Config):
apply_rope_fusion: bool = False

def _init_fusion_schedule(self, num_layers: int) -> List[int]:
llama_layers = list(range(self.num_layers))
"""Initialize self-attention layer / cross-attention layer fusion schedule"""
mllama_layers = list(range(self.num_layers))
# uniformly spread the layers
k = math.ceil(len(llama_layers) / num_layers)
return llama_layers[::-1][::k][:num_layers][::-1]
k = math.ceil(len(mllama_layers) / num_layers)
return mllama_layers[::-1][::k][:num_layers][::-1]

def configure_model(self, tokenizer, pre_process=True, post_process=True):
"""Configure mllama text model."""
self.fusion_schedule = self._init_fusion_schedule(self.num_cross_attention_layers)
vp_size = self.virtual_pipeline_model_parallel_size
if vp_size:
Expand Down Expand Up @@ -225,6 +237,8 @@ def configure_model(self, tokenizer, pre_process=True, post_process=True):

@dataclass
class MLlamaModelConfig(TransformerConfig, io.IOMixin):
"""Combined configuration for multimodal vision-language model."""

language_model_config: Optional[CrossAttentionTextConfig] = None
vision_model_config: Optional[CrossAttentionVisionConfig] = None

Expand All @@ -237,15 +251,16 @@ class MLlamaModelConfig(TransformerConfig, io.IOMixin):
language_model_from_pretrained: Optional[str] = None # TODO
vision_model_from_pretrained: Optional[str] = None # TODO

forward_step_fn: Callable = llama_forward_step
data_step_fn: Callable = llama_data_step
forward_step_fn: Callable = mllama_forward_step
data_step_fn: Callable = mllama_data_step

def __post_init__(self):
if self.language_model_config is not None:
for attr in MODEL_CONFIG_ATTR:
setattr(self, attr, getattr(self.language_model_config, attr))

def configure_model(self, tokenizer) -> "MLlamaBaseModel":
"""Configure mllama model."""
from megatron.core import parallel_state as ps

self.language_model_config.tensor_model_parallel_size = self.tensor_model_parallel_size
Expand Down Expand Up @@ -274,6 +289,8 @@ def configure_model(self, tokenizer) -> "MLlamaBaseModel":


class CrossAttentionVisionModel(MegatronModule):
"""Mllama vision model."""

def __init__(self, config) -> None:
super().__init__(config=config)
return_intermediate = "3,7,15,23,30"
Expand Down Expand Up @@ -303,6 +320,7 @@ def __init__(self, config) -> None:
self.vision_projection.encoder.skip_bias_add = False # Temporary fix for a MCore side bug

def forward(self, images: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
"""Forward."""
# vision_tokens: (B, T, D)
# aspect_ratio_ids: (B, 1)
# h: (B, T, D)
Expand All @@ -313,10 +331,13 @@ def forward(self, images: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch
return vision_tokens

def set_input_tensor(self, tensor):
# pylint: disable=C0115,C0116
pass


class MLlamaBaseModel(MegatronModule):
"""Mllama base model combining vision and text models with cross-attention."""

def __init__(
self,
config: MLlamaModelConfig,
Expand Down Expand Up @@ -356,10 +377,6 @@ def __init__(
self.patch_size = 14
self.image_res = vision_model_config.vision_chunk_size
self.max_num_chunks = vision_model_config.vision_max_num_chunks
logging.warning("[WARNING] NeMo Mllama will always pad images to max number of tiles. A fix is coming soon!")

def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
self.language_model.setup_cache(max_batch_size, dtype)

def compute_xattn_caches_masks(
self,
Expand All @@ -369,6 +386,7 @@ def compute_xattn_caches_masks(
num_chunks: torch.Tensor,
total_len: int,
) -> Tuple[List, torch.Tensor, torch.Tensor]:
"""Compute xattn caches masks used in text model."""
bsz, nimg, nchunk, ntok, image_token_dim = vision_orig_shape

xattn_caches = [
Expand Down Expand Up @@ -408,6 +426,7 @@ def forward(
full_text_row_masked_out_mask: Optional[torch.Tensor] = None,
xattn_caches: Optional[List] = None,
) -> torch.Tensor:
"""Forward."""
if xattn_caches is None:
bsz, max_num_images = batch_images.size(0), batch_images.size(1)
vision_orig_shape = (
Expand All @@ -418,8 +437,8 @@ def forward(
self.config.hidden_size,
)
skip_vision_encoder = False
num_chunks[num_chunks > 0] = self.max_num_chunks
if max_num_images == 0:
num_chunks[num_chunks > 0] = self.max_num_chunks
skip_vision_encoder = True

if self.encoder_hidden_state is not None:
Expand Down Expand Up @@ -489,6 +508,8 @@ def set_input_tensor(self, input_tensor) -> None:


class MLlamaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin):
"""Lightning Module for the MLlama model."""

def __init__(
self,
config: MLlamaModelConfig,
Expand All @@ -506,6 +527,7 @@ def __init__(
self._validation_loss_reduction = None

def configure_model(self) -> None:
"""Configure mllama model"""
if not hasattr(self, "module"):
self.module: MLlamaBaseModel = self.config.configure_model(self.tokenizer)

Expand All @@ -522,7 +544,7 @@ def forward(
full_text_row_masked_out_mask: Optional[torch.Tensor] = None,
xattn_caches: Optional[torch.Tensor] = None,
) -> torch.Tensor:

"""Forward."""
output_tensor = self.module(
position_ids=position_ids,
tokens=tokens,
Expand All @@ -539,29 +561,34 @@ def forward(
return output_tensor

def data_step(self, dataloader_iter) -> Dict[str, torch.Tensor]:
# pylint: disable=C0115,C0116
return self.config.data_step_fn(dataloader_iter)

def forward_step(self, batch) -> torch.Tensor:
# pylint: disable=C0115,C0116
return self.config.forward_step_fn(self, batch)

def training_step(self, batch, batch_idx=None) -> torch.Tensor:
# pylint: disable=C0115,C0116
# In mcore the loss-function is part of the forward-pass (when labels are provided)
return self.forward_step(batch)

def validation_step(self, batch, batch_idx=None) -> torch.Tensor:
# pylint: disable=C0115,C0116
# In mcore the loss-function is part of the forward-pass (when labels are provided)

return self.forward_step(batch)

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
# pylint: disable=C0115,C0116
if not self._training_loss_reduction:
self._training_loss_reduction = MaskedTokenLossReduction()

return self._training_loss_reduction

@property
def validation_loss_reduction(self) -> MaskedTokenLossReduction:
# pylint: disable=C0115,C0116
if not self._validation_loss_reduction:
self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True)

Expand All @@ -573,8 +600,8 @@ def validation_loss_reduction(self) -> MaskedTokenLossReduction:
"MLlamaModelConfig",
"CrossAttentionTextConfig",
"CrossAttentionVisionConfig",
"llama_data_step",
"llama_forward_step",
"mllama_data_step",
"mllama_forward_step",
"transformer_engine_layer_spec",
"local_layer_spec",
]
Loading

0 comments on commit 5094b2e

Please sign in to comment.