Skip to content

Commit

Permalink
change inference time image processor
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <[email protected]>
  • Loading branch information
HuiyingLi committed May 14, 2024
1 parent 56a3bb4 commit 72e75cb
Showing 1 changed file with 9 additions and 34 deletions.
43 changes: 9 additions & 34 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP
from nemo.collections.multimodal.data.neva.neva_dataset import process_image
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.utils import AppState, logging
from nemo.utils.model_utils import inject_model_parallel_rank
Expand Down Expand Up @@ -425,42 +426,16 @@ def image_processor(maybe_image_path):
else:
image = maybe_image_path

if neva_cfg.mm_cfg.vision_encoder.from_hf:
processor = CLIPImageProcessor.from_pretrained(
neva_cfg.mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
else:
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)

if neva_cfg.data.image_aspect_ratio == 'keep':
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
image = processor.preprocess(
image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge}
)['pixel_values'][0]
elif neva_cfg.data.image_aspect_ratio == 'pad':

def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result

image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
processor = model.model.module.image_processor \
if hasattr(model.model, "module") else model.model.image_processor
image = process_image(processor, image, neva_cfg.data.image_aspect_ratio)
if neva_cfg.precision in [16, '16', '16-mixed']:
media = image.type(torch.float16)
elif neva_cfg.precision in [32, '32', '32-true']:
media = image.type(torch.float32)
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
media = image.type(torch.bfloat16)

media = image.type(torch_dtype_from_precision(neva_cfg.precision))
return media.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)

# add video processor for video neva
Expand Down

0 comments on commit 72e75cb

Please sign in to comment.