Skip to content

Commit

Permalink
Context Parallel SFT Support for dataset in THD format (NVIDIA#10688)
Browse files Browse the repository at this point in the history
* Add context parallel support for packed dataset in THD format

* add changes with debug print

* remove debug print

Signed-off-by: Lifu Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tomlifu <[email protected]>

* fix cu_seqlens and cu_seqlens_padded

Signed-off-by: Lifu Zhang <[email protected]>

* cu_seqlens and cu_seqlens_padded fix

Signed-off-by: Lifu Zhang <[email protected]>

* more fix on cu_seqlens and cu_seqlens_padded

Signed-off-by: Lifu Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tomlifu <[email protected]>

* addressing Xiaowei's review

Signed-off-by: Lifu Zhang <[email protected]>

* addressing more review comments

Signed-off-by: Lifu Zhang <[email protected]>

* fix for the case where cp=1

Signed-off-by: Lifu Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tomlifu <[email protected]>

* more fix to address Xiaowei's comment

Signed-off-by: Lifu Zhang <[email protected]>

* fix the loss_mask for THD formated data

Signed-off-by: Lifu Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tomlifu <[email protected]>

* fixed eos_idx[0][0] out of bounds issue

Signed-off-by: Lifu Zhang <[email protected]>

* fixed CP=1 case

Signed-off-by: Lifu Zhang <[email protected]>

* fixed thd_get_partitioned_indices assertion issue when pp=1

Signed-off-by: Lifu Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tomlifu <[email protected]>

* fixed data packing issue

Signed-off-by: root <[email protected]>

* fixed an issue where cp>1 has different loss curves

Signed-off-by: Lifu Zhang <[email protected]>

* remove redudant check for cu_seqlens

Signed-off-by: Lifu Zhang <[email protected]>

* fixed NeMo CI failure issue due to old TE version in CI

Signed-off-by: Lifu Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: tomlifu <[email protected]>

---------

Signed-off-by: Lifu Zhang <[email protected]>
Signed-off-by: tomlifu <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: tomlifu <[email protected]>
Co-authored-by: tomlifu <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Xiaowei Ren <[email protected]>
  • Loading branch information
4 people authored Nov 28, 2024
1 parent a6b08a6 commit 48349e4
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -573,15 +573,23 @@ def _build_samples_mapping(self):
self.samples_mapping = None

def _build_loss_mask(self, processed_example):
seq_boundaries = processed_example['seq_boundaries']
if self.answer_only_loss:
seq_boundaries = processed_example['seq_boundaries']
return np.concatenate(
[
processed_example['loss_mask'][seq_boundaries[i] + 1 : seq_boundaries[i + 1]]
for i in range(len(seq_boundaries) - 1)
]
)
return [1.0] * (len(processed_example['input_ids']) - len(processed_example['seq_boundaries']) + 1)
return np.concatenate(
[
[
0 if x == self.tokenizer.eos_id else 1.0
for x in processed_example['input_ids'][seq_boundaries[i] : seq_boundaries[i + 1] - 1]
]
for i in range(len(seq_boundaries) - 1)
]
)

def _maybe_cast_to_list(self, x):
return [item.tolist() if isinstance(item, np.ndarray) else item for item in x]
Expand Down Expand Up @@ -622,16 +630,40 @@ def collate_fn(self, batch):

position_ids: List[List[int]] = []
cu_seqlens: List[List[int]] = []
cu_seqlens_unpadded: List[List[int]] = []
for item in batch:
position_ids.append([])
cu_seqlens.append([0])
cu_seqlens_unpadded.append([0])
seqlens = np.array(item['seq_boundaries'][1:]) - np.array(item['seq_boundaries'][:-1])
for l in seqlens:
# length minus 1 because input_ids is truncated by 1 for labels
position_ids[-1].extend(list(range(l - 1)))
cu_seqlens[-1].append(cu_seqlens[-1][-1] + l - 1)
# set last seq to the max seq len because rope and attn kernels expect no padding
cu_seqlens[-1][-1] = max_length

# the last seq needs to be the max seq len because rope and attn kernels expect no padding
assert cu_seqlens[-1][-1] <= max_length

# since data is prepadded when cp_size > 1, there may be some extra padding at the end
# of the packed sequence. In this case, we need to add the max seq len to the end.
if cu_seqlens[-1][-1] != max_length:
cu_seqlens[-1].append(max_length)

for i in range(len(item['seq_boundaries']) - 1):
current_seq = item['input_ids'][item['seq_boundaries'][i] : item['seq_boundaries'][i + 1] - 1]

# since the data could be prepadded with tokenizer's eos_id, we can find out the index of all the eos_id
eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id)

# The second eos_id index marks the length of the original unpadded sequence if the sequence is
# prepadded for cp_size > 1. Otherwise, there is no extra padding.
seqlen_unpadded = eos_idx[0][0] + 1 if eos_idx[0].any() else len(current_seq)
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded)

# if extra paddings are added in the packed sequence, they can't be counted as
# actual tokens for training
if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]):
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1])

assert len(input_ids[0]) == len(
position_ids[0]
Expand All @@ -652,12 +684,16 @@ def collate_fn(self, batch):

if self.return_cu_seqlen:
cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1)

cu_seqlens_unpadded = self._collate_item(
cu_seqlens_unpadded, max_length=max(len(l) for l in cu_seqlens_unpadded) + 1, pad_id=-1
)
# Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies.
cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True)
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)
cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded)
cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True)

processed_batch.update(
{
Expand All @@ -667,6 +703,8 @@ def collate_fn(self, batch):
'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
'cu_seqlens_argmin': cu_seqlens_argmin, # only required for perf
'max_seqlen': max_seqlen, # only required for perf
'cu_seqlens_unpadded': torch.IntTensor(cu_seqlens_unpadded),
'cu_seqlens_unpadded_argmin': cu_seqlens_unpadded_argmin,
}
)
else:
Expand Down
81 changes: 62 additions & 19 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,22 +1231,23 @@ def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None and key != "context_lengths":
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

# check if the batch is not in THD format
if 'cu_seqlens' not in batch:
for key, val in batch.items():
if val is not None and key != "context_lengths":
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub

return batch
Expand All @@ -1261,12 +1262,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
required_keys = set()
max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None
cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None
cu_seqlens_unpadded_argmin = (
batch['cu_seqlens_unpadded_argmin'] if 'cu_seqlens_unpadded_argmin' in batch else None
)
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
required_keys.update(batch.keys())
else:
required_keys.add('attention_mask')
if 'cu_seqlens' in batch:
required_keys.add('cu_seqlens')
if 'cu_seqlens_unpadded' in batch:
required_keys.add('cu_seqlens_unpadded')
if parallel_state.is_pipeline_first_stage():
required_keys.update(('tokens', 'position_ids'))
if parallel_state.is_pipeline_last_stage():
Expand Down Expand Up @@ -1301,12 +1307,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset
# these args are passed eventually into TEDotProductAttention.forward()
cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1)
cu_seqlens_unpadded = batch['cu_seqlens_unpadded'].squeeze()
# remove -1 "paddings" added in collate_fn
if cu_seqlens_argmin is not None:
cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()]
else:
cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)]

if cu_seqlens_unpadded_argmin is not None:
cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
else:
cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]
try:
from megatron.core.packed_seq_params import PackedSeqParams
except (ImportError, ModuleNotFoundError) as e:
Expand All @@ -1317,9 +1327,42 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
)
raise e

# get packed sequences for this context parallel rank
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
try:
import transformer_engine_torch as tex
except ModuleNotFoundError as e:
logging.error(
"Please update Transformer Engine to >= 1.10 to use Context Parallel with THD format data"
)
raise e
cp_rank = parallel_state.get_context_parallel_rank()
for key in required_keys:
val = batch[key]
if key not in {
"cu_seqlens",
"cu_seqlens_unpadded",
"cu_seqlens_argmin",
"cu_seqlens_unpadded_argmin",
"max_seqlen",
"token_count",
}:
index = tex.thd_get_partitioned_indices(cu_seqlens, val.size(1), cp_size, cp_rank)
val = val.index_select(1, index)
batch[key] = val
forward_args = {
'input_ids': batch['tokens'],
'position_ids': batch['position_ids'],
'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'],
'labels': batch['labels'] if 'labels' in batch else None,
}

forward_args['packed_seq_params'] = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q=cu_seqlens_unpadded,
cu_seqlens_kv=cu_seqlens_unpadded,
cu_seqlens_q_padded=cu_seqlens,
cu_seqlens_kv_padded=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_kv=max_seqlen,
qkv_format='thd',
Expand Down
4 changes: 2 additions & 2 deletions nemo/utils/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int):
logging.info("Creating histogram from tokenized dataset...")

sequences = collections.defaultdict(list)
counts = [0] * truncate_seq_len
counts = [0] * (truncate_seq_len + 1)

for item_dict in dataset:
# Minus 1 here to account for the fact that transformer input and label have one less token than the full sequence
Expand All @@ -129,7 +129,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int):
logging.debug(counts)

histogram = []
for seq_len in range(truncate_seq_len):
for seq_len in range(truncate_seq_len + 1):
histogram.append(len(sequences[seq_len]))

return sequences, histogram
Expand Down
45 changes: 43 additions & 2 deletions scripts/nlp_language_modeling/prepare_packed_ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def tokenize_dataset(cfg: 'DictConfig'):
# using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings
# are identical to normal SFT training
data_cfg = cfg.model.data.train_ds
pad_seq_length_to_mult = 16
cp_size = cfg.model.get("context_parallel_size", 1)

# if context parallel is used, each individual data length in one packed dataset sample
# needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641
if cp_size > 1:
pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2)

if os.path.isdir(cfg.tokenizer_path):
# pass in a Hugging Face folder which contains tokenizer.json
tokenizer = get_nmt_tokenizer(library="huggingface", model_name=cfg.tokenizer_path, use_fast=True)
Expand All @@ -99,7 +107,7 @@ def tokenize_dataset(cfg: 'DictConfig'):
tokenizer=tokenizer,
max_seq_length=data_cfg.max_seq_length,
min_seq_length=data_cfg.min_seq_length,
pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here
pad_seq_length_to_mult=pad_seq_length_to_mult,
add_bos=data_cfg.get('add_bos', False),
add_eos=data_cfg.get('add_eos', True),
add_sep=data_cfg.get('add_sep', False),
Expand All @@ -121,7 +129,40 @@ def tokenize_dataset(cfg: 'DictConfig'):
is_test=True,
)

return np.array([dataset[i] for i in range(len(dataset))])
max_seq_length = dataset.max_seq_length
pad_id = dataset.tokenizer.eos_id
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult
dataset = np.array([dataset[i] for i in range(len(dataset))])
if cp_size > 1:

def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id):
'''
pad each individual data point to the length of max_length
'''
assert max_seq_length >= max_length_to_pad
for key, val in data.items():
if key in {'input_ids', 'context_ids'}:
if len(val) <= max_length_to_pad:
# because input_ids are truncated by 1 for inputs and labels,
# we add 1 extra padding here to make sure padded inputs and labels
# are is a multiple of (cp_size * 2)
val = val + [pad_id] * (max_length_to_pad - len(val) + 1)
data[key] = val
elif len(val) > max_seq_length:
logging.info(
f"""The current sequence length {len(val)} for packing is
larger than the max_seq_length specified ({max_seq_length}).
The current seqquence length is truncated to the size of max_seq_length.
Please consider increase the sequence packing size"""
)
data[key] = val[:max_seq_length]
return

ceil_to_nearest = lambda n, m: (n + m - 1) // m * m
for data in dataset:
max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult))
pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id)
return dataset


@dataclass
Expand Down

0 comments on commit 48349e4

Please sign in to comment.