-
-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Want to try row split + all_reduce for MLP and attn #614
Comments
Row-parallel doesn't work for attn because the attn output projection is shuffled and it would need a little extra work to unshuffle it so the input columns actually align with the rows of each slice of the split tensor. I've fixed this for the MLP so far, by applying the input permutation from the down projection to the up and gate tensors (which is a small optimization regardless of TP), meaning the down projection can be trivially split by row. It's definitely not abandoned, just disabled for now since it's incomplete. For the record, it's not really six communications, but four: attn:
The intermediate state between layers is the same pinned CPU tensor that exists in the middle of each all-gather anyway. This simplifies the implementation somewhat and allows attn and MLP layers to use different sets of devices in the case of uneven splits. Everything could be reduced to two all-reduces in principle, or one all-reduce plus two all-gathers without a solution to the permutation issue during attn, but the reason I put it on hold and released the (experimental) feature in the state it's currently in is that I haven't yet found a way to do the all-reduce that's actually more efficient, not to mention user-friendly enough to be worth considering. There are already solutions for people who want maximally-efficient inference on server hardware with 2^n identical, headless, P2P-enabled devices, appropriate BIOS settings for compute workloads, and so on. There's no point in trying to compete with NVIDIA for performance, but I think ExLlama can still offer flexibility and usability advantages, especially for affordable hardware. Libraries like NCCL come with really annoying limitations, like an inability to gather tensors of different shapes, incompatibility with Windows, and apparently a strong preference for multiprocessing. At least a single-process all-reduce with NCCL does not appear to be any faster than two all-gathers even when one of the latter is the MLP intermediate state which is 3x as large as the residual stream. Obviously P2P would help, but for consumer hardware that currently comes down to the hacked NVIDIA driver that geohotz doesn't seem to be too interested in outside the context of TinyGrad (which is perfectly fair, to be clear). So, where I'm currently at, I'm experimenting, whenever I have time, to come up with a more efficient all-reduce operation that, ideally, works within a single process (though multithreaded is fine of course) and doesn't completely tank performance without P2P or on slow PCIe links. |
Thanks for answering! So for MLP, I should be able to directly split And you mentioned that:
How to do this one? I didn't find your mlp example on |
Problem is it's not Changing to row split just in the MLP also requires a few other changes. There are two different versions of the forward pass, one in Python and one in C++, and they'd both have to be updated. The logic is probably easiest to follow in the (poorly named) It currently goes like this:
The all-reduce version would have to be:
The output could then be passed directly into the next attention layer, but that attention layer has to skip broadcasting the already-broadcast tensor (conditionally since the embedding layer would still output a single tensor in pinned memory). And that should be it. |
This seems just what I did. I modified forward_tp_old but I got segment fault when applying row-split down projection. What did i missed?
|
We are trying to using all reduce TP to slash the communication time. I noticed that you have implemented Row split + all_reduce for MLP (not faster, disabled). Why this version is abandoned? By row split we only need 2 communications(all reduce) per layer. But the release version use 6 communications per layer using all gather. Can you share the row split version that I can test/modify?
Anyway, I noticd that your
ExLlamaV2Linear
implementedtp_split_row
andforward_tp_row
.So I tried to modify
ExLlamaV2MLP.tp_split
andExLlamaV2MLP.forward_tp_old
to implement an all reduced version tp, but I got segment fault in down_proj’s cpp kernel.The error message from gdb is as below:
The text was updated successfully, but these errors were encountered: