Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] adding Tensor and Pipeline Parallelism to transformers #13690

Closed
stas00 opened this issue Sep 22, 2021 · 113 comments
Closed

[RFC] adding Tensor and Pipeline Parallelism to transformers #13690

stas00 opened this issue Sep 22, 2021 · 113 comments
Assignees
Labels
Performance Pipeline Parallel Tensor Parallel WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Sep 22, 2021

Following up on this proposal #12772 I just had a discussion with @hyunwoongko (with great help from @jaketae who patiently translated for us), and we tried to discuss a strategy of how to best integrate Tensor Parallelism (TP) and Pipeline Parallelism (PP) into transformers, making it easy for reviewers and the contributors. Note that
parallelformers currently implements only TP.

So here is a great example of how the TP can be added, as @hyunwoongko already implemented it in his fork for GPTNeo
tunib-ai@5bf8655 (he didn't use GPT2 since it already has the naive PP implemented). So you can see exactly what we want to merge. It's a very thin layer to the model and most of the functionality is in the helper parallel utils. The end of the change is multiple tests/examples that need to be converted to our test framework.

Now, while adding TP is relatively easy, adding PP is very complex in the current state of HF models because they include many features that interfere with implementing PP - due to the requirements:

  1. for the model to be nn.Sequential and
  2. inputs/outputs to be simple tensors with the first dimension of batch size.

So to implement PP we will most likely have to fork each model, strip the unnecessary for scalability features and only then be able to implement PP.

So my thinking is that perhaps we do it from the get the going? Instead of integrating TP into the normal model - say GPTNeo, we fork it to say GTPNeo3D from the get going and do all the work including TP and PP on that new model. Once everybody is happy we can rinse and repeat for other models.

I added 3D to GPTNeo to make GTPNeo3D - 3D = DP/TP/PP - not exactly sure about this particular name or attached to it, this is just something to start with.

Also once TP is implemented in say GTPNeo3D we can start replicating it to other models. Because parallelformers has them all covered already. PP will be much harder and we can do this in parallel.

I wanted to check in with the team to see if this approach resonates better, rather than modifying the existing models.

Thank you!

Also see this blog post explaining parallelforms.


Additionally see the main pytorch Parallelism discussion at pytorch/rfcs#32

@LysandreJik, @sgugger, @patrickvonplaten

@siddk
Copy link
Contributor

siddk commented Sep 22, 2021

@stas00 - I like this a lot! And as we're been dragging our feet with implementing some of the Megatron 3D parallelism into mistral - I think it might be a great way for us to collaborate; we can just start with the base GPT-2 model perhaps?

I think my (and the Mistral team's) addition the next few weeks will be trying to do some benchmarking of Megatron and existing gains with various subsets of parallelism (at a very fundamental level - profiling which kernels are being called, etc.) and maybe creating a set of unit tests to verify correctness?

Separately - might be worth keeping logs of how to "3D-ify" new models, and ways we might make that procedure even easier moving forward.

Let me know if this makes sense!

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

@stas00 @siddk If we are creating a new class, we do not need to modify the existing parallelize() method, so we do not need to work with GPTNeo. I think GPT2 would be better.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

Thanks for the feedback, Sidd.

The reason @hyunwoongko thought of starting with GPTNeo was because GPT2 already has the naive PP parallelize(). But the problem is that it's not just in the model, it's also in the Trainer. So we probably need to choose some other action name for that function altogether. At least for a time being so that we could move forward.

Note that the intention is to do simple things first and not do too many things at once. So I think starting with GPTNeo on a clean slate is a better idea. Once it's happy it'd be trivial to replicate that to GPT2. And it's already done as you can see from the link in OP.

Here is my vision of 3Difying transformers:

step 1. implement TP in one model
step 2a. start replicating TP to other models
step 2b. start working on PP in one model
step 3a. start replicating PP to other models.

note how step 2 can be done in parallel by different people.

So I can see that Mistral's team efforts would be parallel work and not sequential. So for example:

step 3b. implement Mistral's GPT2 improvements to GPT2
step 4a. start replicating it to other models.

If were were to start with GPT2 we would interfere with your work, Sidd, so I think it's actually best if we pick 2 different starting models.

But let's stay focused in this discussion on TP+PP, otherwise it'd be too easy to get side-tracked. We already spent too much time talking - let's see some code going into transformers! :)

wrt trainers, it'll be a natural part of the work - I'm not worried too much about it. I don't know much about accelerate yet, but HF Trainer should be relatively easy.

@siddk
Copy link
Contributor

siddk commented Sep 22, 2021

This makes a lot of sense to me - thanks @stas00 and @hyunwoongko for the clarifications! The steps above form a pretty good concrete plan - but if you both are already planning on tackling it, maybe it makes sense for us to tackle some of the other Megatron-LM improvements first, like the custom loss scaling/kernels/etc. (in mistral, so we can break things 😅)? And as y'all build the "main API" for 3D parallelism, we can just drop that in, and train larger models!

The PR with the mistral's first set GPT-2 improvements is waiting on approval right now - once that's in we can move a bit faster as well.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

That sounds like a perfect plan to me, Sidd.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

@stas00 I think the following method is not good for megatron-friendly method.

step 1. implement megatron-friendly TP in one model
step 2a. start replicating megatron-friendly TP to other models
step 2b. start working on megatron-friendly PP in one model
step 3a. start replicating megatron-friendly PP to other models.

Ultimately, implementing PP requires rewriting all modeling code. (including GPT2Attention, GPT2MLP, GPT2Model, ...) I wasn't familiar with PP until not long ago. but recently, I became very familiar with PP and found out that we had to rewrite all the code. (generation_utils.py used for inference should also be changed.) Therefore, it is recommended that megatron-friendly TP and PP be implemented together. (I think it's inefficient to implement megatron-friendly TP alone.)

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

The transformers-friendly method (=parallelformers) has the advantage of being able to extend the model quickly because it does not need to rewrite the modeling code (it uses the existing transformers code), but it is not compatible with PP. So we have to remove all the transformers-friendly TP when implementing PP. Which strategy we take is a matter of choice. We can quickly expand them in a transformers friendly way, and then change them one by one to be megatron friendly like

step 1. implement transformers-friendly TP in one model
step 2a. start replicating transformers-friendly TP to other models
step 2b. start working on megatron-friendly TP + PP in one model
step 3a. start replicating megatron-friendly TP + PP to other models.

Or there is a way to not implement transformers-friendly methods because they will be removed anyway. But, since there are thousands of lines of code to write for megatron-friendly and tens of lines of code for transformers-friendly, the megatron-friendly approach will scale very slowly.

step 1. start working on megatron-friendly TP + PP in one model
step 2. start replicating megatron-friendly TP + PP to other models.

One thing to note is that the transformers-friendly TP implementation is completely eliminated when implementing the megatron-friendly TP. A megatron-friendly TP is implemented differently from a transformers-friendly TP.

@sgugger
Copy link
Collaborator

sgugger commented Sep 22, 2021

Adding a GPTNeo3D to experiment seems like a good idea to me. At the end of the day, that modeling file can leave in the same folder as modeling_gptneo.py.

Note that while you experiment, you can leverage #13467 to share models on the Hub that have no implementation in Transformers and still work with the auto-model API.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

Adding a GPTNeo3D to experiment seems like a good idea to me. At the end of the day, that modeling file can leave in the same folder as modeling_gptneo.py.

Great!

Note that while you experiment, you can leverage #13467 to share models on the Hub that have no implementation in Transformers and still work with the auto-model API.

The 3D GPTNeo model's weights are the same as a normal GPTNeo model's - i.e. it can be used w/ or w/ PP/TP, so I'm not sure why we need a special API?

And I guess we won't be able to use AutoModel, because the config.model_type will say 'gpt_neo', but we will want to load it with GPTNeo3D* classes.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

@hyunwoongko, you're bringing up excellent points.

I suppose the main question is how much of a benefit we can give to users by having just TP. My thinking is that if it's easy to add TP to all models and since you have already done this, let's do it.

I'm concerned that adding PP will be a very slow process because as you said it requires massive rewrites to the model's code, and meanwhile those models that are waiting their turn won't be very scalable (except with Deepspeed ZeRO).

Besides we can delegate the TP adding to the rest of the models to others (other developers and even community) since it's mostly just replaying the code you have already written. But it still requires work, at least in adding tests and documentation, and then PRs.

The only concern with adding the transformers-friendly way is that the external API remains the same when we add PP.

How does that sound?

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

@stas00 But anyway, I don't prefer PP. As you know, PP is memory inefficient because it is not compatible with ZeRO 2, 3. In fact, we also decided not to use PP when developing language models. So adding just TP would be helpful for many people. So let's go with the following strategy. but, as you said, the API for both methods should remain the same.

step 1. implement transformers-friendly TP in one model
step 2a. start replicating transformers-friendly TP to other models
step 2b. start working on megatron-friendly TP + PP in one model
step 3a. start replicating megatron-friendly TP + PP to other models.

But transformers-friendly TPs have no reason to rewrite their modeling code. What should we do?

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

That's great, @hyunwoongko!

And once we complete GPTNeo3D with TP we can decide whether to fold it back to the normal GPTNeo model or keep it separate. I'm saying that if at the end we will do PP only for a few select models (which is too a real possibiilty), then there is absolutely no need to fork 60 models and create a lot more maintenance work for transformers, if they will have just TP+DP.

@hyunwoongko
Copy link
Contributor

@stas00

In my opinion, transformers-friendly TP have no reason to write their own modeling code like GPTNeo3D.

  1. So the transformers-friendly TP will just use the existing model
  2. And let's make a new modeling class such as GPT2For3D when we develop the megatron-friendly TP + PP (GPT2, Bert, T5, etc, It will probably be some models, not all.)

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

I'm thinking of an API like this.

from transformers import GPTNeoModel

model = GPTNeoModel.from_pretrained("elutherai/gpt-neo-1.3B", tensor_model_parallel_size=4)

or

model = GPTNeoModel.from_pretrained("elutherai/gpt-neo-1.3B", tp=4)

I implemented megatron friendly model internally like

@classmethod
def from_yaml(
    cls,
    cfg_path: str,
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    tp: int = None,
    pp: int = None,
):
    """
    Create model from yaml config file

    Args:
        cfg_path: path of configurations
        tensor_model_parallel_size: tensor model parallel world size
        pipeline_model_parallel_size: pipeline model parallel world size
        tp (int): equivalent with `tensor_model_parallel_size`
        pp (int): equivalent with `pipeline_model_parallel_size`
    """

    if tp is not None:
        assert tensor_model_parallel_size == 1, (
            "you can't use param `tensor_model_parallel_size` and `tp` at the same time. "
            "they are equivalent. so please use one of them."
        )
        tensor_model_parallel_size = tp

    if pp is not None:
        assert pipeline_model_parallel_size == 1, (
            "you can't use param `pipeline_model_parallel_size` and `pp` at the same time. "
            "they are equivalent. so please use one of them."
        )
        pipeline_model_parallel_size = pp

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

I totally agree, that this is a much better way to proceed.

@sgugger, is it ok if we change the initial proposal and add TP to the normal model classes? As we continued discussing this and based on my experience with trying to add PP to transformers it'll be a huge amount of work to do it for all models, and so it's very likely many models will never get it. And since TP requires no changes to the models then there is no reason to make it difficult on users and maintainers to fork the model for that feature to work.

And we believe just having TP+DP will already be a great boon to the scalability of the models (if Deepspeed ZeRO doesn't already address this for whatever reason).

For PP new classes will be needed 100%.

Thank you.

@sgugger
Copy link
Collaborator

sgugger commented Sep 22, 2021

As long as the changes are minimal, no objection from my side. I agree it makes much more sense to get that out if it's faster and deliver the PP later on.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

the problem is the 'parallelize()' method, the API for layerwise naive parallelism in GPT2 and T5. Do you agree to remove this method? The megatron-friendly TP + PP cannot handle it that way. This is because in the case of PP, parallelization occurs at the time of model creation. That's why I let from_pretrained takes the tp and pp sizes as input.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

I'm thinking of an API like this.

from transformers import GPTNeoModel

model = GPTNeoModel.from_pretrained("elutherai/gpt-neo-1.3B", tensor_model_parallel_size=4)

or

model = GPTNeoModel.from_pretrained("elutherai/gpt-neo-1.3B", tp=4)

I think transformers tends to go with more spelled out args, but not too too long, so perhaps tensor_parallel_size=4

the problem is the 'parallelize()' method, the naive parallelism (layer-wise) implementation. Do you agree to remove this method? The megatron-friendly TP + PP cannot handle it that way. This is because in the case of PP, parallelization occurs at the time of model creation. That's why I let from_pretrained take the tp and pp sizes as input.

The naive PP is experimental:

PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice.

but we shouldn't remove it until we replace it with real PP, because users actively use the naive PP at the moment.

That's why we proposed to work on NeoGPT first so that it's easier to take time and not need to have that older code interfere.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

@stas00

I think transformers tends to go with more spelled out args, but not too too long, so perhaps tensor_parallel_size=4

So I made it support both variables (long name and short name). not good?

but we shouldn't remove it until we replace it with real PP, because users actively use the naive PP at the moment. That's why we proposed to work on NeoGPT first so that it's easier to take time and not need to have that older code interfere.

I totally agree with you. Let's start from GPTNeo.


The second thing to discuss is the embedding layer. When I implemented parallelformers, I didn't actually parallelize the embedding layer. In this case, the embedding layer is copied to all GPUs. Therefore, it is memory inefficient. But in fact we can apply VocabParallelEmbedding and VocabParallelCrossEntropy. (However, we should not use the original CrossEntropy in this case) we also need to decide whether or not to add VocabParallelEmbedding to the transforemrs-friendly TP.

I didn't tell you guys, but I actually experimented little by little. I already figured out that I can do VocabParallelEmbedding internally with transformers-friendly TPs.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

@stas00

I think transformers tends to go with more spelled out args, but not too too long, so perhaps tensor_parallel_size=4

So I made it support both variables (long name and short name). not good?

At the moment I don't recall transformers using shortcut aliases for arg names, so probably just having tensor_parallel_size is fine. (no need to repeat "model_" as the shorter name I proposed is not ambiguous)

The second thing to discuss is the embedding layer. When I implemented parallelformers, I didn't actually parallelize the embedding layer. In this case, the embedding layer is copied to all GPUs. Therefore, it is memory inefficient. But in fact we can apply VocabParallelEmbedding and VocabParallelCrossEntropy. (However, we should not use the original CrossEntropy in this case) we also need to decide whether or not to add VocabParallelEmbedding to the transformers-friendly TP.

Was CrossEntropy the reason for not doing it in the first place in parallelformers? I guess the integration will allow to overcome this then if I understood your comment correctly.

But otherwise by all means let's make TP as efficient as possible.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

  1. I like the name tensor_parallel_size more, but I named it tensor_model_parallel_size because I wanted to follow the Megatron-LM nomenclature. In fact, if we input the mpu to DeepSpeed, methods such as mpu.XXX_model_parallel_rank() are called inside it. Therefore, it is better to unify the names.

  2. Since parallelformers is inference only toolkit, there was no reason to worry about CrossEntropy. The reason I didn't do it at the time was because it was a bit complicated. (But it's not difficult.)

How about implementing it with options first?

from_pretrained(tensor_model_parallel_size=4, embedding_parallelism=True)

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

  1. I like the name tensor_parallel_size more, but I named it tensor_model_parallel_size because I wanted to follow the Megatron-LM nomenclature. In fact, if we input the mpu to DeepSpeed, methods such as mpu.XXX_model_parallel_rank() are called inside it.

Ah, ok, we can use tensor_model_parallel_size then to make things easier to cross-code. May be then add a note at why this particular name has been chosen.

  1. Since parallelformers are inference only in the first place, there was no reason to worry about CrossEntropy. The reason I didn't do it at the time was because it was a bit complicated. (But it's not difficult.)

Ah, right, I forgot that parallelformers was intended for inference only in the first place. Yes, so what you proposed is a good idea.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

How about implementing it with options first?

from_pretrained(tensor_model_parallel_size=4, embedding_parallelism=True)

Is there a technical reason for not always doing the latter?

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

Because of VocabParallelCrossEntropy. the user should be able to use a loss function other than CrossEntropy by using the Transformers model. (RMS, Center Loss, Large-margin softmax, ...) With VocabParallelEmbedding, the Loss function should handle this appropriately. You can check this https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/cross_entropy.py

So I thought the default value of embedding_parallelism as false and turning it on when the user wants to.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

Thank you for the explanation, Hyunwooongko.

Then yes we need that arg. Should the default be False then, so the priority is for the user code to work out of the box and we document embedding_parallelism=True as an optimization?

Further, embedding is ambiguous since we have different types, should we say explicitly word_embed_parallelism?

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

Oh I was wrong. Only tying the output embeddings have problems with the loss function. I checked and it doesn't matter since neither gpt2 nor gpt neo are tying output embeddings.

In most cases, we don't need to worry about the loss function. Therefore, I will implement embedding parallelism works everytime so this option is unnecessary. and users do not need to worry about it. If I find a model that tying input and output embeddings without an lm head later, I will think about it then.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Sep 22, 2021

But maybe Meg-DS and GPT NeoX use embedsing tying. So this option will be needed in the future.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2021

If I'm not mistaken many models have input and output embeddings tied.

@deepakn94
Copy link

Hi all, I helped implement pipeline parallelism in Megatron (and was also one of the lead authors on the PipeDream project). Happy to answer any questions.

I had a question too: what is the current plan for the new PP-friendly model classes? What is going into these, and how will they be different from the vanilla model classes?

Thanks!

@lucasleesw
Copy link

Hi @hyunwoongko. Thank you for your reply. I got your idea about linear or conv1d layer in the Attention module or MLP module and I think it is great.
And indeed this would be kind of risky. For example, how about conv2d layers in src/transformers/models/vit/modeling_vit.py?
In our implementation, we built some api to help users built the dict map to handle this exception but we found it is hard to use for now.
I think it would be helpful if the metadata class has some api to let user extend the tensor parallelism methods.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jan 26, 2022

@lucasleesw I totally agree with you. Defining both columns and rows is probably the safest and most extensible way.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jan 26, 2022

@stas00 I initially used the method of defining both column and row parallel parameters, but since the process of defining them is quite difficult, I experimented with many ways to create a simple tensor parallelization map. But the simplification got the more possibility that can makes exceptions. So, like @lucasleesw's method, it would be best to use all three pieces of information: column, row, and mp_param.

We all parallelize in a similar way, and so does sagemaker too. Therefore, it would be convenient if we unify and manage this inside the transformers.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jan 26, 2022

@lucasleesw Omitting column parallel linear and tracining method won't cause any problems in vit. Because embedding is not in the module list. https://github.com/huggingface/transformers/blob/master/src/transformers/models/vit/modeling_vit.py#L146

They parallelize only the layers inside the base layer module (like BertLayer), not all existing layers. Even so, these simplifications can always make exceptions.

@hyunwoongko
Copy link
Contributor

@lucasleesw I'm also wondering about your pp implementation. could you let me know? I used deepspeed pp in the beginning, but now we are implementing the same method with sagemaker pp.

@lucasleesw
Copy link

@hyunwoongko You are right, thanks again for your inspiration.
Our implementation will be available very soon, we look forward for your advice.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jan 26, 2022

@lucasleesw I will upload a PR for tensor parallel mapping today. It would be great if you could reply to make a more general PR. How did you deal with the fused attention module (in the gpt2, transfo_xl)? it means attention layer that has the size like linear(3 * dim, dim). and If we create GPT2 with EncoderDecoderModel, then GPT2 has cross attention (q_attn) which is linear(2 * dim, dim). These shouldn't be handled simply because they are all appended with q, k, and v (or k and v for cross attention). How did you deal with them? Were you able to automate this without some mapping?

@stas00
Copy link
Contributor Author

stas00 commented Jan 26, 2022

So we have at least 3 possible "consumers" of additional model metadata at the moment: @hyunwoongko, @lucasleesw and @RezaYazdaniAminabadi - so perhaps instead of maintaining 3 different tables, would you agree on having one that contains all the fields that you need? and you can discuss between yourselves how you prefer to call those. We can give the new names a period of "experimental and a subject to change" until the dust settles and then they will get carved in stone at a later date to support backward compatibility. And I'm sure there will be other consumers for that type of metadata.

I don't see any reason not to have all the desired components written out explicitly, instead of being derived automatically. There is absolutely no reason to take a risque here, this is software engineering and not a stock market.

I propose to start with having a dedicated file for it with the first few models and then down the road we can see if it makes sense to move these into their model files. I just want to create minimal disturbance to the models code until we are ready to do so.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Jan 26, 2022

I opened a PR about tensor parallel mappings !

@stas00 stas00 self-assigned this Feb 20, 2022
@stas00 stas00 added WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress Performance Pipeline Parallel Tensor Parallel labels Feb 20, 2022
@huggingface huggingface deleted a comment from github-actions bot Feb 20, 2022
@codego7250
Copy link

@stas00 For TP, it will go with the megatron-lm's way to do the parallelism in a python way if I understood correctly. If that's the case, it leaves us the opportunity to support the tpu and etc since the only question is about the allgather&allreduce API.
If that's case, I'd like to move this direction ahead for transformer. I'm not sure where we are now and what's the right branch to start. It will be great if you can share the end-2-end impl and I can start from there.

