Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate cuda syncs #1374

Merged
merged 2 commits into from
Apr 29, 2024
Merged

Eliminate cuda syncs #1374

merged 2 commits into from
Apr 29, 2024

Conversation

robieta
Copy link

@robieta robieta commented Apr 29, 2024

This PR fixes two CUDA syncs that I ran across when optimizing Gemma:

1) max(1, non_masked_elems)
This punts to python int before being implicitly converted to a Tensor. (I'm pretty sure I'm responsible for this one.) We need to use the uglier but more performant non_masked_elems.maximum(torch.ones_like(non_masked_elems)).

2) torch.tensor(self.lora_ind, device=result.device)
This one is a little harder because we genuinely do need to move data from host to device. However, lora_ind is set in __init__ and doesn't change. So the best we can do is cache the first time we see it on a given device.

NOTE: It's very important that we do our own caching rather than use functools.cache, as the latter extends the life of self by storing it in the cache.

@rasbt
Copy link
Collaborator

rasbt commented Apr 29, 2024

Thanks a lot for the PR! Do you have some rough estimates in terms of how the performance is before and after? E.g., if it is a noticeable difference, it could potentially be related to #1369

@robieta
Copy link
Author

robieta commented Apr 29, 2024

@rasbt It's going to be super case dependent. (The LoRA one is definitely the much more important one.) I saw ~5%, but host-device syncs can vary from no difference to several-fold slowdown. For #1369 it's impossible to say anything without a profile. (It's not clear to me that it should be related, but stranger things have happened.)

@robieta
Copy link
Author

robieta commented Apr 29, 2024

By the way, I did an audit of other uses of torch.tensor and was pleasantly surprised to find no other cases that looked problematic. (Which is very unusual for a codebase of this size and complexity.) Thanks for keeping the bar high everyone!

litgpt/lora.py Outdated Show resolved Hide resolved
Copy link
Contributor

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, let's add a couple of comments so future readers understand

litgpt/utils.py Show resolved Hide resolved
litgpt/lora.py Show resolved Hide resolved
@robieta
Copy link
Author

robieta commented Apr 29, 2024

Added comments and fixed the lora_ind issue.

@robieta robieta merged commit 4780604 into main Apr 29, 2024
9 checks passed
@robieta robieta deleted the robieta/eliminate_syncs branch April 29, 2024 22:49
@awaelchli awaelchli mentioned this pull request Apr 30, 2024
if enable_v:
v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1]
self.lora_ind.extend(v_ind)
lora_ind.extend(v_ind)
self._lora_ind = torch.tensor(lora_ind)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoRA training script supports FSDP with meta-device initialization. But this change brakes this, because this is now a tensor on the meta device but never gets re-initialized.

self._lora_ind = torch.tensor(lora_ind) 

should probably move to reset_parameters()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can keep _lora_ind as a Python list during the initialization and place the list on a target device as a tensor (inside zero_pad method) if it's not in the cache.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing the changes in this PR with those in #770 could also be a good alternative

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory that should work on a multi-GPU machine too, thanks to self.register_buffer, but I haven't checked it.
Due to a higher cost of a multi-GPU machine, I almost never use more than a single GPU, and thus I'm lacking in knowledge in this department.

At the same time, the code in this PR looks fairly compact. That means that I don't have any preference.

Copy link
Contributor

@carmocca carmocca May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just needs a test. You can let it run on CI. A simple test like:

@RunIf(standalone=True, min_cuda_gpus=2)
def test_lora_model_fsdp_init():
    config = ...
    fabric = Fabric(devices=2, strategy="fsdp", precision="16-true")
    fabric.launch()
    with fabric.init_module(empty_init=True):
        model = GPT(config)
    x = ...
    model = fabric.setup(model)
    y = model(x)
    assert y.shape == ...

Should catch this issue

tests/test_lora.py Show resolved Hide resolved
@Andrei-Aksionov
Copy link
Collaborator

Hah, I have already forgotten that I've created a PR to eliminate unnecessary CUDA sync during the zero_pad call. I remember that there was an issue with a CUDA Stream overflow during the backward pass, which made the backward call slower, but thanks to a speedup during the forward pass the overall time during training was smaller.
The funniest part is that I started to investigate it after @carmocca recommend watching the video where Taylor explained how to do profiling.
Eventually, @robieta fixed the issue himself 🙃.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants