diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml index 9ec6e51bb004..694bf0df95f1 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml @@ -60,6 +60,7 @@ model: tensor_model_parallel_size: 1 # intra-layer model parallelism pipeline_model_parallel_size: 1 # inter-layer model parallelism + context_parallel_size: 1 # kqv model parallelism virtual_pipeline_model_parallel_size: null # interleaved pipeline restore_from_path: null # used in fine-tuning @@ -186,12 +187,23 @@ model: num_workers: 8 dataloader_type: cyclic data_path: + # could be a path to a single file or a list of files for data blending like below + # - /path/to/json + # - /path/to/json + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + concat_sampling_probabilities: null + # - 0.5 + # - 0.5 lazy_preprocess: True is_multimodal: True media_type: image # currently supported: image + num_frames: -1 sep_image_conv_front: False conv_template: ${model.mm_cfg.llm.model_type} # check `nemo/collections/multimodal/data/neva/conversation.py` + image_token_len: 576 image_folder: null + video_folder: null image_aspect_ratio: 'square' # Nsys profiling options diff --git a/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py b/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py old mode 100644 new mode 100755 index be1edd66aeb0..b670d171fd1d --- a/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py +++ b/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py @@ -243,9 +243,38 @@ def pack_sequence(args, seq_lens): bins = packing_fn(seq_lens, args.max_seq_length) return bins -def process_data_file(train_dl, prefix_path, data_file): + +def main(): + torch.multiprocessing.set_sharing_strategy('file_system') + + args = get_args() + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.model.mm_cfg.vision_encoder.from_pretrained = args.hf_vision_encoder + nemo_config.model.data.data_path = args.data_path + nemo_config.model.data.image_folder = args.image_folder + nemo_config.model.data.conv_template = args.conv_template + nemo_config.model.data.image_aspect_ratio = args.image_aspect_ratio + + tokenizer = get_nmt_tokenizer( + library="sentencepiece", + tokenizer_model=args.tokenizer_path, + ) + image_processor = create_image_processor(nemo_config.model.mm_cfg) + train_ds = make_supervised_data_module( + tokenizer=tokenizer, image_processor=image_processor, model_cfg=nemo_config.model + )["train_dataset"] + train_dl = DataLoader(train_ds, num_workers=32, collate_fn=None, shuffle=False) + # Example shape: {'tokens': torch.Size([1, 344]), 'labels': torch.Size([1, 344]), 'image': torch.Size([1, 1, 3, 224, 224])} + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + logging.info(f"Output directory: {output_dir}") + + prefix_path = f"{output_dir}/packed_seq_dataset" + os.makedirs(prefix_path, exist_ok=True) + # Original Datasets to Sequence Lengths Files builders = {} - for item_dict in tqdm(train_dl, desc=f"Building indexed datasets for {data_file}"): + for item_dict in tqdm(train_dl, desc="Building indexed datasets"): item_dict = {k: v[0] for k, v in item_dict.items()} seq_len = len(item_dict['tokens']) if seq_len in builders: @@ -266,7 +295,7 @@ def process_data_file(train_dl, prefix_path, data_file): logging.info(f"Finalizing builder for sequence length {seq_len} at {idx_path}") builder.finalize(idx_path) -def pack_sequences_into_bins(args, output_dir, prefix_path): + # Packing Sequences into Bins files = os.listdir(f"{output_dir}/packed_seq_dataset") pattern = rf"seqlen_(\d+).bin" seq_len_list = [] @@ -275,22 +304,16 @@ def pack_sequences_into_bins(args, output_dir, prefix_path): if match: seq_len = int(match.group(1)) seq_len_list.append(seq_len) - + aggregated_seq_lens = [] doc_pop_order = {} indexed_datasets = {} - error_len = 0 for seq_len in seq_len_list: dataset_path = f"{prefix_path}/seqlen_{seq_len}" - try: - dataset = IndexedDataset(dataset_path, multimodal=True) - aggregated_seq_lens.extend([seq_len] * (len(dataset.document_indices) - 1)) - doc_pop_order[seq_len] = list(np.random.permutation(len(dataset.document_indices) - 1)) - indexed_datasets[seq_len] = dataset - except Exception as e: - error_len += 1 - logging.error(f"Error while processing {dataset_path}: {e}") - logging.info(f"Number of errors: {error_len}") + dataset = IndexedDataset(dataset_path, multimodal=True) + aggregated_seq_lens.extend([seq_len] * (len(dataset.document_indices) - 1)) + doc_pop_order[seq_len] = list(np.random.permutation(len(dataset.document_indices) - 1)) + indexed_datasets[seq_len] = dataset logging.info("Getting bins") bins = pack_sequence(args, aggregated_seq_lens) @@ -301,6 +324,7 @@ def pack_sequences_into_bins(args, output_dir, prefix_path): avg_bins_sum = sum([sum(x) for x in bins]) / num_bins logging.info(f"Number of bins: {num_bins}, Average bin length: {avg_bins_len}, Average bin sum: {avg_bins_sum}") + # Reading Sequence Lengths and Packing into New Files final_builder_path = get_bin_path(f"{prefix_path}") logging.info(f"Creating final builder at {final_builder_path}") final_builder = IndexedDatasetBuilder(final_builder_path, dtype=np.float32, multimodal=True) @@ -333,41 +357,6 @@ def pack_sequences_into_bins(args, output_dir, prefix_path): final_builder.finalize(idx_path) logging.info(f"Number of bins: {num_bins}, Average bin length: {avg_bins_len}, Average bin sum: {avg_bins_sum}") -def main(): - torch.multiprocessing.set_sharing_strategy('file_system') - - args = get_args() - nemo_config = OmegaConf.load(args.hparams_file) - nemo_config.model.mm_cfg.vision_encoder.from_pretrained = args.hf_vision_encoder - nemo_config.model.data.data_path = args.data_path - nemo_config.model.data.image_folder = args.image_folder - nemo_config.model.data.conv_template = args.conv_template - nemo_config.model.data.image_aspect_ratio = args.image_aspect_ratio - tokenizer = get_nmt_tokenizer( - library="sentencepiece", - tokenizer_model=args.tokenizer_path, - ) - image_processor = create_image_processor(nemo_config.model.mm_cfg) - output_dir = args.output_dir - os.makedirs(output_dir, exist_ok=True) - logging.info(f"Output directory: {output_dir}") - - prefix_path = f"{output_dir}/packed_seq_dataset" - os.makedirs(prefix_path, exist_ok=True) - - data_files = nemo_config.model.data.data_file_names - for data_file in data_files: - logging.info(f"Processing data file: {data_file}") - - train_ds = make_supervised_data_module( - tokenizer=tokenizer, image_processor=image_processor, model_cfg=nemo_config.model, data_file=data_file - )["train_dataset"] - train_dl = DataLoader(train_ds, num_workers=32, collate_fn=None, shuffle=False) - - process_data_file(train_dl, prefix_path, data_file) - - pack_sequences_into_bins(args, output_dir, prefix_path) - if __name__ == "__main__": main() diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index c808fc5be0d1..96aa556cff47 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -1229,7 +1229,7 @@ class DataCollatorForSupervisedDataset(object): tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - packed_sequence = "cu_seqlens" in instances[0] + packed_sequence = "cu_seqlens" in instances[0] max_len = max(instance['tokens'].shape[0] for instance in instances) max_len = (max_len - 1) // 64 * 64 + 64 for instance in instances: @@ -1310,7 +1310,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: return batch -def make_supervised_data_module(tokenizer, image_processor, model_cfg, data_file) -> Dict: +def make_supervised_data_module(tokenizer, image_processor, model_cfg, data_file=None) -> Dict: """Make dataset and collator for supervised fine-tuning.""" data_cfg = model_cfg.data mm_cfg = model_cfg.mm_cfg @@ -1318,11 +1318,7 @@ def make_supervised_data_module(tokenizer, image_processor, model_cfg, data_file if getattr(model_cfg, 'no_seqlen_plus_one_input_tokens', False): add_extra_token = 0 crop_size = mm_cfg.vision_encoder.get("crop_size", (224, 224)) - if not data_cfg.get("data_path"): - data_path = data_file - else: - data_path = data_cfg.data_path - # use blend + data_path = data_file if data_file is not None else data_cfg.data_path train_dataset = NevaDataset( tokenizer=tokenizer, data_path=data_path, diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 442f0cc400e9..7ba8e9d3b2fe 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -1210,8 +1210,7 @@ def setup(self, stage=None): else: # TODO: consider adding a ModelPT guard to check if model is being restored. # allowing restored models to optionally setup datasets - #self.build_train_valid_test_datasets() - self.build_train_valid_test_datasets_blend() + self.build_train_valid_test_datasets() self.setup_training_data(self.cfg.data) self.setup_validation_data(self.cfg.data) self.setup_test_data(self.cfg.data) @@ -1247,15 +1246,15 @@ def build_train_valid_test_datasets_blend(self): if data_cfg.get('concat_sampling_probabilities') is None or not isinstance( data_cfg.concat_sampling_probabilities, ListConfig ): - raise ValueError("concat_sampling_probabilities must be a ListConfig with the same number of entries as data_file_names.") + raise ValueError("concat_sampling_probabilities must be a ListConfig with the same number of entries as data_path.") - if len(data_cfg.concat_sampling_probabilities) != len(data_cfg.data_file_names): + if len(data_cfg.concat_sampling_probabilities) != len(data_cfg.data_path): raise ValueError( f"concat_sampling_probabilities must be of the same size as data_file_names. " - f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.data_file_names)}" + f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.data_path)}" ) - for data_file in data_cfg.data_file_names: + for data_file in data_cfg.data_path: if is_packed_sequence: train_dataset = NevaPackedSeqDatatset( data_file, self.cfg.mm_cfg.vision_encoder.get("crop_size") @@ -1277,22 +1276,24 @@ def build_train_valid_test_datasets_blend(self): train_datasets.append(train_dataset) valid_datasets.append(valid_dataset) - + # Create BlendableDataset for training if self.trainer.max_steps is None or self.trainer.max_steps <= 0: raise ValueError(f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}') num_train_samples = self.trainer.max_steps * data_cfg.global_batch_size _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples( - data_prefix=[weight for pair in zip(data_cfg.concat_sampling_probabilities, data_cfg.data_file_names) for weight in pair], + data_prefix=[weight for pair in zip(data_cfg.concat_sampling_probabilities, data_cfg.data_path) for weight in pair], num_samples=[num_train_samples] ) num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) logging.info(f"Number of train datasets: {len(train_datasets)}") logging.info(f"Lengths of train datasets: {[len(ds) for ds in train_datasets]}") - logging.info(f"concat_sampling_probabilities: {data_cfg.concat_sampling_probabilities}") - logging.info(f"num_train_samples_after_blend: {num_train_samples_after_blend}") + logging.info(f"Number of train datasets after blending: {num_train_samples_after_blend}") + + if is_packed_sequence: + num_train_samples_after_blend = sum([len(ds) for ds in train_datasets]) self._train_ds = BlendableDataset( datasets=train_datasets, @@ -1306,20 +1307,16 @@ def build_train_valid_test_datasets_blend(self): size=num_train_samples_after_blend ) - logging.info(f'Length of train dataset: {len(self._train_ds)}') logging.info(f'Length of validation dataset: {len(self._validation_ds)}') - - ######### Use ConcatDataset instead of BlendableDataset########## - # self._train_ds = ConcatDataset(train_datasets) - # self._validation_ds = ConcatDataset(valid_datasets) - ################################################################## - return self._train_ds, self._validation_ds def build_train_valid_test_datasets(self): logging.info('Building Neva datasets.') + if isinstance(self.cfg.data.data_path, ListConfig) and self.cfg.data.get('concat_sampling_probabilities'): + return self.build_train_valid_test_datasets_blend() + if self.cfg.data.get("packed_sequence", False): assert self.cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence" self._train_ds = NevaPackedSeqDatatset( diff --git a/nemo/collections/nlp/data/language_modeling/megatron/blendable_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/blendable_dataset.py index ae2b5fff6be1..c6c23aa16bdd 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/blendable_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/blendable_dataset.py @@ -25,7 +25,6 @@ class BlendableDataset(torch.utils.data.Dataset): def __init__(self, datasets, weights, size): - self.datasets = datasets num_datasets = len(datasets) assert num_datasets == len(weights) @@ -43,6 +42,7 @@ def __init__(self, datasets, weights, size): assert num_datasets < 255 self.dataset_index = np.zeros(self.size, dtype=np.uint8) self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) + app_state = AppState() try: if app_state.local_rank == 0: @@ -74,7 +74,13 @@ def __len__(self): def __getitem__(self, idx): dataset_idx = self.dataset_index[idx] sample_idx = self.dataset_sample_index[idx] - return self.datasets[dataset_idx][sample_idx] + # Ensure the sample index doesn't exceed the dataset size + # original build_index function does not handle the extreme case properly + sample_idx = sample_idx % len(self.datasets[dataset_idx]) + data = self.datasets[dataset_idx][sample_idx] + + return data + def create_data_mmap(self): for dataset in self.datasets: