Skip to content

Commit

Permalink
Add FusedLinearJSD (#300)
Browse files Browse the repository at this point in the history
## Summary
similar to the fuse linear CE.

It handles the forward and backward pass of the final linear layer via
JSD by avoiding the materialization of the large logits tensor. Since
JSD is the last layer, we can compute the gradient at the forward pass.



## Testing Done
Hidden size: 4096, Vocab size: 128256

![fused_linear_jsd_memory](https://github.com/user-attachments/assets/231303d1-4734-49fb-8c69-8e60730563c2)

![fused_linear_jsd_speed](https://github.com/user-attachments/assets/d83c85ec-ab29-44e0-a3d9-ad85acf4577d)

- Hardware Type: NVIDIA H100 80GB HBM3 (SXM5)
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Qingquan Song <[email protected]>
  • Loading branch information
Tcc0403 and qingquansong authored Oct 11, 2024
1 parent 9b10f48 commit ff6650b
Show file tree
Hide file tree
Showing 11 changed files with 792 additions and 42 deletions.
24 changes: 24 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,27 @@ jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20
jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
fused_linear_jsd,liger,forward,speed,ms,BT,B x T,1024,110.02185821533203,110.02185821533203,110.02185821533203,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1
fused_linear_jsd,liger,forward,speed,ms,BT,B x T,2048,124.14070129394531,124.14070129394531,124.14070129394531,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1
fused_linear_jsd,liger,forward,speed,ms,BT,B x T,4096,143.15420532226562,143.15420532226562,143.15420532226562,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1
fused_linear_jsd,liger,forward,speed,ms,BT,B x T,8192,180.90406799316406,180.90406799316406,180.90406799316406,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:18,0.3.1
fused_linear_jsd,torch,forward,speed,ms,BT,B x T,1024,9.556896209716797,9.550745964050293,9.576268196105957,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1
fused_linear_jsd,torch,forward,speed,ms,BT,B x T,2048,18.73731231689453,18.732704162597656,18.737701416015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1
fused_linear_jsd,torch,forward,speed,ms,BT,B x T,4096,37.830482482910156,37.80821990966797,37.85274124145508,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1
fused_linear_jsd,torch,forward,speed,ms,BT,B x T,8192,75.15289306640625,75.15289306640625,75.15289306640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:20,0.3.1
fused_linear_jsd,liger,full,speed,ms,BT,B x T,1024,111.16019439697266,111.16019439697266,111.16019439697266,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1
fused_linear_jsd,liger,full,speed,ms,BT,B x T,2048,125.6825942993164,125.6825942993164,125.6825942993164,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1
fused_linear_jsd,liger,full,speed,ms,BT,B x T,4096,144.00784301757812,144.00784301757812,144.00784301757812,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1
fused_linear_jsd,liger,full,speed,ms,BT,B x T,8192,182.5832977294922,182.5832977294922,182.5832977294922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:24,0.3.1
fused_linear_jsd,torch,full,speed,ms,BT,B x T,1024,25.977184295654297,25.968351364135742,25.989356994628906,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1
fused_linear_jsd,torch,full,speed,ms,BT,B x T,2048,49.48417663574219,49.47330093383789,49.495052337646484,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1
fused_linear_jsd,torch,full,speed,ms,BT,B x T,4096,98.31510162353516,98.31510162353516,98.31510162353516,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1
fused_linear_jsd,torch,full,speed,ms,BT,B x T,8192,195.29539489746094,195.29539489746094,195.29539489746094,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:27,0.3.1
fused_linear_jsd,liger,full,memory,MB,BT,B x T,1024,4652.48486328125,4652.48486328125,4652.48486328125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1
fused_linear_jsd,liger,full,memory,MB,BT,B x T,2048,5231.93798828125,5231.93798828125,5231.93798828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1
fused_linear_jsd,liger,full,memory,MB,BT,B x T,4096,6391.87548828125,6391.87548828125,6391.87548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1
fused_linear_jsd,liger,full,memory,MB,BT,B x T,8192,8711.75,8711.75,8711.75,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:33,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859375,10609.005859375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
250 changes: 250 additions & 0 deletions benchmark/scripts/benchmark_fused_linear_jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD


class TorchJSD(torch.nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
super(TorchJSD, self).__init__()
self.kl = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
self.beta = beta
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
return loss.to(self.dtype)


class TorchLMHeadJSD(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based jsd loss.
:param H: hidden size
:param V: vocab size
:param temperature: softmax temperature
:param beta: jsd beta
"""

def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
device: torch.device,
temperature: float = 1.0,
beta: float = 0.5,
):
super().__init__()
self.student_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.jsd = TorchJSD(beta, dtype=dtype)
self.temperature = temperature

def forward(self, student_input, teacher_input):
student_logits = self.student_lin(student_input)
teacher_logits = self.teacher_lin(teacher_input)
student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1)
teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1)

return self.jsd(student_prob, teacher_prob)


class LigerLMHeadJSD(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
device: torch.device,
temperature: float = 1.0,
beta: float = 0.5,
):
super().__init__()
self.student_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.fused_jsd = LigerFusedLinearJSD(beta, temperature)

def forward(self, student_input, teacher_input):
return self.fused_jsd(
student_input,
self.student_lin.weight,
teacher_input,
self.teacher_lin.weight,
)


#############################################################################
# Test the memory consumption of the fused linear JSD
#############################################################################


def bench_memory_fused_linear_jsd(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)

# init the linear in all FusedLinearJSDs with the same weights
torch_lm_head_jsd.student_lin.weight.data = (
liger_lm_head_jsd.student_lin.weight.data
) = torch.rand(V, H, device=device, dtype=dtype)
torch_lm_head_jsd.teacher_lin.weight.data = (
liger_lm_head_jsd.teacher_lin.weight.data
) = torch.rand(V, H, device=device, dtype=dtype)

student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
teacher_input = torch.rand(BT, H, dtype=dtype, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_jsd(student_input, teacher_input)
elif provider == "torch":
return torch_lm_head_jsd(student_input, teacher_input)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


# #############################################################################
# # Test the speed of the fused linear JSD
# #############################################################################


def bench_speed_fused_linear_jsd(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
BT = input.x
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
mode = input.kernel_operation_mode

dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_jsd = TorchLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(H=H, V=V, dtype=dtype, device=device).to(device)

# init the linear in all FusedLinearJSDs with the same weights
torch_lm_head_jsd.student_lin.weight.data = (
liger_lm_head_jsd.student_lin.weight.data
) = torch.rand(V, H, device=device, dtype=dtype)
torch_lm_head_jsd.teacher_lin.weight.data = (
liger_lm_head_jsd.teacher_lin.weight.data
) = torch.rand(V, H, device=device, dtype=dtype)

student_input = torch.rand(BT, H, requires_grad=True, dtype=dtype, device=device)
teacher_input = torch.rand(BT, H, dtype=dtype, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_jsd(student_input, teacher_input)
elif provider == "torch":
return torch_lm_head_jsd(student_input, teacher_input)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[
student_input,
torch_lm_head_jsd.student_lin.weight,
torch_lm_head_jsd.teacher_lin.weight,
],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "fused_linear_jsd",
"x_name": "BT",
"x_label": "B x T",
"x_values": [2**i for i in range(10, 14)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [
{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_fused_linear_jsd,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_jsd,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
38 changes: 2 additions & 36 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import triton
import triton.language as tl

from liger_kernel.ops.utils import element_mul_kernel


@triton.jit
def liger_cross_entropy_kernel(
Expand Down Expand Up @@ -159,42 +161,6 @@ def liger_cross_entropy_kernel(
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning


@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""

# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)

# Locate the start index
X_ptr += program_id * X_stride

# Load the gradient output value
grad_output = tl.load(grad_output_ptr)

# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)


def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
BT, V = _input.shape
n_rows = BT
Expand Down
6 changes: 2 additions & 4 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import torch
import triton

from liger_kernel.ops.cross_entropy import (
element_mul_kernel,
liger_cross_entropy_kernel,
)
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
from liger_kernel.ops.utils import element_mul_kernel

# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
Expand Down
Loading

0 comments on commit ff6650b

Please sign in to comment.