@stas00
Copy link
Contributor Author

stas00 commented Feb 22, 2022

At the moment we have 2 projects that support TP (tensor parallelism):

Both are not yet integrated into transformers. Oslo we are just slow to integrate since I'm busy with BigScience and @jaketae is backing me up and has started to work on the integration. Deepspeed-Inference is still a work in progress on the core, and I have some initial PR that integrates it but there are some hanging issues as HF Trainer is not MPU-aware yet.

So at the moment Deepspeed-ZeRO is the only solid and working solution for scalability on the free side, and Sagemaker on the paid side (though I have never tried the latter myself).

PP is much more difficult, and we are leaving it to the end, in hope that pytorch will provide us a new much easier to use PP-api that is somewhat similar to sagemaker's paper https://arxiv.org/abs/2111.05972

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Feb 22, 2022

@stas00

  1. OSLO has the MPU, and this is compatible with deepspeed, ds-inference and megatron. If you need mpu, how about using this? maybe from oslo import MPU could work with ds-inference.
  2. I am wonder your ds-inference integration plan. we need to integrate it without Trainer (because it's not about training). What's your plan? Since OSLO TP can be used for both training and inference, we need to discuss how to provide it from inference view.
  3. I almost have implemented sagemaker-like PP internally, but I am not currently integrating it into the main branch. Because it can interfere with TP integration. So, when the TP integration work is finished, the PP will be merged into the main branch.

@stas00
Copy link
Contributor Author

stas00 commented Feb 22, 2022

I was just saying that it doesn't have MPU at the moment ;) And it's needed to sync the ds-inference tp-processes I think. But my mind is the BigScience at the moment so I don't have the brain cycles for indepth analysis at the moment.

Making ds-inference integration depend on oslo would be odd, but it could be ok at the beginning and eventually have an internal one - it's just one module that's already written.


why integrate ds-inference w/o Trainer? Trainer is just a name for both inference and training.

@AaronZLT
Copy link
Contributor

Any progress now? I think to integrate the TP and PP with transformers and deepspeed or megatron, offering a easier access to users today will be a great contribute.

@stas00
Copy link
Contributor Author

stas00 commented Sep 24, 2023

Given the complexity of TP/PP and which requires massive changes to the modeling code and having hundreds of architectures it's hard to tell if this will ever happen in transformers.

You already have Deepspeed integration in transformers, so it's trivial to scale pretty much any model in transformers' arsenal to any number of GPUs. And you should get about the same throughput with DeepSpeed ZeRO-3 as you'd with TP/PP/DP as long as you are on a fast internode connection.

@stas00
Copy link
Contributor Author

stas00 commented Sep 24, 2023

I'm actually going to close this, since it's too old and clearly isn't happening.

@stas00 stas00 closed this as completed Sep 24, 2023
@marianna13
Copy link

Hey @stas00
I'm curious about the claim that "you should get about the same throughput with DeepSpeed ZeRO-3 as you'd with TP/PP/DP as long as you are on a fast internode connection". Why everyone still uses Megatron-LM and not just Huggingface transformers with Deepspeed ZeRO-3 if it's the same throughput.
Thanks!

@stas00
Copy link
Contributor Author

stas00 commented Nov 29, 2023

That's an excellent question, @marianna13

First, Megatron-LM does more than just 3D parallelism - it has a superb set of various other features, so we can't compare projects just on the base of how well they solve one specific problem.

The main cons of 3D parallelism is that it requires the modeling code to be modified, which is far from trivial and it's very easy to introduce bugs and unfortunately we have seen some of those bugs being invisible until it's too late to fix them.

The main pros of ZeRO (Pytorch FSDP or Deepspeed ZeRO-DP) is that modeling code and scalability code are separate and the user only needs to get the modeling code right for a successful training. You just write the code as if you were to train on a single gpu and ZeRO can then scale it to any number of gpus w/o you needing to do anything about it.

Now specifically to the claim "you should get about the same throughput with DeepSpeed ZeRO-3 as you'd with TP/PP/DP as long as you are on a fast internode connection" - I personally haven't seen that yet because I'm yet to be given a chance to run on a cluster where one gets high inter-node speed. The fastest I have used so far was 340Gbps w/ A100 which is very slow.

Given that ZeRO implementations prefetch the sharded data, this comms overhead is overlapped with compute of the previous stage and if the inter-node speed is fast enough it'll be mostly hidden and not contributing to an additional overhead. I walked through the math here:
https://github.com/stas00/ml-engineering/tree/master/model-parallelism#inter-node-speed-requirements-to-use-zero
and another similar but more general walk-through here:
https://github.com/stas00/ml-engineering/tree/master/network#understanding-why-inter-node-network-speed-is-of-a-huge-importance

Additionally, recently pytorch and Deepspeed released a hybrid version of ZeRO (Hybrid FSDP and ZeRO++) which if you can fit the sharded model on a single node it will use the super-fast speed of NVLink for all comms and only do grads reduction DDP-stype over the slower inter-node if multiple nodes are used. (except it uses a faster version of comms than DDP since each gpu needs only its gradient shard - so it's 1/2 of the DDP traffic) This should lead to much higher TFLOPS, for models up to a certain size. e.g. on a 8x80GB node can at most fit a 30B-param model for mixed half-precision training. For larger models this won't help and the slower inter-node link becomes the defining speed for all comms :(

And even if the inter-node speed is slow and one gets less TFLOPS using ZeRO the important consideration to make is whether your team will finish the training sooner or later using 3D parallelism because it'll certainly take longer development/testing time than the ZeRO equivalent. Because if you just need to train something that Megatron-LM has implemented fully you should have close to 0 dev overhead. But if you need to introduce changes, you might spend weeks and months of dev time depending on the complexity of the changes and the skills of your engineers. So it's possible that the slower ZeRO training will still deliver a faster outcome and less hair will be lost in the process.

Also did you know Megatron-LM implemented ZeRO-1 and ZeRO-2 as well? And there is https://github.com/microsoft/Megatron-DeepSpeed/ - in other words, the awesome developers who give us these incredibly useful tools work on combining the best features of both approaches.

@marianna13
Copy link

Thank you very much for this detailed answer, @stas00 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Performance Pipeline Parallel Tensor Parallel WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests