Skip to content

Commit

Permalink
refac image processor loading to util
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed Jun 4, 2024
1 parent 46c9339 commit 2233fa6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from tqdm import tqdm

from nemo.collections.multimodal.data.neva.neva_dataset import make_supervised_data_module
from nemo.collections.multimodal.parts.utils import create_image_processor
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.utils import logging

Expand Down Expand Up @@ -254,8 +255,14 @@ def main():
nemo_config.model.data.conv_template = args.conv_template
nemo_config.model.data.image_aspect_ratio = args.image_aspect_ratio

tokenizer = get_nmt_tokenizer(library="sentencepiece", tokenizer_model=args.tokenizer_path,)
train_ds = make_supervised_data_module(tokenizer=tokenizer, model_cfg=nemo_config.model)["train_dataset"]
tokenizer = get_nmt_tokenizer(
library="sentencepiece",
tokenizer_model=args.tokenizer_path,
)
image_processor = create_image_processor(nemo_config.model.mm_cfg)
train_ds = make_supervised_data_module(
tokenizer=tokenizer, image_processor=image_processor, model_cfg=nemo_config.model
)["train_dataset"]
train_dl = DataLoader(train_ds, num_workers=32, collate_fn=None, shuffle=False)
# Example shape: {'tokens': torch.Size([1, 344]), 'labels': torch.Size([1, 344]), 'image': torch.Size([1, 1, 3, 224, 224])}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
from omegaconf.dictconfig import DictConfig
from pkg_resources import packaging
from pytorch_lightning.trainer.trainer import Trainer
from transformers import CLIPImageProcessor, CLIPVisionModel, SiglipImageProcessor, SiglipVisionModel
from transformers import CLIPVisionModel, SiglipVisionModel

from nemo.collections.common.parts.utils import extend_instance
from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform
from nemo.collections.multimodal.data.neva.conversation import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
from nemo.collections.multimodal.data.neva.neva_dataset import (
DataCollatorForSupervisedDataset,
Expand All @@ -37,7 +36,7 @@
CLIPVisionTransformer,
MegatronCLIPModel,
)
from nemo.collections.multimodal.parts.utils import load_nemo_model_weights
from nemo.collections.multimodal.parts.utils import create_image_processor, load_nemo_model_weights
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel, get_specs
Expand Down Expand Up @@ -308,9 +307,6 @@ def create_vision_encoder_and_processor(self, mm_cfg):
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained:
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
Expand All @@ -321,9 +317,6 @@ def create_vision_encoder_and_processor(self, mm_cfg):
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
else:
raise (ValueError("Currently only support CLIPVisionModel and SigLipVisionModel from Huggingface"))
else:
Expand All @@ -334,13 +327,8 @@ def create_vision_encoder_and_processor(self, mm_cfg):
self.load_vision_encoder_weights(vision_encoder, mm_cfg.vision_encoder.from_pretrained)
if mm_cfg.vision_encoder.freeze:
vision_encoder.freeze()
crop_size = mm_cfg.get("crop_size", (224, 224))
image_processor = image_transform(
crop_size,
is_train=False,
mean=None,
std=None,
)

image_processor = create_image_processor(mm_cfg)

return vision_encoder, image_processor

Expand Down
27 changes: 26 additions & 1 deletion nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from PIL import Image
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from transformers import CLIPImageProcessor
from transformers import CLIPImageProcessor, SiglipImageProcessor
from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform

from nemo.collections.multimodal.data.neva.neva_dataset import process_image
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
Expand Down Expand Up @@ -508,3 +509,27 @@ def expand2square(pil_img, background_color):
return media_tensors.unsqueeze(dim=0).unsqueeze(dim=0)

return model, image_processor, video_processor


def create_image_processor(mm_cfg):
if mm_cfg.vision_encoder.get("from_hf", False):
if "clip" in mm_cfg.vision_encoder.from_pretrained:
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained:
image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
else:
raise (ValueError("Currently only support CLIPImageProcessor and SiglipImageProcessor from Huggingface"))
else:
#Corresponds to MegatronCLIPModel
crop_size = mm_cfg.get("crop_size", (224, 224))
image_processor = image_transform(
crop_size,
is_train=False,
mean=None,
std=None,
)
return image_processor

0 comments on commit 2233fa6

Please sign in to comment.