Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In-place operations in triton kernel might result in incorrect gradient calculations #272

Open
Tcc0403 opened this issue Sep 26, 2024 · 3 comments · May be fixed by #273
Open

In-place operations in triton kernel might result in incorrect gradient calculations #272

Tcc0403 opened this issue Sep 26, 2024 · 3 comments · May be fixed by #273
Labels
bug Something isn't working

Comments

@Tcc0403
Copy link
Collaborator

Tcc0403 commented Sep 26, 2024

🐛 Describe the bug

#254 #262 (comments)

PyTorch’s autograd system records operations on tensors to construct a computational graph, which is used for computing gradients. When an in-place operation is performed on a tensor, the autograd system needs to ensure that the computational graph reflects the modified values.

https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks

Each tensor in PyTorch has an internal version counter that is incremented every time an in-place operation is performed.

https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382

Since we don't explicitly call pytorch in-place operations, the version counter doesn't change when we do in-place operations in triton kernels, i.e., pytorch's "In-place correctness checks" mechanism won't work properly and show no error to user.

Reproduce

import torch
import torch.nn.functional as F

from liger_kernel.transformers.functional import liger_cross_entropy


def run_inplace_experiment(logits_p, logits_q, cross_entropy_fn):
    _p = logits_p.clone().detach().requires_grad_(True)
    _p.retain_grad()
    softmax = torch.nn.Softmax(dim=-1)
    p = softmax(_p)
    p.retain_grad()
    loss = cross_entropy_fn(p, logits_q)
    loss.backward(retain_graph=True)

    print(f"Cross Entropy Loss: {loss.item()}")
    print(f"Input _p: {_p}")
    print(f"Input logits_q: {logits_q}")
    print(f"Gradients of p (batch item 0): {p.grad[0]}")
    print(f"Gradients of _p (batch item 0): {_p.grad[0]}")


torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)


run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=F.cross_entropy)

print()
print("LIGER:")
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)
python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017,  0.0029,  0.0003,  0.0055, -0.0182,  0.0024,  0.0023,  0.0032],
       device='cuda:0')

LIGER:
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([2.1320e-05, 3.4830e-05, 6.8024e-06, 6.7467e-05, 1.3247e-02, 2.9687e-05,
        2.8429e-05, 3.8656e-05], device='cuda:0')

Solution

One trivial solution is performing a no-op like inplace operation, such as .add_(0) and .mul_(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.

With this approach, I suggest adding a inplace=True/False parameter to those functions involved in-place operations, so users can set it to False (using extra tensors) when getting errors.

Versions

Environment Report:

Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version: 2.4.1+cu121
CUDA version: 12.1
Triton version: 3.0.0
Transformers version: 4.45.0

@Tcc0403 Tcc0403 linked a pull request Sep 26, 2024 that will close this issue
3 tasks
@ByronHsu
Copy link
Collaborator

ByronHsu commented Oct 3, 2024

should we adopt the second solution since the first one introduces quite a lot of overhead? also, can you elaborate under which case will this behavior happen?

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Oct 3, 2024

@ByronHsu

also, can you elaborate under which case will this behavior happen?

Consider the following forward graph:

graph TD
    A[input] -->|a| B[exp]
    B -->|b| C[liger_ce]
    C -->|loss| ouput
Loading

to calculate gradients of exp layer, which is exp(input), we can either:

  1. save input tensor a in forward pass, then recompute exp(a) in backward pass
  2. save output tensor b in forward pass, no need further operations in backward pass (assum torch marks it as version 0)

Normally, we take the least computations/memory option, 2. in this case.

graph TD
    A[input] -->|a| B["exp <br> saved tensors: b (v0)"]
    B -->|b| C[liger_ce]
    C -->|loss| ouput
Loading

After a complete forward pass from input a to loss, now we call loss.backward().

graph TD
    A[input] <-->|dx * grad_ce = b' * grad_ce| B["exp <br> saved tensors: b' (v0)<br>(changed by liger_ce)"]
    B <-->|grad_ce| C[liger_ce]
    C <-->|loss| ouput
Loading

Notice that in forward pass we stored the gradients of liger_ce at b, the input tensor of it, so the saved tensor b in exp layer has been changed as well. Since the saved tensor is corrupted, exp layer can't produce the correct gradients.

Replacing exp with any layer that stores output tensor and liger_ce with any layer that performs inplace operations on input, will result in the same behavior.

tl;dr
The saved tensors are corrupted by inplace operations.

Why no error?

The reason why it doesn't raise the error is because triton kernel doesn't bump the version when doing inplace op, so it's still v0 when computing gradients in backward.

If we do inplace outside of kernel by calling torch function, version can be correctly updated.

graph TD
    A[input] <-->|"dx * grad_output <br>= b' * grad_output"| B["exp <br> saved tensors: b' (v1)<br>(changed by inplace op)"]
    B <-->|grad_output| C["torch's inplace op"]
    C <-->|something| something
Loading

Thus, the error can be detected.

@Tcc0403
Copy link
Collaborator Author

Tcc0403 commented Oct 12, 2024

We can keep pointers of gradients when designing a kernel, and add a boolean argument to autograd.function for users to decide whether storing gradients inplace or not.

If False, we can allocate new memory and pass it to kernel. E.g. X_ptr and dX_ptr as below:

_jsd_kernel[(n_rows,)](
X_ptr=_input, # input in logspace, X = log Q
X_stride=_input.stride(-2),
Y_ptr=target, # ground truth in logspace, Y = log P
Y_stride=target.stride(-2),
loss_ptr=loss,
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
beta=beta,
n_rows=n_rows,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
)

If True, we can just pass the existing tensor that we want to perform in-place storing. E.g. X_ptr and dX_ptr as below:
_jsd_kernel[(chunk_n_rows,)](
X_ptr=student_prob_chunk,
X_stride=student_prob_chunk.stride(-2),
Y_ptr=teacher_prob_chunk,
Y_stride=teacher_prob_chunk.stride(-2),
loss_ptr=loss_1d_slice,
loss_stride=loss_1d_slice.stride(-2),
dX_ptr=student_prob_chunk,
dX_stride=student_prob_chunk.stride(-2),
beta=jsd_beta,
n_rows=BT, # batchmean
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
)

Above examples show that we can design a kernel which looks "out-place" but still can achieve "in-place" storing.

One trivial solution is performing a no-op like inplace operation, such as .add_(0) and .mul_(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.

Since the trivial solution introduces quite a lot of overhead, we can just do it only in the first pass as a in-place correctness checker.

A possible implementation could be like this:

@triton.jit
def _kernel(
    x_ptr, # input tensor
    y_ptr, # output tensor
    dx_ptr, # gradients of input  
    ... 
):
    ... # do something


def forward(_input, inplace: bool, ...):
    ... # do something
    if inplace:
        dx = _input
        if first_pass: # I haven't come up with a good way to detect first pass or not
            _input.add_(0) 
    else:
        dx = tensor.zeros_like(_input)
    _kernel[(...)](
        x_ptr=_input,
        y_ptr=output,
        dx_ptr=dx,
        ...
    )
    return output

cc @ByronHsu @lancerts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants