diff --git a/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py b/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py index ee96ff6489d3..60f882fa9821 100644 --- a/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py +++ b/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py @@ -60,6 +60,7 @@ from tqdm import tqdm from nemo.collections.multimodal.data.neva.neva_dataset import make_supervised_data_module +from nemo.collections.multimodal.parts.utils import create_image_processor from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.utils import logging @@ -254,8 +255,14 @@ def main(): nemo_config.model.data.conv_template = args.conv_template nemo_config.model.data.image_aspect_ratio = args.image_aspect_ratio - tokenizer = get_nmt_tokenizer(library="sentencepiece", tokenizer_model=args.tokenizer_path,) - train_ds = make_supervised_data_module(tokenizer=tokenizer, model_cfg=nemo_config.model)["train_dataset"] + tokenizer = get_nmt_tokenizer( + library="sentencepiece", + tokenizer_model=args.tokenizer_path, + ) + image_processor = create_image_processor(nemo_config.model.mm_cfg) + train_ds = make_supervised_data_module( + tokenizer=tokenizer, image_processor=image_processor, model_cfg=nemo_config.model + )["train_dataset"] train_dl = DataLoader(train_ds, num_workers=32, collate_fn=None, shuffle=False) # Example shape: {'tokens': torch.Size([1, 344]), 'labels': torch.Size([1, 344]), 'image': torch.Size([1, 1, 3, 224, 224])} diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index c4daa13877f5..2fc9bb7e2022 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -23,10 +23,9 @@ from omegaconf.dictconfig import DictConfig from pkg_resources import packaging from pytorch_lightning.trainer.trainer import Trainer -from transformers import CLIPImageProcessor, CLIPVisionModel, SiglipImageProcessor, SiglipVisionModel +from transformers import CLIPVisionModel, SiglipVisionModel from nemo.collections.common.parts.utils import extend_instance -from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.conversation import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN from nemo.collections.multimodal.data.neva.neva_dataset import ( DataCollatorForSupervisedDataset, @@ -37,7 +36,7 @@ CLIPVisionTransformer, MegatronCLIPModel, ) -from nemo.collections.multimodal.parts.utils import load_nemo_model_weights +from nemo.collections.multimodal.parts.utils import create_image_processor, load_nemo_model_weights from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel, get_specs @@ -308,9 +307,6 @@ def create_vision_encoder_and_processor(self, mm_cfg): for param in vision_encoder.parameters(): param.requires_grad = False vision_encoder = vision_encoder.eval() - image_processor = CLIPImageProcessor.from_pretrained( - mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 - ) elif "siglip" in mm_cfg.vision_encoder.from_pretrained: vision_encoder = SiglipVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, @@ -321,9 +317,6 @@ def create_vision_encoder_and_processor(self, mm_cfg): for param in vision_encoder.parameters(): param.requires_grad = False vision_encoder = vision_encoder.eval() - image_processor = SiglipImageProcessor.from_pretrained( - mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 - ) else: raise (ValueError("Currently only support CLIPVisionModel and SigLipVisionModel from Huggingface")) else: @@ -334,13 +327,8 @@ def create_vision_encoder_and_processor(self, mm_cfg): self.load_vision_encoder_weights(vision_encoder, mm_cfg.vision_encoder.from_pretrained) if mm_cfg.vision_encoder.freeze: vision_encoder.freeze() - crop_size = mm_cfg.get("crop_size", (224, 224)) - image_processor = image_transform( - crop_size, - is_train=False, - mean=None, - std=None, - ) + + image_processor = create_image_processor(mm_cfg) return vision_encoder, image_processor diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 6d55ad015444..dc7ad86a0f23 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -21,7 +21,8 @@ from PIL import Image from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from transformers import CLIPImageProcessor +from transformers import CLIPImageProcessor, SiglipImageProcessor +from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.neva_dataset import process_image from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel @@ -508,3 +509,27 @@ def expand2square(pil_img, background_color): return media_tensors.unsqueeze(dim=0).unsqueeze(dim=0) return model, image_processor, video_processor + + +def create_image_processor(mm_cfg): + if mm_cfg.vision_encoder.get("from_hf", False): + if "clip" in mm_cfg.vision_encoder.from_pretrained: + image_processor = CLIPImageProcessor.from_pretrained( + mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 + ) + elif "siglip" in mm_cfg.vision_encoder.from_pretrained: + image_processor = SiglipImageProcessor.from_pretrained( + mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 + ) + else: + raise (ValueError("Currently only support CLIPImageProcessor and SiglipImageProcessor from Huggingface")) + else: + #Corresponds to MegatronCLIPModel + crop_size = mm_cfg.get("crop_size", (224, 224)) + image_processor = image_transform( + crop_size, + is_train=False, + mean=None, + std=None, + ) + return image_processor