Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsloth kernels as a thunder executor #1174

Merged
merged 38 commits into from
Mar 21, 2024
Merged

Unsloth kernels as a thunder executor #1174

merged 38 commits into from
Mar 21, 2024

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Mar 20, 2024

RoPE equivalence check
from unsloth.kernels.rope_embedding import Fast_RoPE_Embedding
import torch

# litgpt's implementation

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    head_size = x.size(-1)
    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x * cos) + (rotated * sin)
    return roped.to(dtype=x.dtype)


def build_rope_cache(
    seq_len: int, n_elem: int, device = None, base: int = 10000, condense_ratio: int = 1
):
    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
    seq_idx = torch.arange(seq_len, device=device) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
    return torch.cos(idx_theta), torch.sin(idx_theta)


def test_rope(dtype):
    B, nh, T, hs = 2, 128, 4096, 16
    cos, sin = build_rope_cache(T, hs, device="cuda")
    q_unsloth = torch.rand((B, nh, T, hs), dtype=dtype, device='cuda', requires_grad=True)
    q_ref = q_unsloth.clone()
    q_unsloth.retain_grad()
    q_ref.retain_grad()

    # forward pass
    y_unsloth_q = Fast_RoPE_Embedding.apply(q_unsloth.transpose(1, 2).contiguous(), cos, sin).transpose(1, 2)
    y_ref_q = apply_rope(q_ref, cos, sin)

    print(y_unsloth_q.sum(), y_ref_q.sum())
    print("===" * 10)

    dy = torch.randn_like(y_ref_q, requires_grad=False)

    y_ref_q.backward(dy, retain_graph=True, inputs=q_ref)
    dts_ref = q_ref.grad.clone()

    y_unsloth_q.backward(dy, retain_graph=True)
    dts_usloth = q_unsloth.grad.clone()

    print("Unsloth", dts_usloth.sum() if dts_usloth is not None else None)
    print("Reference", dts_ref.sum() if dts_ref is not None else None)
    print("===" * 10)


test_rope(torch.float32)
RMSNorm equivalence check
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
import torch

# litgpt's implementation
def rms_norm(x, weight, dim: int, eps: float, add_unit_offset: bool):
    dtype = x.dtype
    x = x.float()
    norm_x = torch.mean(x * x, dim=dim, keepdim=True)
    x_normed = x * torch.rsqrt(norm_x + eps)
    x_normed = x_normed.to(dtype=dtype)
    if add_unit_offset:
        return x_normed * (1 + weight)
    return x_normed * weight

# unsloth's implementation
def rms_norm_2(X, weight, eps):
    old_dtype = X.dtype
    XX = X.to(torch.float32)
    variance = XX.square().mean(-1, keepdim = True)
    variance += eps
    XX = XX * variance.rsqrt()
    X = XX.to(old_dtype) # Must preserve due to residual
    X *= weight
    return X

def test_layer_norm(M, N, dtype, eps=1e-5):
    weight = torch.rand((N,), dtype=dtype, device='cuda', requires_grad=True)
    x = torch.randn((M, N), dtype=dtype, device='cuda')
    dy = torch.randn_like(x)
    x.requires_grad_(True)

    # forward pass
    y_unsloth = Fast_RMS_Layernorm.apply(x, weight, eps, False)
    # y_unsloth = rms_norm_2(x, weight, eps)
    y_ref = rms_norm(x, weight, -1, eps, False)

    # backward pass
    y_unsloth.backward(dy, retain_graph=True)
    dx_unsloth, dw_unsloth = [(t.grad.clone() if t.grad is not None else None) for t in [x, weight]]
    x.grad, weight.grad = None, None
    y_ref.backward(dy, retain_graph=True)
    dx_ref, dw_ref = [(t.grad.clone() if t.grad is not None else None) for t in [x, weight]]

    # compare
    print(y_unsloth.sum(), y_ref.sum())
    if not torch.allclose(y_unsloth, y_ref, atol=1e-2, rtol=0):
        print("y diff")
    print(dx_unsloth.sum(), dx_ref.sum())
    if not torch.allclose(dx_unsloth, dx_ref, atol=1e-2, rtol=0):
        print("dx diff")
    print(dw_ref.sum())
    if not torch.allclose(dw_unsloth, dw_ref, atol=1e-2, rtol=0):
        print("dw diff")


test_layer_norm(1151, 8192, torch.float32)
print("Test pased")
SwiGLU equivalence check
from unsloth.kernels.swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
import torch

# litgpt's implementation

def swiglu(x: torch.Tensor, w1, w2, w3) -> torch.Tensor:
    x_fc_1 = torch.nn.functional.linear(x, w1)
    x_fc_2 = torch.nn.functional.linear(x, w2)
    x = torch.nn.functional.silu(x_fc_1) * x_fc_2
    return torch.nn.functional.linear(x, w3)

def unsloth_swiglu(x: torch.Tensor, w1, w2, w3) -> torch.Tensor:
    x_fc_1 = torch.nn.functional.linear(x, w1)
    x_fc_2 = torch.nn.functional.linear(x, w2)
    x = SwigluFunction.apply(x_fc_1, x_fc_2)
    return torch.nn.functional.linear(x, w3)


class SwigluFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, e, g):
        ctx.save_for_backward(e, g)
        return swiglu_fg_kernel(e, g)

    @staticmethod
    def backward(ctx, dW):
        e, g = ctx.saved_tensors
        batch, seq_len, hd = e.shape
        e  = e .view(-1, e .shape[-1])
        g  = g .view(-1, g .shape[-1])
        DW, e, g = swiglu_DWf_DW_dfg_kernel(dW, e, g)
        e = e.view(batch, seq_len, hd)
        g = g.view(batch, seq_len, hd)
        return g, e


def test_swiglu(intermediate_size, n_embd, dtype, eps=1e-5):
    w1 = torch.rand((n_embd, intermediate_size), dtype=dtype, device='cuda', requires_grad=True).T
    w2 = torch.rand((n_embd, intermediate_size), dtype=dtype, device='cuda', requires_grad=True).T
    w3 = torch.rand((intermediate_size, n_embd), dtype=dtype, device='cuda', requires_grad=True).T
    x = torch.randn((2, 10, n_embd), dtype=dtype, device='cuda') # B, T, C
    dy = torch.randn_like(x)
    x.requires_grad_(True)
    w1.retain_grad()
    w2.retain_grad()
    w3.retain_grad()

    # forward pass
    y_unsloth = unsloth_swiglu(x, w1, w2, w3)
    y_ref = swiglu(x, w1, w2, w3)

    print(y_unsloth.sum(), y_ref.sum())
    if not torch.allclose(y_unsloth, y_ref, atol=1e-2, rtol=0):
        print("y diff")

    ts = [x, w1, w2, w3]
    # backward pass
    y_ref.backward(dy, retain_graph=True)
    dts_ref = [(t.grad.clone() if t.grad is not None else None) for t in ts]
    for t in ts:
        t.grad = None
    y_unsloth.backward(dy, retain_graph=True)
    dts = [(t.grad.clone() if t.grad is not None else None) for t in ts]

    # compare
    for dt_unsloth, dt_ref in zip(dts, dts_ref):
        print("Unsloth", dt_unsloth.sum() if dt_unsloth is not None else None)
        print("Reference", dt_ref.sum() if dt_ref is not None else None)
        print("===" * 10)


test_swiglu(115, 819, torch.float32)

@carmocca carmocca marked this pull request as ready for review March 21, 2024 01:33
@carmocca carmocca merged commit 24d5eba into main Mar 21, 2024
1 check passed
@carmocca carmocca deleted the carmocca/unsloth branch March 21, 2024 01:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant