From 48349e462bf695df017e0624b4d891614542ef2f Mon Sep 17 00:00:00 2001 From: tomlifu Date: Wed, 27 Nov 2024 16:03:04 -0800 Subject: [PATCH] Context Parallel SFT Support for dataset in THD format (#10688) * Add context parallel support for packed dataset in THD format * add changes with debug print * remove debug print Signed-off-by: Lifu Zhang * Apply isort and black reformatting Signed-off-by: tomlifu * fix cu_seqlens and cu_seqlens_padded Signed-off-by: Lifu Zhang * cu_seqlens and cu_seqlens_padded fix Signed-off-by: Lifu Zhang * more fix on cu_seqlens and cu_seqlens_padded Signed-off-by: Lifu Zhang * Apply isort and black reformatting Signed-off-by: tomlifu * addressing Xiaowei's review Signed-off-by: Lifu Zhang * addressing more review comments Signed-off-by: Lifu Zhang * fix for the case where cp=1 Signed-off-by: Lifu Zhang * Apply isort and black reformatting Signed-off-by: tomlifu * more fix to address Xiaowei's comment Signed-off-by: Lifu Zhang * fix the loss_mask for THD formated data Signed-off-by: Lifu Zhang * Apply isort and black reformatting Signed-off-by: tomlifu * fixed eos_idx[0][0] out of bounds issue Signed-off-by: Lifu Zhang * fixed CP=1 case Signed-off-by: Lifu Zhang * fixed thd_get_partitioned_indices assertion issue when pp=1 Signed-off-by: Lifu Zhang * Apply isort and black reformatting Signed-off-by: tomlifu * fixed data packing issue Signed-off-by: root * fixed an issue where cp>1 has different loss curves Signed-off-by: Lifu Zhang * remove redudant check for cu_seqlens Signed-off-by: Lifu Zhang * fixed NeMo CI failure issue due to old TE version in CI Signed-off-by: Lifu Zhang * Apply isort and black reformatting Signed-off-by: tomlifu --------- Signed-off-by: Lifu Zhang Signed-off-by: tomlifu Signed-off-by: root Signed-off-by: tomlifu Co-authored-by: tomlifu Co-authored-by: root Co-authored-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> --- .../megatron/gpt_sft_dataset.py | 48 +++++++++-- .../language_modeling/megatron_gpt_model.py | 81 ++++++++++++++----- nemo/utils/sequence_packing_utils.py | 4 +- .../prepare_packed_ft_dataset.py | 45 ++++++++++- 4 files changed, 150 insertions(+), 28 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 898ddb7d716b..9da2419520c2 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -573,15 +573,23 @@ def _build_samples_mapping(self): self.samples_mapping = None def _build_loss_mask(self, processed_example): + seq_boundaries = processed_example['seq_boundaries'] if self.answer_only_loss: - seq_boundaries = processed_example['seq_boundaries'] return np.concatenate( [ processed_example['loss_mask'][seq_boundaries[i] + 1 : seq_boundaries[i + 1]] for i in range(len(seq_boundaries) - 1) ] ) - return [1.0] * (len(processed_example['input_ids']) - len(processed_example['seq_boundaries']) + 1) + return np.concatenate( + [ + [ + 0 if x == self.tokenizer.eos_id else 1.0 + for x in processed_example['input_ids'][seq_boundaries[i] : seq_boundaries[i + 1] - 1] + ] + for i in range(len(seq_boundaries) - 1) + ] + ) def _maybe_cast_to_list(self, x): return [item.tolist() if isinstance(item, np.ndarray) else item for item in x] @@ -622,16 +630,40 @@ def collate_fn(self, batch): position_ids: List[List[int]] = [] cu_seqlens: List[List[int]] = [] + cu_seqlens_unpadded: List[List[int]] = [] for item in batch: position_ids.append([]) cu_seqlens.append([0]) + cu_seqlens_unpadded.append([0]) seqlens = np.array(item['seq_boundaries'][1:]) - np.array(item['seq_boundaries'][:-1]) for l in seqlens: # length minus 1 because input_ids is truncated by 1 for labels position_ids[-1].extend(list(range(l - 1))) cu_seqlens[-1].append(cu_seqlens[-1][-1] + l - 1) - # set last seq to the max seq len because rope and attn kernels expect no padding - cu_seqlens[-1][-1] = max_length + + # the last seq needs to be the max seq len because rope and attn kernels expect no padding + assert cu_seqlens[-1][-1] <= max_length + + # since data is prepadded when cp_size > 1, there may be some extra padding at the end + # of the packed sequence. In this case, we need to add the max seq len to the end. + if cu_seqlens[-1][-1] != max_length: + cu_seqlens[-1].append(max_length) + + for i in range(len(item['seq_boundaries']) - 1): + current_seq = item['input_ids'][item['seq_boundaries'][i] : item['seq_boundaries'][i + 1] - 1] + + # since the data could be prepadded with tokenizer's eos_id, we can find out the index of all the eos_id + eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id) + + # The second eos_id index marks the length of the original unpadded sequence if the sequence is + # prepadded for cp_size > 1. Otherwise, there is no extra padding. + seqlen_unpadded = eos_idx[0][0] + 1 if eos_idx[0].any() else len(current_seq) + cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded) + + # if extra paddings are added in the packed sequence, they can't be counted as + # actual tokens for training + if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]): + cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1]) assert len(input_ids[0]) == len( position_ids[0] @@ -652,12 +684,16 @@ def collate_fn(self, batch): if self.return_cu_seqlen: cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1) - + cu_seqlens_unpadded = self._collate_item( + cu_seqlens_unpadded, max_length=max(len(l) for l in cu_seqlens_unpadded) + 1, pad_id=-1 + ) # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. cu_seqlens = torch.IntTensor(cu_seqlens) cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded) + cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True) processed_batch.update( { @@ -667,6 +703,8 @@ def collate_fn(self, batch): 'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 'cu_seqlens_argmin': cu_seqlens_argmin, # only required for perf 'max_seqlen': max_seqlen, # only required for perf + 'cu_seqlens_unpadded': torch.IntTensor(cu_seqlens_unpadded), + 'cu_seqlens_unpadded_argmin': cu_seqlens_unpadded_argmin, } ) else: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f71b1ad13c6d..a4b8242e0185 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1231,22 +1231,23 @@ def get_batch_on_this_context_parallel_rank(self, batch): cp_size = parallel_state.get_context_parallel_world_size() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() - for key, val in batch.items(): - if val is not None and key != "context_lengths": - seq_dim = 1 if key != 'attention_mask' else 2 - val = val.view( - *val.shape[0:seq_dim], - 2 * cp_size, - val.shape[seq_dim] // (2 * cp_size), - *val.shape[(seq_dim + 1) :], - ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( - non_blocking=True - ) - val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) - batch[key] = val - + # check if the batch is not in THD format + if 'cu_seqlens' not in batch: + for key, val in batch.items(): + if val is not None and key != "context_lengths": + seq_dim = 1 if key != 'attention_mask' else 2 + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + batch[key] = val batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub return batch @@ -1261,12 +1262,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys = set() max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None + cu_seqlens_unpadded_argmin = ( + batch['cu_seqlens_unpadded_argmin'] if 'cu_seqlens_unpadded_argmin' in batch else None + ) if parallel_state.get_pipeline_model_parallel_world_size() == 1: required_keys.update(batch.keys()) else: required_keys.add('attention_mask') if 'cu_seqlens' in batch: required_keys.add('cu_seqlens') + if 'cu_seqlens_unpadded' in batch: + required_keys.add('cu_seqlens_unpadded') if parallel_state.is_pipeline_first_stage(): required_keys.update(('tokens', 'position_ids')) if parallel_state.is_pipeline_last_stage(): @@ -1301,12 +1307,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset # these args are passed eventually into TEDotProductAttention.forward() cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1) + cu_seqlens_unpadded = batch['cu_seqlens_unpadded'].squeeze() # remove -1 "paddings" added in collate_fn if cu_seqlens_argmin is not None: cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()] else: cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] - + if cu_seqlens_unpadded_argmin is not None: + cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()] + else: + cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)] try: from megatron.core.packed_seq_params import PackedSeqParams except (ImportError, ModuleNotFoundError) as e: @@ -1317,9 +1327,42 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ ) raise e + # get packed sequences for this context parallel rank + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + try: + import transformer_engine_torch as tex + except ModuleNotFoundError as e: + logging.error( + "Please update Transformer Engine to >= 1.10 to use Context Parallel with THD format data" + ) + raise e + cp_rank = parallel_state.get_context_parallel_rank() + for key in required_keys: + val = batch[key] + if key not in { + "cu_seqlens", + "cu_seqlens_unpadded", + "cu_seqlens_argmin", + "cu_seqlens_unpadded_argmin", + "max_seqlen", + "token_count", + }: + index = tex.thd_get_partitioned_indices(cu_seqlens, val.size(1), cp_size, cp_rank) + val = val.index_select(1, index) + batch[key] = val + forward_args = { + 'input_ids': batch['tokens'], + 'position_ids': batch['position_ids'], + 'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'], + 'labels': batch['labels'] if 'labels' in batch else None, + } + forward_args['packed_seq_params'] = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, + cu_seqlens_q=cu_seqlens_unpadded, + cu_seqlens_kv=cu_seqlens_unpadded, + cu_seqlens_q_padded=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, qkv_format='thd', diff --git a/nemo/utils/sequence_packing_utils.py b/nemo/utils/sequence_packing_utils.py index cee2be248f73..2ca03ce44b67 100644 --- a/nemo/utils/sequence_packing_utils.py +++ b/nemo/utils/sequence_packing_utils.py @@ -115,7 +115,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int): logging.info("Creating histogram from tokenized dataset...") sequences = collections.defaultdict(list) - counts = [0] * truncate_seq_len + counts = [0] * (truncate_seq_len + 1) for item_dict in dataset: # Minus 1 here to account for the fact that transformer input and label have one less token than the full sequence @@ -129,7 +129,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int): logging.debug(counts) histogram = [] - for seq_len in range(truncate_seq_len): + for seq_len in range(truncate_seq_len + 1): histogram.append(len(sequences[seq_len])) return sequences, histogram diff --git a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py index 7ff2342e4087..19a3e6a78228 100644 --- a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py +++ b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py @@ -88,6 +88,14 @@ def tokenize_dataset(cfg: 'DictConfig'): # using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings # are identical to normal SFT training data_cfg = cfg.model.data.train_ds + pad_seq_length_to_mult = 16 + cp_size = cfg.model.get("context_parallel_size", 1) + + # if context parallel is used, each individual data length in one packed dataset sample + # needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641 + if cp_size > 1: + pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2) + if os.path.isdir(cfg.tokenizer_path): # pass in a Hugging Face folder which contains tokenizer.json tokenizer = get_nmt_tokenizer(library="huggingface", model_name=cfg.tokenizer_path, use_fast=True) @@ -99,7 +107,7 @@ def tokenize_dataset(cfg: 'DictConfig'): tokenizer=tokenizer, max_seq_length=data_cfg.max_seq_length, min_seq_length=data_cfg.min_seq_length, - pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here + pad_seq_length_to_mult=pad_seq_length_to_mult, add_bos=data_cfg.get('add_bos', False), add_eos=data_cfg.get('add_eos', True), add_sep=data_cfg.get('add_sep', False), @@ -121,7 +129,40 @@ def tokenize_dataset(cfg: 'DictConfig'): is_test=True, ) - return np.array([dataset[i] for i in range(len(dataset))]) + max_seq_length = dataset.max_seq_length + pad_id = dataset.tokenizer.eos_id + pad_seq_length_to_mult = dataset.pad_seq_length_to_mult + dataset = np.array([dataset[i] for i in range(len(dataset))]) + if cp_size > 1: + + def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id): + ''' + pad each individual data point to the length of max_length + ''' + assert max_seq_length >= max_length_to_pad + for key, val in data.items(): + if key in {'input_ids', 'context_ids'}: + if len(val) <= max_length_to_pad: + # because input_ids are truncated by 1 for inputs and labels, + # we add 1 extra padding here to make sure padded inputs and labels + # are is a multiple of (cp_size * 2) + val = val + [pad_id] * (max_length_to_pad - len(val) + 1) + data[key] = val + elif len(val) > max_seq_length: + logging.info( + f"""The current sequence length {len(val)} for packing is + larger than the max_seq_length specified ({max_seq_length}). + The current seqquence length is truncated to the size of max_seq_length. + Please consider increase the sequence packing size""" + ) + data[key] = val[:max_seq_length] + return + + ceil_to_nearest = lambda n, m: (n + m - 1) // m * m + for data in dataset: + max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult)) + pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id) + return dataset @dataclass