Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get 70b in our fork working with pp4, tp4, dp>1 #19

Open
ischlag opened this issue Sep 3, 2024 · 6 comments
Open

Get 70b in our fork working with pp4, tp4, dp>1 #19

ischlag opened this issue Sep 3, 2024 · 6 comments
Assignees

Comments

@ischlag
Copy link

ischlag commented Sep 3, 2024

Using our launcher and the latest pull of our pretrain repo you can run a Llama3 70B model as follows. Thanks to @AleHD for getting activation recompute and async working.

(export DP=1 PP=4 BACC=16; python launcher.py nanotron=llama3_70b \
	run.slurm.time="0:19:56" \
	nanotron.tokens.train_steps=15 \
	nanotron.general.project=Llama3_70B_efficiency \
	nanotron.parallelism.dp=$DP \
	nanotron.parallelism.pp=$PP \
	nanotron.tokens.batch_accumulation_per_replica=$BACC \
	++nanotron.parallelism.tp_linear_async_communication=true \
	++nanotron.parallelism.tp_recompute_allgather=true \
	nanotron.general.run=baseline_dp${DP}_pp${PP}_tp4_baccum${BACC})

(The use of the shell variables is to not have to set them twice and ensure config always matches the name.)

This should work and get about 780 tokens per second per GPU. This will be higher with batch_accumulation_per_replica=64 and lower with batch_accumulation_per_replica=16 which we can use to trade off wall time for efficiency.

The problem is that 70B just barely fits into four tödi nodes with dp=1, pp=4 and tp=4. It's so close that it is OOM with dp>1.

Looking at memory usage, we can see that each pipeline stage from first to last uses about 10GB less memory than the previous one. So our OOM is on the node of the first PP stage. The splitting of the layers is in nanotron done automatically using

def get_block_compute_costs(self):
. I think we should use this PR to check if we can reduce the workload (and thus memory load) of the first issue and a custom PR to fit the model with pp=4 tp=4 and dp>1 by changing how nanotron splits the architecture across nodes.

@TJ-Solergibert TJ-Solergibert self-assigned this Sep 3, 2024
@C-TC
Copy link

C-TC commented Sep 3, 2024

A simple mod in pp split strategy. #20
If the workload in each pp stage is imbalanced, the bubble ratio would be larger. I'm wondering if it still OOM when zero-1 is enabled with a higher DP size.

@AleHD
Copy link

AleHD commented Sep 5, 2024

I ran a few tests. When acc=16 it does not OOM when using zero1 for dp>1, however when acc=128 it does OOM for dp=2,4 but for dp=8 it works again...

@C-TC
Copy link

C-TC commented Sep 6, 2024

It sounds like some activation memory is not released in time. Can we have a small scale experiment, e.g. TP=1, PP=4, DP=2,4,... and get some memory snapshot from pytorch? This functionality is only available on x86 machines unfortunately.

@idoh
Copy link

idoh commented Sep 8, 2024

No memory profile cause of the x86, but I do agree that the most likely cause is activation memory bug as for the DP=1 case the max memory is 50.2 GB and each GPU has almost 100 GB. Changing DP=2 shouldn't increase the memory by x2.

@idoh
Copy link

idoh commented Sep 8, 2024

I'll try to get a memory snapshot on Bristen as there we should have x86 machines.

@ischlag
Copy link
Author

ischlag commented Sep 17, 2024

any update on this? Has someone committed to it? I saw this PR @C-TC does it solve this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants