Skip to content

Commit

Permalink
Eliminate cuda syncs (#1374)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taylor Robie authored Apr 29, 2024
1 parent 6014075 commit 4780604
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
24 changes: 16 additions & 8 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,22 @@ def __init__(
total_qkv = q_per_kv + 2
head_size = out_features // (self.n_query_groups * total_qkv)
ind = range(out_features)
self.lora_ind = []
lora_ind = []
if enable_q:
q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2]
self.lora_ind.extend(q_ind)
lora_ind.extend(q_ind)
if enable_k:
k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2]
self.lora_ind.extend(k_ind)
lora_ind.extend(k_ind)
if enable_v:
v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1]
self.lora_ind.extend(v_ind)
lora_ind.extend(v_ind)
self._lora_ind = torch.tensor(lora_ind)
self._lora_ind_cache = {self._lora_ind.device: self._lora_ind}
self.reset_parameters()



def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
"""Properly pad weight updates with zeros.
Expand Down Expand Up @@ -328,15 +332,19 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
# ⚬ enable_lora: [True, False, True]
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
# only for key updates (this is where self.lora_ind comes in handy)
# only for key updates (this is where lora_ind comes in handy)
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
# for example when we want to merge/unmerge LoRA weights and pretrained weights
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
result = result.view(-1, self.linear.out_features) # (4096, 384)
result = result.index_copy(
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
) # (4096, 256)

# `lora_ind` is constant, so we want to avoid copying it (and incurring an expensive cudaStreamSynchronize)
# every time this method is called. So instead we simply cache a copy on each device that needs it.
if (lora_ind := self._lora_ind_cache.get(result.device)) is None:
self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device)

result = result.index_copy(1, lora_ind, x.reshape(-1, sum(self.qkv_shapes))) # (4096, 256)
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)

def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
Expand Down
9 changes: 7 additions & 2 deletions litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def chunked_cross_entropy(
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
return torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
# See [non_masked_elems div note]
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))

# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
Expand All @@ -288,7 +289,11 @@ def chunked_cross_entropy(
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
return torch.cat(loss_chunks).sum() / max(1, non_masked_elems)
# [non_masked_elems div note]:
# max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
# results in a python int which is then passed back to torch division. By using the
# `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))


def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (24, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (16, 2)
assert attn.lora_ind == lora_ind
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 24)
bsz, ctx_len, in_dim = 2, 30, 8
Expand All @@ -128,7 +128,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (12, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (10, 2)
assert attn.lora_ind == lora_ind
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 12)
bsz, ctx_len, in_dim = 2, 30, 8
Expand All @@ -149,7 +149,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (16, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (12, 2)
assert attn.lora_ind == lora_ind
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 16)
bsz, ctx_len, in_dim = 2, 30, 8
Expand Down

0 comments on commit 4780604

Please sign in to comment.