diff --git a/nemo/collections/llm/gpt/data/core.py b/nemo/collections/llm/gpt/data/core.py index af9686167dbd..b327e0494a89 100644 --- a/nemo/collections/llm/gpt/data/core.py +++ b/nemo/collections/llm/gpt/data/core.py @@ -47,34 +47,42 @@ def create_sft_dataset( memmap_workers: int = 2, hf_dataset: bool = False, global_sample_mapping: bool = False, + pack_metadata_file_path: Path = None, + pad_cu_seqlens: bool = False, **kwargs, ) -> "GPTSFTDataset": + + gpt_sft_dataset_kwargs = { + 'file_path': str(path), + 'tokenizer': tokenizer, + 'max_seq_length': seq_length, + 'memmap_workers': memmap_workers, + 'hf_dataset': hf_dataset, + 'global_sample_mapping': global_sample_mapping, + 'add_bos': add_bos, + 'add_eos': add_eos, + 'add_sep': add_sep, + 'seed': seed, + 'label_key': label_key, + 'answer_only_loss': answer_only_loss, + 'truncation_field': truncation_field, + 'pad_to_max_length': pad_to_max_length, + 'index_mapping_dir': index_mapping_dir, + 'prompt_template': prompt_template, + 'truncation_method': truncation_method, + } + if path.suffix == '.npy': from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTPackedDataset - - dataset_cls = GPTSFTPackedDataset + return GPTSFTPackedDataset( + pack_metadata_file_path=pack_metadata_file_path, + pad_cu_seqlens=pad_cu_seqlens, + **gpt_sft_dataset_kwargs, + **kwargs, + ) else: from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset - - dataset_cls = GPTSFTDataset - - return dataset_cls( - file_path=str(path), - tokenizer=tokenizer, - max_seq_length=seq_length, - memmap_workers=memmap_workers, - hf_dataset=hf_dataset, - global_sample_mapping=global_sample_mapping, - add_bos=add_bos, - add_eos=add_eos, - add_sep=add_sep, - seed=seed, - label_key=label_key, - answer_only_loss=answer_only_loss, - truncation_field=truncation_field, - pad_to_max_length=pad_to_max_length, - index_mapping_dir=index_mapping_dir, - prompt_template=prompt_template, - truncation_method=truncation_method, - **kwargs, - ) + return GPTSFTDataset( + **gpt_sft_dataset_kwargs, + **kwargs, + ) diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 0d866bb600fe..4f67a94539eb 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -74,6 +74,7 @@ def __init__( persistent_workers: bool = False, packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, dataset_kwargs: Optional[Dict[str, Any]] = None, + pad_cu_seqlens: Optional[bool] = False, ): super().__init__() self.seq_length = seq_length @@ -93,6 +94,7 @@ def __init__( self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size self.validate_batch_size_for_packed_sequence() self.dataset_kwargs = dataset_kwargs or {} + self._pad_cu_seqlens = pad_cu_seqlens def validate_batch_size_for_packed_sequence(self): """ @@ -128,6 +130,7 @@ def prepare_data(self) -> None: tokenizer=self.tokenizer, max_seq_length=self.seq_length, seed=self.seed, + output_metadata_path=self.train_pack_metadata, ) if not self.validation_path_packed.is_file(): @@ -138,6 +141,7 @@ def prepare_data(self) -> None: tokenizer=self.tokenizer, max_seq_length=self.seq_length, seed=self.seed, + output_metadata_path=self.val_pack_metadata, ) def setup(self, stage: str): @@ -196,6 +200,7 @@ def train_dataloader(self) -> DataLoader: return self._create_dataloader( self._create_dataset( self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed, + pack_metadata_path=None if self.packed_sequence_size <= 0 else self.train_pack_metadata, max_num_samples=self.max_train_samples, **self.dataset_kwargs, ), @@ -207,6 +212,7 @@ def val_dataloader(self) -> DataLoader: return self._create_dataloader( self._create_dataset( self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed, + pack_metadata_path=None if self.packed_sequence_size <= 0 else self.val_pack_metadata, is_test=True, **self.dataset_kwargs, ), @@ -226,15 +232,18 @@ def test_dataloader(self) -> DataLoader: ) @lru_cache - def _create_dataset(self, path, is_test=False, **kwargs): + def _create_dataset(self, path, pack_metadata_path=None, is_test=False, **kwargs): # pylint: disable=C0115,C0116 + is_not_packing = is_test or self.packed_sequence_size <= 0 return create_sft_dataset( path, tokenizer=self.tokenizer, - seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size), + seq_length=(self.seq_length if is_not_packing else self.packed_sequence_size), memmap_workers=self.memmap_workers, seed=self.seed, is_test=is_test, + pack_metadata_file_path=None if is_not_packing else pack_metadata_path, + pad_cu_seqlens=False if is_not_packing else self.pad_cu_seqlens, **kwargs, ) @@ -255,6 +264,32 @@ def train_path(self) -> Path: """Path to training dataset file""" return self.dataset_root / "training.jsonl" + @property + def train_pack_metadata(self) -> Path: + """Path to metadata dataset file for packed sequence.""" + if self.packed_sequence_size > 0: + if self.packed_sequence_specs.packed_val_metadata_path is not None: + return self.packed_sequence_specs.packed_train_metadata_path + tokenizer_model_name = self._extract_tokenizer_model_name() + folder_name = self.dataset_root / "packed" / tokenizer_model_name + folder_name.mkdir(parents=True, exist_ok=True) + return folder_name / f"train_{self.packed_sequence_size}_metadata.jsonl" + else: + raise ValueError("`train_pack_metadata invalid since packed sequence size is not specified.") + + @property + def val_pack_metadata(self) -> Path: + """Path to metadata dataset file for packed sequence.""" + if self.packed_sequence_size > 0: + if self.packed_sequence_specs.packed_val_metadata_path is not None: + return self.packed_sequence_specs.packed_val_metadata_path + tokenizer_model_name = self._extract_tokenizer_model_name() + folder_name = self.dataset_root / "packed" / tokenizer_model_name + folder_name.mkdir(parents=True, exist_ok=True) + return folder_name / f"val_{self.packed_sequence_size}_metadata.jsonl" + else: + raise ValueError("val_pack_metadata invalid since packed sequence size is not specified.") + @property def train_path_packed(self) -> Path: """Path to training dataset file for packed sequence. The file path contains a reference to the @@ -293,6 +328,16 @@ def test_path(self) -> Path: """Path to test dataset file""" return self.dataset_root / "test.jsonl" + @property + def pad_cu_seqlens(self) -> bool: + """Whether to pad cu_seqlens to a constant shape""" + if self.packed_sequence_size > 0: + if self.packed_sequence_specs.pad_cu_seqlens is not None: + return self.packed_sequence_specs.pad_cu_seqlens + else: + return self._pad_cu_seqlens + return False + def _extract_tokenizer_model_name(self) -> str: """Automatically get the model name from model path.""" if self.packed_sequence_specs.tokenizer_model_name is not None: diff --git a/nemo/collections/llm/gpt/data/packed_sequence.py b/nemo/collections/llm/gpt/data/packed_sequence.py index 345489ea0b63..596e5724221d 100644 --- a/nemo/collections/llm/gpt/data/packed_sequence.py +++ b/nemo/collections/llm/gpt/data/packed_sequence.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Optional +import json import numpy as np from nemo.collections.common.tokenizers import TokenizerSpec @@ -50,6 +51,7 @@ def tokenize_dataset(path: Path, tokenizer: TokenizerSpec, max_seq_length: int, def prepare_packed_sequence_data( input_path: Path, output_path: Path, + output_metadata_path: Path, packed_sequence_size: int, tokenizer: TokenizerSpec, max_seq_length: int, @@ -77,11 +79,15 @@ def prepare_packed_sequence_data( dataset = tokenize_dataset(input_path, tokenizer, max_seq_length, seed) sequences, histogram = create_hist(dataset, max_seq_length) - assignments = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm) + assignments, packing_metadata = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm) output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size) # save output data np.save(output_path, output_data) + # save packing metadata + if output_metadata_path is not None: + with open(output_metadata_path, "w" ) as f: + json.dump(packing_metadata, f) logging.info(f"Packed sequence is prepared and saved to {output_path}") @@ -111,6 +117,21 @@ class PackedSequenceSpecs: If specified, use this file for the packed validation dataset instead of the default path. """ + packed_train_metadata_path: str = None + """ + If specified, use this file for the train packing metadata instead of the default path. + """ + + packed_val_metadata_path: str = None + """ + If specified, use this file for the val packing metadata instead of the default path. + """ + + pad_cu_seqlens: str = None + """ + If specified, pad cu_seqlens to a constant size, which is required for use with cudagraphs. + """ + def __post_init__(self): if self.packed_train_data_path is not None: self.packed_train_data_path = Path(self.packed_train_data_path) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 898ddb7d716b..3af7301a5972 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import math import re from typing import List, Mapping, Optional @@ -524,7 +525,13 @@ def collate_fn(self, batch): class GPTSFTPackedDataset(GPTSFTDataset): - def __init__(self, file_path: str, tokenizer: TokenizerSpec, return_cu_seqlen: bool = True, **kwargs): + def __init__(self, + file_path: str, + tokenizer: TokenizerSpec, + return_cu_seqlen: bool = True, + pad_cu_seqlens: bool = False, + pack_metadata_file_path: Optional[str] = None, + **kwargs): """ file_path: See `file_path` in the parent class. tokenizer: See `tokenizer` in the parent class. @@ -537,6 +544,16 @@ def __init__(self, file_path: str, tokenizer: TokenizerSpec, return_cu_seqlen: b assert self.virtual_tokens == 0, "P-Tuning with packed sequence is not supported." self.return_cu_seqlen = return_cu_seqlen + self.pad_cu_seqlens = pad_cu_seqlens + if self.pad_cu_seqlens: + assert pack_metadata_file_path is not None, \ + "a metadata json file is required when pad_cu_seqlens is enabled" + + self.pack_metadata = None + if pack_metadata_file_path is not None: + with open(pack_metadata_file_path) as f: + self.pack_metadata = json.load(f) + def __getitem__(self, idx): if self.samples_mapping is not None: # assert idx < len(self.samples_mapping) @@ -633,6 +650,11 @@ def collate_fn(self, batch): # set last seq to the max seq len because rope and attn kernels expect no padding cu_seqlens[-1][-1] = max_length + if self.pad_cu_seqlens: + # pad cu_seqlens with zero length sequences + pad_num = self.pack_metadata['max_samples_per_bin'] - len(cu_seqlens[-1]) + cu_seqlens[-1].extend([max_length]*pad_num) + assert len(input_ids[0]) == len( position_ids[0] ), "Dataset problem: input_ids and position_ids lengths don't match" @@ -656,8 +678,15 @@ def collate_fn(self, batch): # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. cu_seqlens = torch.IntTensor(cu_seqlens) cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) - seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] - max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + + if self.pad_cu_seqlens: + # Use the global max seqlen, as 'pad_cu_seqlens' is used mainly + # to support cudagraphs, and 'max_seqlen' is a cpu tensor, which means should + # be the same across all batches. + max_seqlen = torch.IntTensor([self.pack_metadata['dataset_max_seqlen']]*len(cu_seqlens)) + else: + seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] + max_seqlen, _ = seqlens.max(dim=1, keepdim=True) processed_batch.update( { diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index ec97bcb90853..729433fa43f4 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -99,7 +99,7 @@ class ParallelismConfig: pipeline_dtype: torch.dtype encoder_tensor_model_parallel_size: int = 0 encoder_pipeline_model_parallel_size: int = 0 - use_te_rng_tracker: bool + use_te_rng_tracker: bool = False class MegatronStrategy(DDPStrategy, io.IOMixin): diff --git a/nemo/utils/sequence_packing_utils.py b/nemo/utils/sequence_packing_utils.py index cee2be248f73..bbb90d0d5c3c 100644 --- a/nemo/utils/sequence_packing_utils.py +++ b/nemo/utils/sequence_packing_utils.py @@ -153,6 +153,7 @@ def create_packing_strategy( Returns: assignments: A list of lists, where each inner list represents a bin and contains the indices of the sequence lengths assigned to that bin. + pack_metadata #TODO """ logging.info(f"Packing sequences to length {pack_size}...") @@ -166,13 +167,17 @@ def create_packing_strategy( packed_seq_lens = [sum(x) for x in assignments] packing_factor = len(all_seq_lens) / len(packed_seq_lens) + max_seqlen = max(all_seq_lens) + max_samples_per_bin = max([len(b) for b in assignments]) + packing_metadata = {'dataset_max_seqlen': max_seqlen, 'max_samples_per_bin' : max_samples_per_bin} + logging.debug("Packed sequence lengths:") logging.debug(packed_seq_lens) logging.info(f"Packing is {sum(packed_seq_lens)/len(packed_seq_lens)/pack_size*100:.2f}% efficient") logging.info( f">>>>> For pack size {pack_size}, average number of sequences per pack is n = {packing_factor:.3f} <<<<<" ) - return assignments + return assignments, packing_metadata def fill_packing_strategy(