-
Notifications
You must be signed in to change notification settings - Fork 209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Triton error on AMD GPUs #231
Comments
cc @Jokeren do you have any ideas? |
same error as triton-lang/triton#4128. also, @helloworld1 has tested on AMD GPUs before. Can you share your experience? |
Same error on MI210, not able to resolve it myself. Looks like triton / rocm compatibility issue. |
@ByronHsu Thanks a lot for the pointer. I really appreciate the help. So, when I manually change the
I should note that this doesn't use the minimal code above. It's a little bit more complicated with an fsdp wrapper around the model (I couldn't immediately create a minimal example), but I was wondering if you had any ideas as to what might be triggering this size mismatch error. |
I got this same error a few weeks ago trying to train on MI300x with axolotl (on torch 2.4.0+rocm6.1). There was one time I got the training run to start fiddling with various deps, but I could never reproduce that unfortunately. EDIT: Was able to get rope and rms_norm liger kernels to run without this error for a Llama 3.1 model on the setup I mentioned above. swiglu, cross entropy, and fused linear cross entropy all result in this error, in case that helps narrow anything down a little. |
I don't maintain the AMD backend. Better to try out triton/main or contact AMD people |
Following the logic in the issue linked here (#231 (comment)), noting that the warp size of AMD Instinct processors is 64 compared to 32 for NVIDIA GPUs, I halved Training appears to be working fine judging by my logs (slightly faster and significantly less memory while having similar loss and grad norms): No liger:
With liger:
I know nothing about triton kernels, so I wanted to ask if there are any potential adverse consequences to this? |
Ha! I had set |
🐛 Describe the bug
I'm trying to test this library on an HPC cluster with AMD MI250X GPUs, but I'm getting a weird seemingly Triton-related error specifically when I turn on
model.train()
. The following is a minimal example:I get the following error when I run this on an MI250X:
The same code snippet works fine when I turn off
model.train()
. I also have access to another cluster with NVIDIA GPUs and I can confirm that it works fine (with or withoutmodel.train()
) on NVIDIA GPUs (A100 and H100), so this is an AMD-specific issue. I would appreciate any help you could provide for debugging this issue.Reproduce
No response
Versions
I'm running on PyTorch-nightly + ROCm 6.2 + liger-kernel-nightly:
The text was updated successfully, but these errors were encountered: