Skip to content

Commit

Permalink
fix lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
yashaswikarnati committed Nov 25, 2024
1 parent 7aebd33 commit f4f6a0c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 35 deletions.
21 changes: 0 additions & 21 deletions .devcontainer/devcontainer.json

This file was deleted.

13 changes: 8 additions & 5 deletions nemo/collections/vlm/llava_next/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def llava_next_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
from megatron.core import parallel_state

# Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842
# https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/
# megatron_gpt_model.py#L828-L842
batch = next(dataloader_iter)
_batch: dict
if isinstance(batch, tuple) and len(batch) == 3:
Expand Down Expand Up @@ -227,19 +228,20 @@ def forward(
"""Forward function of the LLaVA Next model.
Args:
images (torch.Tensor): input image of shape [num_tiles, img_h, img_w]. num_tiles means the number of image tiles in this batch.
images (torch.Tensor): input image of shape [num_tiles, img_h, img_w].
num_tiles means the number of image tiles in this batch.
input_ids (torch.Tensor): input text ids [batch, text_seq_len].
position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
image_sizes (torch.Tensor): Raw image sizes before tiling (N,2).
attention_mask (torch.Tensor): Attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len].
attention_mask (torch.Tensor): Attention mask for the language model [batch, text seq length].
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len].
inference_params (InferenceParams): Inference-time parameters including KV cache.
num_media_tiles (list of int): Number of tiles per image. Default None assumes 1 tile per image.
image_token_index (int): ID for input images.
Returns:
output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size].
output (torch.Tensor): Loss ([b, s]) if labels are provided; logits ([b, s, vocab_size]) otherwise.
loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s].
"""

Expand All @@ -248,7 +250,8 @@ def forward(
)
has_images = media.shape[0] > 0

# If running inference, we can skip media token computation if they were computed already earlier for this sample.
# If running inference, we can skip media token computation
# if they were computed already earlier for this sample.
if use_inference_kv_cache:
media_embeddings = None
elif self.add_encoder and not has_images:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/vlm/llava_next/model/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LlavaNextModel(NevaModel):
Attributes:
config (LlavaNextConfig): Configuration object for the model.
optim (Optional[OptimizerModule]): Optimizer module for training the model. Defaults to a Megatron optimizer.
optim (Optional[OptimizerModule]): Optimizer module. Defaults to a Megatron optimizer.
tokenizer (Optional[TokenizerSpec]): Tokenizer specification for processing text inputs.
model_transform (Optional[Callable[[torch.nn.Module], torch.nn.Module]]):
Optional transformation applied to the model after initialization.
Expand All @@ -78,7 +78,7 @@ def __init__(
Args:
config (LlavaNextConfig): Configuration object for the model.
optim (Optional[OptimizerModule]): Optional optimizer module. If not provided, a default Megatron optimizer is used.
optim (Optional[OptimizerModule]): optimizer module. Defaults to Megatron optimizer.
tokenizer (Optional[TokenizerSpec]): Optional tokenizer specification for processing text inputs.
model_transform (Optional[Callable[[torch.nn.Module], torch.nn.Module]]):
Optional transformation function applied to the model after initialization.
Expand Down Expand Up @@ -121,7 +121,7 @@ def forward(
position_ids (torch.Tensor): Position IDs of shape [batch, text_seq_len].
image_sizes (torch.Tensor): Raw image sizes before tiling, of shape [batch, 2].
loss_mask (Optional[torch.Tensor]): Text loss mask of shape [batch, text_seq_len].
attention_mask (Optional[torch.Tensor]): Attention mask (before merging image embeddings) of shape [batch, text_seq_len].
attention_mask (Optional[torch.Tensor]): Attention mask shape [batch, text_seq_len].
media (Optional[torch.Tensor]): Input media tensor.
labels (Optional[torch.Tensor]): Target labels of shape [batch, combined_seq_len].
inference_params (InferenceParams): Inference-time parameters.
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/vlm/llava_next/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def merge_input_ids_with_image_features(
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens})
different from num_images ({num_images})."
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) "
f"different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
Expand Down Expand Up @@ -274,8 +274,8 @@ def unpad_image(tensor, original_size):
if not isinstance(original_size, (list, tuple)):
if not isinstance(original_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(original_size)} not valid,
should be either list, tuple, np.ndarray or tensor"
f"image_size invalid type: {type(original_size)} not valid ",
"should be either list, tuple, np.ndarray or tensor",
)
original_size = original_size.tolist()
original_height, original_width = original_size
Expand Down Expand Up @@ -355,8 +355,8 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
if not isinstance(image_size, (list, tuple)):
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
raise TypeError(
f"image_size invalid type: {type(image_size)} not valid,
should be either list, tuple, np.ndarray or tensor"
f"image_size invalid type: {type(image_size)} not valid, "
"should be either list, tuple, np.ndarray or tensor"
)
image_size = image_size.tolist()

Expand Down

0 comments on commit f4f6a0c

Please sign in to comment.