Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context Parallel SFT Support for dataset in THD format #10688

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9a2de52
Add context parallel support for packed dataset in THD format
tomlifu Aug 19, 2024
b1ef8f0
Merge remote-tracking branch 'origin/main' into thd_cp_support
tomlifu Aug 21, 2024
a11c351
add changes with debug print
tomlifu Sep 30, 2024
2fd0456
remove debug print
tomlifu Sep 30, 2024
cf1a88d
Merge branch 'main' into thd_cp_support
tomlifu Sep 30, 2024
4ad3511
Apply isort and black reformatting
tomlifu Sep 30, 2024
43d60a3
Merge branch 'main' into thd_cp_support
tomlifu Oct 1, 2024
63447f7
fix cu_seqlens and cu_seqlens_padded
tomlifu Oct 9, 2024
e287857
Merge branch 'thd_cp_support' of https://github.com/tomlifu/NeMo into…
tomlifu Oct 9, 2024
850d9ae
cu_seqlens and cu_seqlens_padded fix
tomlifu Oct 9, 2024
d50d88e
more fix on cu_seqlens and cu_seqlens_padded
tomlifu Oct 12, 2024
adea017
Apply isort and black reformatting
tomlifu Oct 12, 2024
8c76e48
addressing Xiaowei's review
tomlifu Oct 21, 2024
78b3b1c
addressing more review comments
tomlifu Oct 23, 2024
8852431
fix for the case where cp=1
tomlifu Oct 25, 2024
a4bbb20
Apply isort and black reformatting
tomlifu Oct 25, 2024
bcda0db
more fix to address Xiaowei's comment
tomlifu Oct 26, 2024
aa59854
fix the loss_mask for THD formated data
tomlifu Nov 1, 2024
ab02643
Apply isort and black reformatting
tomlifu Nov 7, 2024
2a8b21f
fixed eos_idx[0][0] out of bounds issue
tomlifu Nov 8, 2024
f930f77
Merge branch 'thd_cp_support' of https://github.com/tomlifu/NeMo into…
tomlifu Nov 8, 2024
d3e9354
Merge branch 'main' into thd_cp_support
tomlifu Nov 8, 2024
02bccd7
fixed CP=1 case
tomlifu Nov 9, 2024
11d68b4
fixed thd_get_partitioned_indices assertion issue when pp=1
tomlifu Nov 14, 2024
463a478
Apply isort and black reformatting
tomlifu Nov 14, 2024
12de6bb
fixed data packing issue
Nov 22, 2024
cc236ba
fixed an issue where cp>1 has different loss curves
tomlifu Nov 25, 2024
a988522
Merge branch 'main' into thd_cp_support
tomlifu Nov 26, 2024
29a8dea
remove redudant check for cu_seqlens
tomlifu Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -573,15 +573,23 @@ def _build_samples_mapping(self):
self.samples_mapping = None

def _build_loss_mask(self, processed_example):
seq_boundaries = processed_example['seq_boundaries']
if self.answer_only_loss:
seq_boundaries = processed_example['seq_boundaries']
return np.concatenate(
[
processed_example['loss_mask'][seq_boundaries[i] + 1 : seq_boundaries[i + 1]]
for i in range(len(seq_boundaries) - 1)
]
)
return [1.0] * (len(processed_example['input_ids']) - len(processed_example['seq_boundaries']) + 1)
return np.concatenate(
[
[
0 if x == self.tokenizer.eos_id else 1.0
for x in processed_example['input_ids'][seq_boundaries[i] : seq_boundaries[i + 1] - 1]
]
for i in range(len(seq_boundaries) - 1)
]
)

def _maybe_cast_to_list(self, x):
return [item.tolist() if isinstance(item, np.ndarray) else item for item in x]
Expand Down Expand Up @@ -622,16 +630,40 @@ def collate_fn(self, batch):

position_ids: List[List[int]] = []
cu_seqlens: List[List[int]] = []
cu_seqlens_unpadded: List[List[int]] = []
for item in batch:
position_ids.append([])
cu_seqlens.append([0])
cu_seqlens_unpadded.append([0])
seqlens = np.array(item['seq_boundaries'][1:]) - np.array(item['seq_boundaries'][:-1])
for l in seqlens:
# length minus 1 because input_ids is truncated by 1 for labels
position_ids[-1].extend(list(range(l - 1)))
cu_seqlens[-1].append(cu_seqlens[-1][-1] + l - 1)
# set last seq to the max seq len because rope and attn kernels expect no padding
cu_seqlens[-1][-1] = max_length

# the last seq needs to be the max seq len because rope and attn kernels expect no padding
assert cu_seqlens[-1][-1] <= max_length

# since data is prepadded when cp_size > 1, there may be some extra padding at the end
# of the packed sequence. In this case, we need to add the max seq len to the end.
if cu_seqlens[-1][-1] != max_length:
cu_seqlens[-1].append(max_length)

for i in range(len(item['seq_boundaries']) - 1):
current_seq = item['input_ids'][item['seq_boundaries'][i] : item['seq_boundaries'][i + 1] - 1]

# since the data could be prepadded with tokenizer's eos_id, we can find out the index of all the eos_id
eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id)

# The second eos_id index marks the length of the original unpadded sequence if the sequence is
# prepadded for cp_size > 1. Otherwise, there is no extra padding.
seqlen_unpadded = eos_idx[0][0] + 1 if eos_idx[0].any() else len(current_seq)
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded)

# if extra paddings are added in the packed sequence, they can't be counted as
# actual tokens for training
if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]):
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1])
xrennvidia marked this conversation as resolved.
Show resolved Hide resolved

assert len(input_ids[0]) == len(
position_ids[0]
Expand All @@ -652,12 +684,16 @@ def collate_fn(self, batch):

if self.return_cu_seqlen:
cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1)

cu_seqlens_unpadded = self._collate_item(
cu_seqlens_unpadded, max_length=max(len(l) for l in cu_seqlens_unpadded) + 1, pad_id=-1
)
# Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies.
cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True)
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)
cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded)
cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True)

processed_batch.update(
{
Expand All @@ -667,6 +703,8 @@ def collate_fn(self, batch):
'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
'cu_seqlens_argmin': cu_seqlens_argmin, # only required for perf
'max_seqlen': max_seqlen, # only required for perf
'cu_seqlens_unpadded': torch.IntTensor(cu_seqlens_unpadded),
'cu_seqlens_unpadded_argmin': cu_seqlens_unpadded_argmin,
}
)
else:
Expand Down
80 changes: 61 additions & 19 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import packaging
import torch
import transformer_engine_torch as tex
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.accelerators import CPUAccelerator
Expand Down Expand Up @@ -1230,22 +1231,23 @@ def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None and key != "context_lengths":
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

# check if the batch is not in THD format
if 'cu_seqlens' not in batch:
for key, val in batch.items():
if val is not None and key != "context_lengths":
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub

return batch
Expand All @@ -1260,12 +1262,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
required_keys = set()
max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None
cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None
cu_seqlens_unpadded_argmin = (
batch['cu_seqlens_unpadded_argmin'] if 'cu_seqlens_unpadded_argmin' in batch else None
)
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
required_keys.update(batch.keys())
else:
required_keys.add('attention_mask')
if 'cu_seqlens' in batch:
required_keys.add('cu_seqlens')
if 'cu_seqlens_unpadded' in batch:
required_keys.add('cu_seqlens_unpadded')
if parallel_state.is_pipeline_first_stage():
required_keys.update(('tokens', 'position_ids'))
if parallel_state.is_pipeline_last_stage():
Expand Down Expand Up @@ -1300,12 +1307,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset
# these args are passed eventually into TEDotProductAttention.forward()
cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1)
cu_seqlens_unpadded = batch['cu_seqlens_unpadded'].squeeze()
# remove -1 "paddings" added in collate_fn
if cu_seqlens_argmin is not None:
cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()]
else:
cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)]

if cu_seqlens_unpadded_argmin is not None:
cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
else:
cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]
try:
from megatron.core.packed_seq_params import PackedSeqParams
except (ImportError, ModuleNotFoundError) as e:
Expand All @@ -1316,9 +1327,40 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
)
raise e

# get packed sequences for this context parallel rank
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key in required_keys:
val = batch[key]
if key not in {
"cu_seqlens",
"cu_seqlens_unpadded",
"cu_seqlens_argmin",
"cu_seqlens_unpadded_argmin",
"max_seqlen",
"token_count",
}:
index = tex.thd_get_partitioned_indices(cu_seqlens, val.size(1), cp_size, cp_rank)
val = val.index_select(1, index)
tomlifu marked this conversation as resolved.
Show resolved Hide resolved
batch[key] = val
forward_args = {
'input_ids': batch['tokens'],
'position_ids': batch['position_ids'],
'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'],
'labels': batch['labels'] if 'labels' in batch else None,
}

# since the data is pre-padded, we need to check if there is extra padding
# at the end of the sequence
if len(cu_seqlens) > len(cu_seqlens_unpadded):
cu_seqlens = cu_seqlens[:-1]

forward_args['packed_seq_params'] = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q=cu_seqlens_unpadded,
cu_seqlens_kv=cu_seqlens_unpadded,
cu_seqlens_q_padded=cu_seqlens,
cu_seqlens_kv_padded=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_kv=max_seqlen,
qkv_format='thd',
Expand Down
2 changes: 1 addition & 1 deletion nemo/utils/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int):
logging.info("Creating histogram from tokenized dataset...")

sequences = collections.defaultdict(list)
counts = [0] * truncate_seq_len
counts = [0] * (truncate_seq_len + 1)
tomlifu marked this conversation as resolved.
Show resolved Hide resolved

for item_dict in dataset:
# Minus 1 here to account for the fact that transformer input and label have one less token than the full sequence
Expand Down
47 changes: 45 additions & 2 deletions scripts/nlp_language_modeling/prepare_packed_ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def tokenize_dataset(cfg: 'DictConfig'):
# using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings
# are identical to normal SFT training
data_cfg = cfg.model.data.train_ds
pad_seq_length_to_mult = 16
cp_size = cfg.model.context_parallel_size
tomlifu marked this conversation as resolved.
Show resolved Hide resolved

# if context parallel is used, each individual data length in one packed dataset sample
# needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641
if cp_size > 1:
pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2)

if os.path.isdir(cfg.tokenizer_path):
# pass in a Hugging Face folder which contains tokenizer.json
tokenizer = get_nmt_tokenizer(library="huggingface", model_name=cfg.tokenizer_path, use_fast=True)
Expand All @@ -99,7 +107,7 @@ def tokenize_dataset(cfg: 'DictConfig'):
tokenizer=tokenizer,
max_seq_length=data_cfg.max_seq_length,
min_seq_length=data_cfg.min_seq_length,
pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here
pad_seq_length_to_mult=pad_seq_length_to_mult,
add_bos=data_cfg.get('add_bos', False),
add_eos=data_cfg.get('add_eos', True),
add_sep=data_cfg.get('add_sep', False),
Expand All @@ -121,7 +129,42 @@ def tokenize_dataset(cfg: 'DictConfig'):
is_test=True,
)

return np.array([dataset[i] for i in range(len(dataset))])
max_seq_length = dataset.max_seq_length
pad_id = dataset.tokenizer.eos_id
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult
dataset = np.array([dataset[i] for i in range(len(dataset))])
if cp_size > 1:

def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id):
'''
pad each individual data point to the length of max_length
'''
assert max_seq_length >= max_length_to_pad
for key, val in data.items():
if key in {'input_ids', 'context_ids'}:
if len(val) <= max_length_to_pad:
# because input_ids are truncated by 1 for inputs and labels,
# we add 1 extra padding here to make sure padded inputs and labels
# are is a multiple of (cp_size * 2)
val = val + [pad_id] * (max_length_to_pad - len(val) + 1)
data[key] = val
elif len(val) > max_seq_length:
logging.info(
f"""The current sequence length {len(val)} for packing is
larger than the max_seq_length specified ({max_seq_length}).
The current seqquence length is truncated to the size of max_seq_length.
Please consider increase the sequence packing size"""
)
data[key] = val[:max_seq_length]
return

ceil_to_nearest = lambda n, m: (n + m - 1) // m * m
for data in dataset:
max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult))
if max_length_to_pad < 512:
max_length_to_pad = 512
pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id)
return dataset


@dataclass
Expand Down
Loading