-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eliminate cuda syncs #1374
Eliminate cuda syncs #1374
Conversation
Thanks a lot for the PR! Do you have some rough estimates in terms of how the performance is before and after? E.g., if it is a noticeable difference, it could potentially be related to #1369 |
@rasbt It's going to be super case dependent. (The LoRA one is definitely the much more important one.) I saw ~5%, but host-device syncs can vary from no difference to several-fold slowdown. For #1369 it's impossible to say anything without a profile. (It's not clear to me that it should be related, but stranger things have happened.) |
By the way, I did an audit of other uses of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, let's add a couple of comments so future readers understand
Added comments and fixed the |
if enable_v: | ||
v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1] | ||
self.lora_ind.extend(v_ind) | ||
lora_ind.extend(v_ind) | ||
self._lora_ind = torch.tensor(lora_ind) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LoRA training script supports FSDP with meta-device initialization. But this change brakes this, because this is now a tensor on the meta device but never gets re-initialized.
self._lora_ind = torch.tensor(lora_ind)
should probably move to reset_parameters()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we can keep _lora_ind
as a Python list during the initialization and place the list on a target device as a tensor (inside zero_pad
method) if it's not in the cache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing the changes in this PR with those in #770 could also be a good alternative
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory that should work on a multi-GPU machine too, thanks to self.register_buffer
, but I haven't checked it.
Due to a higher cost of a multi-GPU machine, I almost never use more than a single GPU, and thus I'm lacking in knowledge in this department.
At the same time, the code in this PR looks fairly compact. That means that I don't have any preference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just needs a test. You can let it run on CI. A simple test like:
@RunIf(standalone=True, min_cuda_gpus=2)
def test_lora_model_fsdp_init():
config = ...
fabric = Fabric(devices=2, strategy="fsdp", precision="16-true")
fabric.launch()
with fabric.init_module(empty_init=True):
model = GPT(config)
x = ...
model = fabric.setup(model)
y = model(x)
assert y.shape == ...
Should catch this issue
Hah, I have already forgotten that I've created a PR to eliminate unnecessary CUDA sync during the |
This PR fixes two CUDA syncs that I ran across when optimizing Gemma:
1)
max(1, non_masked_elems)
This punts to python int before being implicitly converted to a Tensor. (I'm pretty sure I'm responsible for this one.) We need to use the uglier but more performant
non_masked_elems.maximum(torch.ones_like(non_masked_elems))
.2)
torch.tensor(self.lora_ind, device=result.device)
This one is a little harder because we genuinely do need to move data from host to device. However,
lora_ind
is set in__init__
and doesn't change. So the best we can do is cache the first time we see it on a given device.NOTE: It's very important that we do our own caching rather than use
functools.cache
, as the latter extends the life ofself
by storing it in the cache.