Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA: zero_pad speed improvements #630

Closed
wants to merge 16 commits into from
Closed

LoRA: zero_pad speed improvements #630

wants to merge 16 commits into from

Conversation

Andrei-Aksionov
Copy link
Collaborator

Hi there 👋

This PR is a result of #461.
In that issue I've found out that the creation of a new tensor with lora_ind (that are stored as a python list on a CPU) for each zero_pad call ...
https://github.com/Lightning-AI/lit-gpt/blob/807c7bc17413d53961f96dc668aa03c0b970a43f/lit_gpt/lora.py#L293-L295

... implicitly calls cudaStreamSynchronize every time and that slows down the forward pass a bit.

Traces

Note

Number are provided for the Nvidia T4 and 16-mixed precision.

Let's take a look at the traces for Pythia-410m.

Currently zero_pad takes a significant part of the time:
Screenshot 2023-10-09 at 7 05 55 PM

Note

Compare the size of cudaStreamSynchronize from the screenshot above (CUDA 12.1) and the one from the "Performance Study" issue (CUDA 11.8) - it's much smaller thanks to the newest CUDA.

After the code is optimized, from the trace we can see that the zero_pad now takes much less portion of the time:
Screenshot 2023-10-09 at 7 08 53 PM

In numbers it's 830 μs vs 126 μs.

LoRA fine-tuning

If to compare LoRA fine-tuning with Pythia-410m and 1k iterations, we have:

Model Loss $_{control}$ Loss $_{test}$ Time $_{control}$ Time $_{test}$
Pythia-70m 2.5835 2.5802 30.90 28.51
Pythia-410m 1.7976 1.7976 124.63 114.51

Not a drastic difference, but still a nice optimization.

@Andrei-Aksionov Andrei-Aksionov changed the title LoRA: zero_pad speed improvements LoRA: zero_pad speed improvements Oct 9, 2023
finetune/lora.py Outdated
@@ -250,7 +250,7 @@ def train(
save_lora_checkpoint(fabric, model, checkpoint_path)


@torch.inference_mode()
@torch.no_grad()
Copy link
Collaborator Author

@Andrei-Aksionov Andrei-Aksionov Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In inference mode every new tensor is created as an inference tensor.
In order to use such tensors for the training we have to clone them.
Since every other fine-tune script uses torch.no_grad for validation, I think it's easier/better to use this decorator here too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since validation (as a sanity check) is the first step of the training, the lora_ind property will be called first here. So if the validation is running in inference_mode, the indices will be also stored in an inference tensor. That explains the issue above.

indices.append(torch.arange(in_features + self.kv_embd_size, out_features, device=device))
self.register_buffer("_lora_ind", torch.cat(indices), persistent=False)

return self._lora_ind
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One might ask why don't we place these indices on a target device (i.e. GPU) during the init method.
There is an issue with FSDP and meta devices if to place them during the init, so as a workaround a "lazy" initialization is used.

@@ -345,7 +339,7 @@ def merge(self):
0
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
# W = W + delta_W (merge)
self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128)
self.linear.weight.data += self.zero_pad(delta_w.T * self.scaling).T # (256, 128) after zero_pad (384, 128)
Copy link
Collaborator Author

@Andrei-Aksionov Andrei-Aksionov Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it's better to do double transpose here (and once), rather than every time in zero_pad and for the cases where it's not needed at all.

@Andrei-Aksionov Andrei-Aksionov closed this by deleting the head repository Nov 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant