From 6150584244038449b3547a8dd6f78ec45176cc95 Mon Sep 17 00:00:00 2001 From: Vivian Chen Date: Mon, 29 Jul 2024 22:05:05 +0000 Subject: [PATCH] inital commit on adding blenddataset for neva Signed-off-by: Vivian Chen --- .../multimodal/data/neva/neva_dataset.py | 15 ++- .../models/multimodal_llm/neva/neva_model.py | 103 +++++++++++++++++- 2 files changed, 108 insertions(+), 10 deletions(-) diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index b56c42fff274..c808fc5be0d1 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -1000,6 +1000,8 @@ def __len__(self): return len(self.list_data_dict) def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if isinstance(i, np.integer): + i = int(i) sources = self.list_data_dict[i] if isinstance(i, int): sources = [sources] @@ -1186,7 +1188,6 @@ class NevaDataset(LazySupervisedDataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, tokenizer, multimodal_cfg: dict, data_cfg: dict): - if data_path.endswith(".json"): super(NevaDataset, self).__init__(data_path, tokenizer, multimodal_cfg, data_cfg) @@ -1309,7 +1310,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: return batch -def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict: +def make_supervised_data_module(tokenizer, image_processor, model_cfg, data_file) -> Dict: """Make dataset and collator for supervised fine-tuning.""" data_cfg = model_cfg.data mm_cfg = model_cfg.mm_cfg @@ -1317,10 +1318,14 @@ def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict: if getattr(model_cfg, 'no_seqlen_plus_one_input_tokens', False): add_extra_token = 0 crop_size = mm_cfg.vision_encoder.get("crop_size", (224, 224)) - + if not data_cfg.get("data_path"): + data_path = data_file + else: + data_path = data_cfg.data_path + # use blend train_dataset = NevaDataset( tokenizer=tokenizer, - data_path=data_cfg.data_path, + data_path=data_path, multimodal_cfg=dict( is_multimodal=data_cfg.is_multimodal, sep_image_conv_front=data_cfg.sep_image_conv_front, @@ -1349,7 +1354,7 @@ def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict: ) return dict(train_dataset=train_dataset, eval_dataset=train_dataset) - + class NevaPackedSeqDatatset(Dataset): def __init__(self, data_path: str, crop_size: Tuple[int, int] = (224, 224)): diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index c5805c972ad0..442f0cc400e9 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -20,8 +20,9 @@ import numpy as np import torch import torch.nn.functional as F +from torch.utils.data import ConcatDataset from einops import rearrange, reduce, repeat -from omegaconf.dictconfig import DictConfig +from omegaconf import DictConfig, ListConfig from pkg_resources import packaging from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPVisionModel, SiglipVisionModel @@ -33,6 +34,10 @@ NevaPackedSeqDatatset, make_supervised_data_module, ) +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import ( CLIPVisionTransformer, MegatronCLIPModel, @@ -473,7 +478,7 @@ def create_vision_encoder_and_processor(self, mm_cfg): from transformers import AutoConfig config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained) - if config.architectures[0] == "CLIPVisionModel": + if config.architectures[0] == "CLIPVisionModel" or config.architectures[0] == "CLIPModel": vision_encoder = CLIPVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, @@ -483,7 +488,7 @@ def create_vision_encoder_and_processor(self, mm_cfg): for param in vision_encoder.parameters(): param.requires_grad = False vision_encoder = vision_encoder.eval() - elif config.architectures[0] == "SiglipVisionModel": + elif config.architectures[0] == "SiglipVisionModel" or config.architectures[0] == "SiglipModel": vision_encoder = SiglipVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, @@ -1205,7 +1210,8 @@ def setup(self, stage=None): else: # TODO: consider adding a ModelPT guard to check if model is being restored. # allowing restored models to optionally setup datasets - self.build_train_valid_test_datasets() + #self.build_train_valid_test_datasets() + self.build_train_valid_test_datasets_blend() self.setup_training_data(self.cfg.data) self.setup_validation_data(self.cfg.data) self.setup_test_data(self.cfg.data) @@ -1224,6 +1230,93 @@ def setup(self, stage=None): if self.cfg.get('transformer_engine', False): self.setup_transformer_engine_tp_groups() + + def build_train_valid_test_datasets_blend(self): + logging.info('Building Blending Neva datasets.') + + train_datasets = [] + valid_datasets = [] + + data_cfg = self.cfg.data + is_packed_sequence = data_cfg.get("packed_sequence", False) + + if is_packed_sequence: + assert self.cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence" + + # Check if concat_sampling_probabilities is properly set + if data_cfg.get('concat_sampling_probabilities') is None or not isinstance( + data_cfg.concat_sampling_probabilities, ListConfig + ): + raise ValueError("concat_sampling_probabilities must be a ListConfig with the same number of entries as data_file_names.") + + if len(data_cfg.concat_sampling_probabilities) != len(data_cfg.data_file_names): + raise ValueError( + f"concat_sampling_probabilities must be of the same size as data_file_names. " + f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.data_file_names)}" + ) + + for data_file in data_cfg.data_file_names: + if is_packed_sequence: + train_dataset = NevaPackedSeqDatatset( + data_file, self.cfg.mm_cfg.vision_encoder.get("crop_size") + ) + valid_dataset = NevaPackedSeqDatatset( + data_file, self.cfg.mm_cfg.vision_encoder.get("crop_size") + ) + else: + ds_dict = make_supervised_data_module( + tokenizer=self.tokenizer, + image_processor=( + self.model.module.image_processor if hasattr(self.model, "module") else self.model.image_processor + ), + model_cfg=self.cfg, + data_file=data_file, + ) + train_dataset = ds_dict["train_dataset"] + valid_dataset = ds_dict["eval_dataset"] + + train_datasets.append(train_dataset) + valid_datasets.append(valid_dataset) + + # Create BlendableDataset for training + if self.trainer.max_steps is None or self.trainer.max_steps <= 0: + raise ValueError(f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}') + + num_train_samples = self.trainer.max_steps * data_cfg.global_batch_size + _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples( + data_prefix=[weight for pair in zip(data_cfg.concat_sampling_probabilities, data_cfg.data_file_names) for weight in pair], + num_samples=[num_train_samples] + ) + num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) + + logging.info(f"Number of train datasets: {len(train_datasets)}") + logging.info(f"Lengths of train datasets: {[len(ds) for ds in train_datasets]}") + logging.info(f"concat_sampling_probabilities: {data_cfg.concat_sampling_probabilities}") + logging.info(f"num_train_samples_after_blend: {num_train_samples_after_blend}") + + self._train_ds = BlendableDataset( + datasets=train_datasets, + weights=data_cfg.concat_sampling_probabilities, + size=num_train_samples_after_blend + ) + + self._validation_ds = BlendableDataset( + datasets=valid_datasets, + weights=data_cfg.concat_sampling_probabilities, + size=num_train_samples_after_blend + ) + + + logging.info(f'Length of train dataset: {len(self._train_ds)}') + logging.info(f'Length of validation dataset: {len(self._validation_ds)}') + + + ######### Use ConcatDataset instead of BlendableDataset########## + # self._train_ds = ConcatDataset(train_datasets) + # self._validation_ds = ConcatDataset(valid_datasets) + ################################################################## + + return self._train_ds, self._validation_ds def build_train_valid_test_datasets(self): logging.info('Building Neva datasets.') @@ -1286,7 +1379,7 @@ def build_pretraining_data_loader( raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') else: raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"') - + collate_func = DataCollatorForSupervisedDataset(self.cfg, self.tokenizer) return torch.utils.data.DataLoader( dataset,