Skip to content

Commit

Permalink
add simulation_train_formulaic
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed Aug 5, 2024
1 parent fa23d1e commit b536cd3
Show file tree
Hide file tree
Showing 51 changed files with 4,041 additions and 872 deletions.
5 changes: 5 additions & 0 deletions gen_profiler_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

from internlm.simulator.profiler.perf_comm import gen_perf

if __name__ == "__main__":
gen_perf()
349 changes: 312 additions & 37 deletions internlm/core/context/parallel_context.py

Large diffs are not rendered by default.

229 changes: 229 additions & 0 deletions internlm/core/context/process_group_initializer_simplified.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from copy import deepcopy
from enum import Enum

import torch
import torch.distributed as dist

from internlm.utils.timeout import LLM_NCCL_TIMEOUT
from internlm.core.context.process_group_initializer import ParallelMode

class ParallelMeta:
def __init__(self, parallel_size, mode) -> None:
self.parallel_size = parallel_size
self.mode = mode

def __str__(self) -> str:
return self.__repr__()

def __repr__(self) -> str:
return f"{self.mode}, {self.parallel_size}"


def determine_intra_inter_size_of_group(one_group_indexs, intra_range=8):
"Determine the inter size and intra size of a rank group."
gourp_size = len(one_group_indexs)
if gourp_size == 1:
return 1, 1
else:
group_stride = one_group_indexs[1] - one_group_indexs[0]
if group_stride >= intra_range:
return 1, gourp_size
else:
intra_size = intra_range // group_stride
inter_size = gourp_size // intra_size
return max(1, intra_size), max(1, inter_size)


class Initializer:
def __init__(
self,
rank: int,
world_size: int,
fake_mode: bool = False,
tensor_mode: str = "fsp",
parallel_info: dict = None,
):
"""Initialize communication groups
Args:
rank (int): global rank
world_size (int): world size
fake_mode (bool, optional): Whether to create actual NCCL communication
groups.Defaults to False.
tensor_mode (str, optional): ISP/FSP/MSP. Defaults to "fsp".
parallel_info (dict, optional): parallel_info. Defaults to None.
"""
self.rank = rank
self.world_size = world_size
self.fake_mode = fake_mode
self.tensor_mode = tensor_mode
self.parallel_info = parallel_info

# assert sequence_parallel_size == tensor_parallel_size
super().__init__()

def init_dist_group(self, use_cpu: bool = False):
parallel_info, world_size = self.parallel_info, self.world_size

wp_size = parallel_info["wp"].parallel_size
# tp_size = parallel_info["tp"].parallel_size
# pp_size = parallel_info["pp"].parallel_size
wdp_size = parallel_info["wdp"].parallel_size
zero1_size = parallel_info["zero1"].parallel_size
ep_size = parallel_info["ep"].parallel_size
edp_size = parallel_info["edp"].parallel_size

re_group_args = {}

# stride_order means the placement priority of PG groups.
stride_order = ["tp", "dp", "pp"]
strides = {}

def assemble_group(all_ranks, dim_name):
for ranks in all_ranks:
if self.fake_mode or len(all_ranks) == 1:
group, group_cpu = None, None
else:
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None

if self.rank in ranks:
local_rank = ranks.tolist().index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks.tolist()

new_all_ranks = []
for ranks in all_ranks:
new_all_ranks.append(ranks.tolist())

return (
local_rank,
group_world_size,
process_group,
cpu_group,
ranks_in_group,
new_all_ranks,
parallel_info[dim_name].mode,
)

def split_orthogonal_sub_group(dim_name, indexs, size, stride):
assert size <= world_size, f"{dim_name} stride: {size} should less then worldsize: {world_size} !"

indexs = indexs.reshape(-1, stride).T.reshape(-1)
all_ranks = torch.split(indexs, size)

return indexs, assemble_group(all_ranks, dim_name)

def split_horizontal_sub_group(dim_name, indexs, size, stride):
assert size <= world_size, f"{dim_name} stride: {size} should less then worldsize: {world_size} !"

indexs = indexs.reshape(stride, -1).reshape(-1)
all_ranks = torch.split(indexs, size)

return indexs, assemble_group(all_ranks, dim_name)

count = 0
for dim_name in stride_order:
parallel_size = parallel_info[dim_name].parallel_size
if parallel_size == 1:
continue

if count == 0:
strides[dim_name] = 1
else:
strides[dim_name] = strides[old_dim_name] * parallel_info[old_dim_name].parallel_size

father_indexs, group_args = split_orthogonal_sub_group(
dim_name, torch.arange(start=0, end=world_size), size=parallel_size, stride=strides[dim_name]
)
re_group_args[dim_name] = group_args

if dim_name == "dp":
"""
"EP, EDP, and ZeRO are auxiliary parallel modes within DP."
"""
if wp_size == 1 and self.tensor_mode != "isp":
re_group_args["zero1"] = split_horizontal_sub_group("zero1", father_indexs, zero1_size, zero1_size)[
1
]
print(f"re_group_args['zero1']: {re_group_args['zero1']}")

# MoE expert group is subgroup of data parallel group
if ep_size > 1:
ep_indexs, group_ep_args = split_horizontal_sub_group(
"ep", father_indexs, size=ep_size, stride=ep_size
)
re_group_args["ep"] = group_ep_args
re_group_args["edp"] = split_orthogonal_sub_group("edp", ep_indexs, edp_size, ep_size)[1]

one_group_indexs = group_args[4] # one group ranks
intra_dp_size, inter_dp_size = determine_intra_inter_size_of_group(one_group_indexs)

# It will be used in drawing heatmap.
parallel_info["intra_dp"].parallel_size = intra_dp_size
parallel_info["inter_dp"].parallel_size = inter_dp_size

# The only parallel group with a higher priority than DP is TP.
# see: stride_order = ["tp", "dp", "pp"]
high_priority_group = parallel_info["tp"].parallel_size

re_group_args["intra_dp"] = split_horizontal_sub_group(
"intra_dp", father_indexs, size=intra_dp_size, stride=high_priority_group
)[1]

re_group_args["inter_dp"] = split_orthogonal_sub_group(
"inter_dp", father_indexs, size=inter_dp_size, stride=intra_dp_size
)[1]

elif dim_name == "tp":
"""
The situation with isp is somewhat complex. When using isp, the head/embedding is partitioned
according to the Megatron-TP method and uses the TP communication group, while other modules
are partitioned according to the WP communication group and reuse the TP communication group
(but perform DeepSpeed-Ulysses instead of Megatron-TP). Therefore,
for head/embedding, their Zero1 communication group is orthogonal to the TP group,
for other modules, their Zero1 communication group is the Wdp communication group
(orthogonal to the WP/TP communication groups).
FIXME: Can this be further simplified?
"""
if self.tensor_mode == "isp":
if wp_size == 1:
re_group_args["zero1"] = split_horizontal_sub_group(
"zero1", father_indexs, zero1_size, zero1_size
)[1]
else:
wp_index, re_group_args["wp"] = split_horizontal_sub_group(
"wp", torch.arange(start=0, end=world_size), wp_size, wp_size
)
re_group_args["wdp"] = split_orthogonal_sub_group("wdp", wp_index, wdp_size, wp_size)[1]
re_group_args["zero1"] = split_orthogonal_sub_group(
"zero1", father_indexs, zero1_size, wp_size
)[1]

count += 1
old_dim_name = dim_name

for name, info in parallel_info.items():
if info.parallel_size == 1:
# If the degree of parallelism is 1, for logical consistency,
# we still need to create a logical communication group
re_group_args[name] = assemble_group([torch.tensor([self.rank])], name)

# If two groups are orthogonal to each other and one group has a parallelism degree of 1,
# then the parallelism degree of the other group is world_size.
if parallel_info["wp"].parallel_size == 1:
re_group_args["wdp"] = tuple(list(deepcopy(re_group_args["dp"]))[0:-1] + [parallel_info["wdp"].mode])

return re_group_args
15 changes: 15 additions & 0 deletions internlm/core/context/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context

import os
from contextlib import contextmanager

from torch import Tensor
Expand All @@ -10,6 +11,8 @@

from .process_group_initializer import ParallelMode

fake_mode = "fake_mode" in os.environ

internlm_accelerator = get_accelerator()


Expand All @@ -35,11 +38,15 @@ def seed_states(self):

def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`."""
if fake_mode:
return
assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager"
self._seed_states[parallel_mode] = state

def set_mode(self, parallel_mode: ParallelMode, update_rng_current_mode: bool = True):
"""Sets the current mode of the seed manager."""
if fake_mode:
return
if update_rng_current_mode and self.current_mode:
# save state for current mode
self._seed_states[self._current_mode] = internlm_accelerator.get_rng_state()
Expand All @@ -50,6 +57,8 @@ def set_mode(self, parallel_mode: ParallelMode, update_rng_current_mode: bool =

def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`."""
if fake_mode:
return
assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode"
if not overwrite:
assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists"
Expand All @@ -63,6 +72,8 @@ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = Fal
internlm_accelerator.set_rng_state(current_state)

def reset(self):
if fake_mode:
return
self._current_mode = None
self._seeds = {}
self._seed_states = {}
Expand Down Expand Up @@ -131,3 +142,7 @@ def seed(parallel_mode: ParallelMode):
yield _SEED_MANAGER.set_mode(parallel_mode)
finally:
_SEED_MANAGER.set_mode(current_mode)


def reset_seed():
_SEED_MANAGER.reset()
12 changes: 6 additions & 6 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Ca
def weight_hook(
self, tensor: torch.Tensor, async_op: bool = False, module: nn.Module = None, is_bias: bool = False
) -> torch.Tensor:
if dist.get_world_size(self.process_group) <= 1:
if gpc.get_group_size(self.process_group) <= 1:
return tensor

if not self.overlap:
Expand All @@ -545,7 +545,7 @@ def grad_hook(
reduce_op: dist.ReduceOp = dist.ReduceOp.AVG,
is_bias: bool = False,
) -> Tuple[torch.Tensor, AsyncCommHandle]:
if dist.get_world_size(self.process_group) <= 1:
if gpc.get_group_size(self.process_group) <= 1:
return tensor, DUMMY_HANDLE_CONST

if not self.overlap:
Expand Down Expand Up @@ -573,7 +573,7 @@ def grad_hook(
result, handle = (
self._get_constant_zero(
(
tensor.shape[0] // dist.get_world_size(self.process_group),
tensor.shape[0] // gpc.get_group_size(self.process_group),
*tensor.shape[1:],
)
),
Expand Down Expand Up @@ -634,10 +634,10 @@ def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: in
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx

if dist.get_world_size(group) <= 1:
if gpc.get_group_size(group) <= 1:
return input_

seq_world_size = dist.get_world_size(group)
seq_world_size = gpc.get_group_size(group)

input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
Expand All @@ -647,7 +647,7 @@ def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: in

@staticmethod
def backward(ctx, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]:
if dist.get_world_size(ctx.group) <= 1:
if gpc.get_group_size(ctx.group) <= 1:
return (None, *grad_output, None, None)

return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
Expand Down
Loading

0 comments on commit b536cd3

Please sign in to comment.