Skip to content

Commit

Permalink
Add ZigzagRing Support (NVlabs#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qinghao-Hu authored Aug 15, 2024
1 parent 8ee8912 commit 118c349
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 93 deletions.
50 changes: 36 additions & 14 deletions llava/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2590,6 +2590,7 @@ class DataCollatorForSupervisedDatasetSeqParallel:
sp_degree: int
sp_rank: int
ring_degree: int
ring_type: str

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, images = [], [], []
Expand Down Expand Up @@ -2689,20 +2690,39 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# Handle RingAttn_Varlen which requires `seqlens_in_batch` should be divisible by `ring_degree`
if self.ring_degree > 1:
RING_PAD_TOKEN_INDEX = 2
if num_incoming_tokens % self.sp_degree != 0:
pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree
num_incoming_tokens += pad_len
# pad `input_ids`
pad_tensor = torch.full(
(pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
)
sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])

# pad `label`
pad_label_tensor = torch.full(
(pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
)
sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
if self.ring_type == "ring_varlen":
if num_incoming_tokens % self.sp_degree != 0:
pad_len = self.sp_degree - num_incoming_tokens % self.sp_degree
num_incoming_tokens += pad_len
# pad `input_ids`
pad_tensor = torch.full(
(pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
)
sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])

# pad `label`
pad_label_tensor = torch.full(
(pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
)
sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
elif self.ring_type == "zigzag_ring_varlen":
self.zigzag_sp_degree = self.sp_degree * 2
if num_incoming_tokens % self.zigzag_sp_degree != 0:
pad_len = self.zigzag_sp_degree - num_incoming_tokens % self.zigzag_sp_degree
num_incoming_tokens += pad_len
# pad `input_ids`
pad_tensor = torch.full(
(pad_len,), RING_PAD_TOKEN_INDEX, dtype=sorted_ids[i].dtype, device=sorted_ids[i].device
)
sorted_ids[i] = torch.cat([sorted_ids[i], pad_tensor])

# pad `label`
pad_label_tensor = torch.full(
(pad_len,), IGNORE_INDEX, dtype=sorted_labels[i].dtype, device=sorted_labels[i].device
)
sorted_labels[i] = torch.cat([sorted_labels[i], pad_label_tensor])
else:
raise ValueError(f"Invalid ring_type: {self.ring_type}")

if num_incoming_tokens > max_seq_length:
print(
Expand Down Expand Up @@ -2855,13 +2875,15 @@ def make_supervised_data_module(
sp_degree = training_args.seq_parallel_size
sp_rank = PROCESS_GROUP_MANAGER.sp_rank
ring_degree = PROCESS_GROUP_MANAGER.ring_degree
ring_type = PROCESS_GROUP_MANAGER.ring_type
data_collator = DataCollatorForSupervisedDatasetSeqParallel(
tokenizer=tokenizer,
data_args=data_args,
training_args=training_args,
sp_degree=sp_degree,
sp_rank=sp_rank,
ring_degree=ring_degree,
ring_type=ring_type,
)

return dict(
Expand Down
61 changes: 54 additions & 7 deletions llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ def repack_multimodal_data(
sp_rank = PROCESS_GROUP_MANAGER.sp_rank
sp_group = PROCESS_GROUP_MANAGER.sp_pg
ring_degree = PROCESS_GROUP_MANAGER.ring_degree
ring_rank = PROCESS_GROUP_MANAGER.ring_rank
ring_type = PROCESS_GROUP_MANAGER.ring_type
ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank

bs, shard_seqlen = position_ids.shape
sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
Expand Down Expand Up @@ -642,13 +646,56 @@ def repack_multimodal_data(
dtype=global_inputs_embeds.dtype,
device=global_inputs_embeds.device,
)
for i in range(bs):
start_idx = new_seqlen_per_rank[i] * sp_rank
end_idx = start_idx + new_seqlen_per_rank[i]
new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[i, start_idx:end_idx, :]

if ring_type == "ring_varlen":
for i in range(bs):
start_idx = new_seqlen_per_rank[i] * sp_rank
end_idx = start_idx + new_seqlen_per_rank[i]
new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
i, start_idx:end_idx, :
]
elif ring_type == "zigzag_ring_varlen":
chunk_size = total_effective_seqlen // (2 * sp_degree)
for i in range(bs):
# Zigzag pattern indices
if sp_degree == ring_degree:
forward_rank_idx = sp_rank
backward_rank_idx = 2 * sp_degree - sp_rank - 1
else:
ulysses_offset = ulysses_rank * ring_degree * 2
forward_rank_idx = ring_rank + ulysses_offset
backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset

# Calculate start and end indices for the forward and backward zigzag
start_idx_fwd = forward_rank_idx * chunk_size[i]
end_idx_fwd = start_idx_fwd + chunk_size[i]

start_idx_bwd = backward_rank_idx * chunk_size[i]
end_idx_bwd = start_idx_bwd + chunk_size[i]

# Fill new tensors with zigzag data
new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
i, start_idx_bwd:end_idx_bwd
]

new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
i, start_idx_bwd:end_idx_bwd
]

new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]

new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
i, start_idx_bwd:end_idx_bwd, :
]
else:
raise ValueError(f"Invalid ring_type: {ring_type}")
else:
global_seq_len = global_attention_mask.shape[-1]
seq_len_sharded = global_seq_len // sp_degree
Expand Down
4 changes: 3 additions & 1 deletion llava/train/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ class TrainingArguments(transformers.TrainingArguments):
)
seq_parallel_ring_type: str = field(
default="ring_varlen",
metadata={"help": "Ring Attention implementation."},
metadata={
"help": "Ring Attention implementation. Support ['ring_varlen', 'zigzag_ring_varlen'] in 2D attention. Only works when `seq_parallel_ring_size` > 1."
},
)
debug_e2e: bool = field(
default=False,
Expand Down
16 changes: 12 additions & 4 deletions llava/train/sequence_parallel/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ class ProcessGroupManager(Singleton):
sp_degree = sp_ring_degree x sp_ulysses_degree
"""

def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low):
def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low, ring_type):
if not hasattr(self, "__initialized"):
super().__init__()
self.ulysses_degree = ulysses_degree
self.ring_type = ring_type
self.ulysses_seq_len = None

self.ring_degree = ring_degree
Expand Down Expand Up @@ -148,7 +149,7 @@ def __init__(self, ulysses_degree, ring_degree, dp_degree, use_ulysses_low):
PROCESS_GROUP_MANAGER = None


def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True):
def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True, ring_type=None):
"""
Set the process group manager for sequence parallelism.
sp_degree = sp_ring_degree x sp_ulysses_degree
Expand Down Expand Up @@ -185,7 +186,9 @@ def set_pg_manager(sp_degree, sp_ring_degree=1, use_ulysses_low=True):

# Init the process group manager
global PROCESS_GROUP_MANAGER
PROCESS_GROUP_MANAGER = ProcessGroupManager(sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low)
PROCESS_GROUP_MANAGER = ProcessGroupManager(
sp_ulysses_degree, sp_ring_degree, dp_degree, use_ulysses_low, ring_type
)


def get_pg_manager():
Expand Down Expand Up @@ -243,10 +246,15 @@ def get_ring_sp_rank():


def get_ring_sp_pg():
"""Get the Ulysses sequence parallel process group."""
"""Get the RingAttn sequence parallel process group."""
return PROCESS_GROUP_MANAGER.ring_pg


def get_ring_type():
"""Get the RingAttn implementation type."""
return PROCESS_GROUP_MANAGER.ring_type


def get_data_parallel_size():
"""Get the size of the data parallel group."""
return PROCESS_GROUP_MANAGER.dp_degree
Expand Down
19 changes: 3 additions & 16 deletions llava/train/sequence_parallel/hybrid_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,7 @@
from torch.nn import Module

from .all_to_all import SeqAllToAll4D, SeqAllToAll5D
from .globals import (
get_pg_manager,
get_ring_sp_pg,
get_ring_sp_rank,
get_ring_sp_size,
get_sequence_parallel_pg,
get_sequence_parallel_rank,
get_sequence_parallel_size,
get_ulysses_seq_len,
get_ulysses_sp_pg,
get_ulysses_sp_rank,
get_ulysses_sp_size,
)
from .globals import get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg
from .ring import (
ring_flash_attn_func,
ring_flash_attn_qkvpacked_func,
Expand All @@ -54,7 +42,7 @@
"zigzag": zigzag_ring_flash_attn_func,
"strip": stripe_flash_attn_func,
"ring_varlen": ring_flash_attn_varlen_func,
"zigzag_varlen": zigzag_ring_flash_attn_varlen_func,
"zigzag_ring_varlen": zigzag_ring_flash_attn_varlen_func,
}

RING_IMPL_QKVPACKED_DICT = {
Expand All @@ -80,7 +68,6 @@ def __init__(
self,
scatter_idx: int = 2,
gather_idx: int = 1,
ring_impl_type: str = "ring_varlen",
use_pack_qkv: bool = False,
attention_warper: Module = None,
) -> None:
Expand All @@ -96,7 +83,7 @@ def __init__(
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
if attention_warper is None:
self.ring_attn_fn = RING_IMPL_DICT[ring_impl_type]
self.ring_attn_fn = RING_IMPL_DICT[get_ring_type()]
else:
self.ring_attn_fn = attention_warper

Expand Down
61 changes: 23 additions & 38 deletions llava/train/sequence_parallel/input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,31 @@
import torch


def extract_local_from_list(vaule_list, sp_rank, sp_size):
quotient, remainder = divmod(len(vaule_list), sp_size)
def extract_local_zigzag(value, rank, world_size, device, dim=1):
value_chunks = value.chunk(2 * world_size, dim=dim)
local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim)
return local_value.to(device)


def extract_local_from_list(value_list, sp_rank, sp_size):
quotient, remainder = divmod(len(value_list), sp_size)
start_idx = sp_rank * quotient + min(sp_rank, remainder)
end_idx = (sp_rank + 1) * quotient + min(sp_rank + 1, remainder)
return vaule_list[start_idx:end_idx]
return value_list[start_idx:end_idx]


def extract_local_from_list_zigzag(value_list, sp_rank, sp_size):
chunk_size, remainder = divmod(len(value_list), (2 * sp_size))
value_chunks = []
start_idx = 0
for i in range(2 * sp_size):
extra = 1 if i < remainder else 0
end_idx = start_idx + chunk_size + extra
value_chunks.append(value_list[start_idx:end_idx])
start_idx = end_idx

local_value = value_chunks[sp_rank] + value_chunks[2 * sp_size - sp_rank - 1]
return local_value


def extract_local_input_ids(input_ids, image_positions, sp_rank, sp_size, bos_token_id=1, image_token_len=3):
Expand Down Expand Up @@ -58,38 +78,3 @@ def extract_local_position_ids(input_ids, image_positions, image_ids, sp_rank, s
return input_ids[start_position_idx:]
else:
return input_ids[start_position_idx:end_position_idx]


def extract_local(value, rank, world_size, dim=1):
value_chunks = value.chunk(2 * world_size, dim=dim)
local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim)
return local_value


def prepare_hybrid_attn_inputs(input_ids, position_ids, target_ids, rank, world_size, device):
local_input_ids = extract_local(
input_ids,
rank,
world_size,
device,
)
local_position_ids = extract_local(
position_ids,
rank,
world_size,
device,
)
if target_ids is not None:
local_target_ids = extract_local(
target_ids,
rank,
world_size,
device,
)
else:
local_target_ids = None
return {
"local_input_ids": local_input_ids,
"local_position_ids": local_position_ids,
"local_target_ids": local_target_ids,
}
40 changes: 28 additions & 12 deletions llava/train/sequence_parallel/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, _get_unpad_data, apply_rotary_pos_emb

from llava.train.sequence_parallel.globals import get_pg_manager, get_ring_sp_pg, get_ulysses_sp_pg
from llava.train.sequence_parallel.globals import get_pg_manager, get_ring_sp_pg, get_ring_type, get_ulysses_sp_pg

from .hybrid_attn import HybridAttention
from .ring import (
Expand Down Expand Up @@ -167,17 +167,33 @@ def hybrid_attn_varlen_func_helper(
# print("rank", dist.get_rank(), "cu_seq_lens", cu_seq_lens)
# exit()

attn_output_unpad = ring_flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seq_lens,
max_seq_lens[0],
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=self.is_causal,
group=group,
)
ring_type = get_ring_type()
if ring_type == "ring_varlen":
attn_output_unpad = ring_flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seq_lens,
max_seq_lens[0],
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=self.is_causal,
group=group,
)
elif ring_type == "zigzag_ring_varlen":
attn_output_unpad = zigzag_ring_flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seq_lens,
max_seq_lens[0],
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=self.is_causal,
group=group,
)
else:
raise ValueError(f"Invalid ring_type: {ring_type}")

# print(dist.get_rank(), "finish ring_flash_attn_varlen_func")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def forward(q, k, v, causal):
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
block_table=None,
)
return block_out, block_lse

Expand Down
Loading

0 comments on commit 118c349

Please sign in to comment.