You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using flyte to reproduce the token throughput and memory savings results reported in this repo's README under slightly different conditions: using the microsoft/Phi-3-mini-4k-instruct model on a single A100 gpu.
Are the performance benefits of liger only applicable to multi-gpu training workloads, or should it also take effect with single gpu training?
Hmm, as a sanity check can you try running your benchmark with a per-device batch size of 8 instead of 4? Using tensors that have dimensions that are multiples of 8 can be important for tensor-core utilization on modern nvidia GPUs (although that statement has a lot of caveats and I'm not sure that it is the issue).
I think this is because you are using A100 40GB so it is heavily memory bound while we use A100 80GB. You can maybe try using SGD optimizer instead of adamw so it takes less memory (-> more compute).
🐛 Describe the bug
I'm using
flyte
to reproduce the token throughput and memory savings results reported in this repo's README under slightly different conditions: using themicrosoft/Phi-3-mini-4k-instruct
model on a single A100 gpu.Are the performance benefits of liger only applicable to multi-gpu training workloads, or should it also take effect with single gpu training?
Reproduce
The code I used for this is essentially the same as the code in this repo's huggingface example: https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface
The full Flyte code is here: https://github.com/unionai/unionai-examples/pull/56/files
It produces this Flyte deck with the basic benchmark of liger vs. regular hf transformer:
As you can see liger does reduce peak memory reserved, but token throughput is slightly lower.
Please advice! I can work on a google colab example if that will help with reproducing the issue.
Versions
datasets==2.21.0
pandas==2.2.2
matplotlib==3.9.2
huggingface-hub==0.24.6
transformers==4.42.2
trl==0.10.1
torch==2.4.0
liger-kernel==0.2.1
The text was updated successfully, but these errors were encountered: