Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eliminate cuda syncs #1374

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,22 @@ def __init__(
total_qkv = q_per_kv + 2
head_size = out_features // (self.n_query_groups * total_qkv)
ind = range(out_features)
self.lora_ind = []
lora_ind = []
if enable_q:
q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2]
self.lora_ind.extend(q_ind)
lora_ind.extend(q_ind)
if enable_k:
k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2]
self.lora_ind.extend(k_ind)
lora_ind.extend(k_ind)
if enable_v:
v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1]
self.lora_ind.extend(v_ind)
lora_ind.extend(v_ind)
self._lora_ind = torch.tensor(lora_ind)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LoRA training script supports FSDP with meta-device initialization. But this change brakes this, because this is now a tensor on the meta device but never gets re-initialized.

self._lora_ind = torch.tensor(lora_ind) 

should probably move to reset_parameters()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can keep _lora_ind as a Python list during the initialization and place the list on a target device as a tensor (inside zero_pad method) if it's not in the cache.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing the changes in this PR with those in #770 could also be a good alternative

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory that should work on a multi-GPU machine too, thanks to self.register_buffer, but I haven't checked it.
Due to a higher cost of a multi-GPU machine, I almost never use more than a single GPU, and thus I'm lacking in knowledge in this department.

At the same time, the code in this PR looks fairly compact. That means that I don't have any preference.

Copy link
Contributor

@carmocca carmocca May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just needs a test. You can let it run on CI. A simple test like:

@RunIf(standalone=True, min_cuda_gpus=2)
def test_lora_model_fsdp_init():
    config = ...
    fabric = Fabric(devices=2, strategy="fsdp", precision="16-true")
    fabric.launch()
    with fabric.init_module(empty_init=True):
        model = GPT(config)
    x = ...
    model = fabric.setup(model)
    y = model(x)
    assert y.shape == ...

Should catch this issue

self._lora_ind_cache = {self._lora_ind.device: self._lora_ind}
self.reset_parameters()



def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
"""Properly pad weight updates with zeros.

Expand Down Expand Up @@ -328,15 +332,19 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
# ⚬ enable_lora: [True, False, True]
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
# only for key updates (this is where self.lora_ind comes in handy)
# only for key updates (this is where lora_ind comes in handy)
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
# for example when we want to merge/unmerge LoRA weights and pretrained weights
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
result = result.view(-1, self.linear.out_features) # (4096, 384)
result = result.index_copy(
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
) # (4096, 256)

# `lora_ind` is constant, so we want to avoid copying it (and incurring an expensive cudaStreamSynchronize)
# every time this method is called. So instead we simply cache a copy on each device that needs it.
if (lora_ind := self._lora_ind_cache.get(result.device)) is None:
robieta marked this conversation as resolved.
Show resolved Hide resolved
self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device)

result = result.index_copy(1, lora_ind, x.reshape(-1, sum(self.qkv_shapes))) # (4096, 256)
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)

def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
Expand Down
9 changes: 7 additions & 2 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def chunked_cross_entropy(
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
return torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
# See [non_masked_elems div note]
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))
robieta marked this conversation as resolved.
Show resolved Hide resolved

# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
Expand All @@ -288,7 +289,11 @@ def chunked_cross_entropy(
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
return torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
# [non_masked_elems div note]:
# max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
# results in a python int which is then passed back to torch division. By using the
# `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))


def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (24, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (16, 2)
assert attn.lora_ind == lora_ind
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 24)
bsz, ctx_len, in_dim = 2, 30, 8
Expand All @@ -128,7 +128,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (12, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (10, 2)
assert attn.lora_ind == lora_ind
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 12)
bsz, ctx_len, in_dim = 2, 30, 8
Expand All @@ -149,7 +149,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (16, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (12, 2)
assert attn.lora_ind == lora_ind
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 16)
bsz, ctx_len, in_dim = 2, 30, 8
Expand Down