-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA: zero_pad
speed improvements
#630
LoRA: zero_pad
speed improvements
#630
Conversation
zero_pad
speed improvements
finetune/lora.py
Outdated
@@ -250,7 +250,7 @@ def train( | |||
save_lora_checkpoint(fabric, model, checkpoint_path) | |||
|
|||
|
|||
@torch.inference_mode() | |||
@torch.no_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In inference mode every new tensor is created as an inference tensor.
In order to use such tensors for the training we have to clone them.
Since every other fine-tune script uses torch.no_grad
for validation, I think it's easier/better to use this decorator here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since validation (as a sanity check) is the first step of the training, the lora_ind
property will be called first here. So if the validation is running in inference_mode, the indices will be also stored in an inference tensor. That explains the issue above.
indices.append(torch.arange(in_features + self.kv_embd_size, out_features, device=device)) | ||
self.register_buffer("_lora_ind", torch.cat(indices), persistent=False) | ||
|
||
return self._lora_ind |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One might ask why don't we place these indices on a target device (i.e. GPU) during the init method.
There is an issue with FSDP
and meta devices if to place them during the init, so as a workaround a "lazy" initialization is used.
@@ -345,7 +339,7 @@ def merge(self): | |||
0 | |||
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) | |||
# W = W + delta_W (merge) | |||
self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128) | |||
self.linear.weight.data += self.zero_pad(delta_w.T * self.scaling).T # (256, 128) after zero_pad (384, 128) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it's better to do double transpose here (and once), rather than every time in zero_pad
and for the cases where it's not needed at all.
Hi there 👋
This PR is a result of #461.
In that issue I've found out that the creation of a new tensor with
lora_ind
(that are stored as a python list on a CPU) for eachzero_pad
call ...https://github.com/Lightning-AI/lit-gpt/blob/807c7bc17413d53961f96dc668aa03c0b970a43f/lit_gpt/lora.py#L293-L295
... implicitly calls
cudaStreamSynchronize
every time and that slows down the forward pass a bit.Traces
Note
Number are provided for the
Nvidia T4
and16-mixed
precision.Let's take a look at the traces for
Pythia-410m
.Currently
zero_pad
takes a significant part of the time:Note
Compare the size of
cudaStreamSynchronize
from the screenshot above (CUDA 12.1) and the one from the "Performance Study" issue (CUDA 11.8) - it's much smaller thanks to the newest CUDA.After the code is optimized, from the trace we can see that the
zero_pad
now takes much less portion of the time:In numbers it's
830 μs
vs126 μs
.LoRA fine-tuning
If to compare LoRA fine-tuning with
Pythia-410m
and1k iterations
, we have:Not a drastic difference, but still a nice optimization.