Skip to content

Commit

Permalink
Add HF siglip vision encoder
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed May 13, 2024
1 parent a0e9ee3 commit 79f042b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 62 deletions.
84 changes: 41 additions & 43 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from omegaconf import DictConfig
from PIL import Image
from torch.utils.data import Dataset, default_collate
from transformers import CLIPImageProcessor
from transformers import CLIPImageProcessor, SiglipImageProcessor

import nemo.collections.multimodal.data.neva.conversation as conversation_lib
from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform
Expand Down Expand Up @@ -295,6 +295,44 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in
return sources


def process_image(processor, image, image_aspect_ratio="square"):
if isinstance(processor, CLIPImageProcessor) or isinstance(processor, SiglipImageProcessor):
# image processor from HF
if image_aspect_ratio == 'keep':
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
image = processor.preprocess(
image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge}
)['pixel_values'][0]
elif image_aspect_ratio == 'pad':

def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result

image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
assert (
image_aspect_ratio == 'square'
), 'NeMo image transform with setting `image_aspect_ratio` to `square`.'
image = processor(image)
return image


def preprocess_llama_2(sources: dict, tokenizer, cfg,) -> Dict:
"""
Preprocesses sources for the LLaMA 2 model configuration.
Expand Down Expand Up @@ -760,40 +798,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = self.image_loader.open_image(image_file)
if image is None:
logging.warning(f"Image {image_file} could not be found!")
if isinstance(self.processor, CLIPImageProcessor):
# image processor from HF
if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
image = self.processor.preprocess(
image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge}
)['pixel_values'][0]
elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':

def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result

image = expand2square(image, tuple(int(x * 255) for x in self.processor.image_mean))
image = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = self.processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
assert (
self.multimodal_cfg['image_aspect_ratio'] == 'square'
), 'NeMo image transform with setting `image_aspect_ratio` to `square`.'
image = self.processor(image)
image = process_image(self.processor, image, self.multimodal_cfg['image_aspect_ratio'])
images.append(image)
media_tensors = torch.tensor([])
if images:
Expand Down Expand Up @@ -1034,21 +1039,14 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
return batch


def make_supervised_data_module(tokenizer, model_cfg) -> Dict:
def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
data_cfg = model_cfg.data
mm_cfg = model_cfg.mm_cfg
add_extra_token = 1
if getattr(model_cfg, 'no_seqlen_plus_one_input_tokens', False):
add_extra_token = 0
crop_size = data_cfg.get("crop_size", (224, 224))
if mm_cfg.vision_encoder.from_hf:
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
else:
# TODO(yuya): Fix this hard-code for our own CLIP
image_processor = image_transform(crop_size, is_train=False, mean=None, std=None,)

train_dataset = NevaDataset(
tokenizer=tokenizer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from omegaconf.dictconfig import DictConfig
from pkg_resources import packaging
from pytorch_lightning.trainer.trainer import Trainer
from transformers import CLIPVisionModel
from transformers import CLIPVisionModel, CLIPImageProcessor, SiglipVisionModel, SiglipImageProcessor

from nemo.collections.common.parts.utils import extend_instance
from nemo.collections.multimodal.data.neva.conversation import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
Expand All @@ -37,6 +37,7 @@
MegatronCLIPModel,
)
from nemo.collections.multimodal.parts.utils import load_nemo_model_weights
from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform
from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler
from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel, get_specs
Expand Down Expand Up @@ -136,7 +137,8 @@ def init_vision(
use_im_start_end=False,
):
self.vision_encoder = vision_encoder
self.from_hf = isinstance(vision_encoder, CLIPVisionModel)
self.from_hf = isinstance(vision_encoder, CLIPVisionModel) or isinstance(vision_encoder, SiglipVisionModel)
self.from_open_clip = "open_clip" in str(vision_encoder.__module__)
self.media_start_id = media_start_id
self.media_end_id = media_end_id
self.class_token_length = class_token_length
Expand Down Expand Up @@ -261,11 +263,58 @@ def __init__(
if mm_cfg.llm.freeze:
self.freeze_llm(mm_cfg)

vision_encoder, self.image_processor = self.create_vision_encoder_and_processor(mm_cfg)

# Monkey patch embedding
if kwargs.get("pre_process", True):
extend_instance(self.embedding.word_embeddings, NevaWordEmbeddingMixin)
self.embedding.word_embeddings.init_vision(
vision_encoder,
media_start_id,
media_end_id,
vision_select_layer=mm_cfg.vision_encoder.get("vision_select_layer", -2),
class_token_length=mm_cfg.vision_encoder.get("class_token_length", 1),
use_im_start_end=mm_cfg.get("use_im_start_end", False),
)

def create_vision_encoder_and_processor(self, mm_cfg):
# Initialize vision encoder and freeze it
if mm_cfg.vision_encoder.from_hf:
vision_encoder = CLIPVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16,
).cuda()
if mm_cfg.vision_encoder.get("from_hf", False):
if "clip" in mm_cfg.vision_encoder.from_pretrained:
vision_encoder = CLIPVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16,
).cuda()
vision_encoder = vision_encoder.to(torch.bfloat16)
if mm_cfg.vision_encoder.freeze:
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained:
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16,
).cuda()
vision_encoder = vision_encoder.to(torch.bfloat16)
if mm_cfg.vision_encoder.freeze:
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
else:
raise(ValueError("Currently only support CLIPVisionModel and SigLipVisionModel from Huggingface"))
elif mm_cfg.vision_encoder.get("from_open_clip", False):
assert mm_cfg.vision_encoder.get("open_clip_model_name") is not None, \
f"`open_clip_model_name` needs to be set."
model, _, image_processor = open_clip.create_model_and_transforms(
mm_cfg.vision_encoder.open_clip_model_name,
pretrained=mm_cfg.vision_encoder.from_pretrained, precision=torch.bfloat16,
)
vision_encoder = model.visual.cuda()
del model
vision_encoder = vision_encoder.to(torch.bfloat16)
if mm_cfg.vision_encoder.freeze:
for param in vision_encoder.parameters():
Expand All @@ -279,18 +328,10 @@ def __init__(
self.load_vision_encoder_weights(vision_encoder, mm_cfg.vision_encoder.from_pretrained)
if mm_cfg.vision_encoder.freeze:
vision_encoder.freeze()
crop_size = mm_cfg.get("crop_size", (224, 224))
image_processor = image_transform(crop_size, is_train=False, mean=None, std=None, )

# Monkey patch embedding
if kwargs.get("pre_process", True):
extend_instance(self.embedding.word_embeddings, NevaWordEmbeddingMixin)
self.embedding.word_embeddings.init_vision(
vision_encoder,
media_start_id,
media_end_id,
vision_select_layer=mm_cfg.vision_encoder.get("vision_select_layer", -2),
class_token_length=mm_cfg.vision_encoder.get("class_token_length", 1),
use_im_start_end=mm_cfg.get("use_im_start_end", False),
)
return vision_encoder, image_processor

def freeze_llm(self, mm_cfg):
raise NotImplementedError
Expand Down Expand Up @@ -981,10 +1022,14 @@ def build_train_valid_test_datasets(self):
self._train_ds = NevaPackedSeqDatatset(self.cfg.data.data_prefix, self.cfg.data.get("crop_size"))
self._validation_ds = NevaPackedSeqDatatset(self.cfg.data.data_prefix, self.cfg.data.get("crop_size"))
else:
ds_dict = make_supervised_data_module(tokenizer=self.tokenizer, model_cfg=self.cfg,)
ds_dict = make_supervised_data_module(
tokenizer=self.tokenizer,
image_processor=self.model.module.image_processor if hasattr(self.model,
"module") else self.model.image_processor,
model_cfg=self.cfg,
)
self._train_ds = ds_dict["train_dataset"]
self._validation_ds = ds_dict["eval_dataset"]

return self._train_ds, self._validation_ds

def build_pretraining_data_loader(
Expand Down

0 comments on commit 79f042b

Please sign in to comment.