Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: TensorParallel with new strategy #1421

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: A
return next.to(dtype=x.dtype)


@torch.inference_mode()
@torch.no_grad()
def generate(
model: GPT,
prompt: torch.Tensor,
Expand Down
172 changes: 79 additions & 93 deletions litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import sys
import time
from functools import partial
from pathlib import Path
from typing import Literal, Optional, Union

Expand All @@ -12,83 +11,86 @@
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.fabric.utilities import rank_zero_only
from torch.distributed._functional_collectives import all_reduce

import litgpt.generate.base as generate_base
from litgpt import GPT, Config, Tokenizer
from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision


def tensor_parallel_linear(fabric: L.Fabric, linear: torch.nn.Linear, style: str) -> None:
world_size = fabric.world_size
dim, attr = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")}[style]
size = getattr(linear, attr)
if size % world_size != 0:
raise ValueError(
f"This linear's {attr} value ({size}) is not evenly divisible by the world size ({world_size})"
)

shard = torch.tensor_split(linear.weight, world_size, dim=dim)[fabric.global_rank]
# overwrite `.data` instead of recreating the parameter for quantization (bitsandbytes) support.
# the bitsandbytes linear classes use custom `torch.nn.Parameter` subclasses
linear.weight.data = shard
setattr(linear, attr, shard.size(dim))

if linear.bias is not None and dim == 0:
shard = torch.tensor_split(linear.bias, world_size)[fabric.global_rank]
linear.bias = torch.nn.Parameter(shard, requires_grad=linear.bias.requires_grad)


def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None:
from litgpt.utils import check_valid_checkpoint_dir, get_default_supported_precision
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)


def tensor_parallel_mlp(mesh, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None:
plan = {}
if isinstance(mlp, LLaMAMLP):
tensor_parallel_linear(fabric, mlp.fc_1, "colwise")
tensor_parallel_linear(fabric, mlp.fc_2, "colwise")
tensor_parallel_linear(fabric, mlp.proj, "rowwise")
mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size))
plan["fc_1"] = ColwiseParallel()
plan["fc_2"] = ColwiseParallel()
plan["proj"] = RowwiseParallel()
elif isinstance(mlp, GptNeoxMLP):
tensor_parallel_linear(fabric, mlp.fc, "colwise")
tensor_parallel_linear(fabric, mlp.proj, "rowwise")
mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size))
plan["fc"] = ColwiseParallel()
plan["proj"] = RowwiseParallel()
elif isinstance(mlp, LLaMAMoE):
# we use expert slicing across ranks, alternatively, we could create a expert parallelism group
# when the number of experts is a multiple of the world size
for expert in mlp.experts:
tensor_parallel_mlp(fabric, expert)
tensor_parallel_mlp(mesh, expert)
else:
raise NotImplementedError


def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None:
tensor_parallel_linear(fabric, attn.attn, "colwise")
tensor_parallel_linear(fabric, attn.proj, "rowwise")
attn.register_forward_hook(partial(all_reduce_output, fabric.world_size))


def all_reduce_output(world_size: int, module: torch.nn.Module, ins, outs) -> torch.Tensor:
return all_reduce(outs, "sum", list(range(world_size)))


def tensor_parallel(fabric: L.Fabric, model: GPT) -> GPT:

parallelize_module(mlp, mesh, plan)


def tensor_parallel_attn(mesh, attn: CausalSelfAttention) -> None:
plan = {
"attn": ColwiseParallel(),
"proj": RowwiseParallel(),
}
parallelize_module(attn, mesh, plan)


def parallelize(model, device_mesh):
tp_mesh = device_mesh["tensor_parallel"]
dp_mesh = device_mesh["data_parallel"]

assert tp_mesh.size() > 1
assert dp_mesh.size() == 1

plan = {
"transformer.wte": RowwiseParallel(input_layouts=Replicate()),
"transformer.ln_f": SequenceParallel(),
"lm_head": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
}
parallelize_module(model, tp_mesh, plan)

for block in model.transformer.h:
tensor_parallel_mlp(fabric, block.mlp)
tensor_parallel_attn(fabric, block.attn)
plan = {
# "norm_1": SequenceParallel(),
# "norm_2": SequenceParallel(),
}
# parallelize_module(block, tp_mesh, plan)
tensor_parallel_mlp(tp_mesh, block.mlp)
tensor_parallel_attn(tp_mesh, block.attn)

# update the config values to the shard sizes
# this is only relevant for `tensor_parallel_attn`, but it needs to run only once
world_size = fabric.world_size
attrs = ["n_head", "n_embd", "n_query_groups"]
for attr in attrs:
size = getattr(model.config, attr)
if size % world_size != 0:
raise ValueError(f"This {attr} value ({size}) is not evenly divisible by the world size ({world_size})")
setattr(model.config, attr, size // world_size)
if size % tp_mesh.size() != 0:
raise ValueError(f"This {attr} value ({size}) is not evenly divisible by the world size ({tp_mesh.size()})")
setattr(model.config, attr, size // tp_mesh.size())

return model


@torch.inference_mode()
@torch.no_grad()
def main(
prompt: str = "What food do llamas eat?",
*,
Expand Down Expand Up @@ -147,8 +149,8 @@ def main(
plugins = BitsandbytesPrecision(quantize[4:], dtype)
precision = None

# set "ddp" as the strategy for the launching functionality, but there's no data-parallelism
fabric = L.Fabric(devices="auto", strategy="ddp", precision=precision, plugins=plugins)
strategy = ModelParallelStrategy(parallelize_fn=parallelize)
fabric = L.Fabric(devices="auto", strategy=strategy, precision=precision, plugins=plugins)
fabric.launch()

check_valid_checkpoint_dir(checkpoint_dir)
Expand All @@ -164,45 +166,27 @@ def main(

fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
t0 = time.perf_counter()
# cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced
# which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert
# still, use init_tensor for the precision
with fabric.init_tensor(), torch.device("meta"):

with fabric.init_module(empty_init=True):
model = GPT(config)
fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

# sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device
# so that the CPU RAM doesn't OOM with larger models
for rank in range(fabric.world_size):
if fabric.global_rank == rank:
t0 = time.perf_counter()
state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu")
model.load_state_dict(state_dict, assign=True)
print(f"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

# cannot use `.setup_module` because it will wrap with DDP
model = fabric._precision.convert_module(model)

t0 = time.perf_counter()
model = tensor_parallel(fabric, model)
print(
f"[{rank}] Time to tensor-parallelize the model: {time.perf_counter() - t0:.02f} seconds.",
file=sys.stderr,
)

with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()

t0 = time.perf_counter()
model = fabric.to_device(model)
print(f"[{rank}] Time to move the model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
fabric.barrier()
model = fabric.setup(model)

t0 = time.perf_counter()
fabric.load_raw(checkpoint_path, model)
fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

with fabric.init_tensor():
# set the max_seq_length to limit the memory usage to what we need
model.max_seq_length = max_returned_tokens
# the rope cache which is on meta device
model.cos, model.sin = model.rope_cache()
# # enable the kv cache
model.set_kv_cache(batch_size=1)
model.eval()

t0 = time.perf_counter()

if compile:
torch._dynamo.config.automatic_dynamic_shapes = True
Expand All @@ -214,7 +198,7 @@ def main(
for i in range(num_samples):
t0 = time.perf_counter()
y = generate_base.generate(
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id
model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0
for block in model.transformer.h:
Expand All @@ -226,3 +210,5 @@ def main(
)
if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)

torch.distributed.destroy_process_group()
Loading