-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA: zero_pad
speed improvements
#770
LoRA: zero_pad
speed improvements
#770
Conversation
I did a very quick benchmarking with Pythia-410m on 1xT4 between the code from this PR and the current main. It would be nice if someone with an access to a multi-GPU machine could run a quick LoRA finetune just to make sure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I'll run finetuning
x = torch.randint(0, config.padded_vocab_size, size=(2, config.block_size), dtype=torch.int64, device=fabric.device) | ||
model = fabric.setup(model) | ||
y = model(x) | ||
assert y.shape == torch.Size([2, 8, 512]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Andrei-Aksionov Could we maybe add a sanity test that iterates over all model attributes of all submodules and asserts that if it's a tensor then .is_meta is False
? The previous bug wasn't caught simply because defaults were all lora kqv were True, which would essentially skip this code path:
Lines 330 to 342 in 90a16e4
if all(self.enable_lora): | |
return x | |
# Let's image that: | |
# ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size) | |
# ⚬ embeddings_size: 128 | |
# ⚬ self.linear.out_features: 384 (3 * embeddings_size) | |
# ⚬ enable_lora: [True, False, True] | |
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected | |
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but | |
# only for key updates (this is where self.lora_ind comes in handy) | |
result = x.new_zeros(*x.shape[:-1], self.linear.out_features) # (64, 64, 384) | |
return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense.
Sure, I'll do this.
While experimenting with GitHub actions I deleted my fork (I know, I know) and thus all opened PRs were automatically closed. This PR mirrors #630.
Hi there 👋
This PR is a result of #461.
In that issue I've found out that the creation of a new tensor with
lora_ind
(that are stored as a python list on a CPU) for eachzero_pad
call ...https://github.com/Lightning-AI/lit-gpt/blob/807c7bc17413d53961f96dc668aa03c0b970a43f/lit_gpt/lora.py#L293-L295
... implicitly calls
cudaStreamSynchronize
every time and that slows down the forward pass a bit.Traces
Note
Number are provided for the
Nvidia T4
and16-mixed
precision.Let's take a look at the traces for
Pythia-410m
.Currently
zero_pad
takes a significant part of the time:Note
Compare the size of
cudaStreamSynchronize
from the screenshot above (CUDA 12.1) and the one from the "Performance Study" issue (CUDA 11.8) - it's much smaller thanks to the newest CUDA.After the code is optimized, from the trace we can see that the
zero_pad
now takes much less portion of the time:In numbers, it's
830 μs
vs126 μs
.LoRA fine-tuning
If to compare LoRA fine-tuning with
Pythia-410m
and1k iterations
, we have:Not a drastic difference, but still a nice optimization.