From 47806045a376cd4dbf48d6bfc41c97659239de4d Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Mon, 29 Apr 2024 15:49:35 -0700 Subject: [PATCH] Eliminate cuda syncs (#1374) --- litgpt/lora.py | 24 ++++++++++++++++-------- litgpt/utils.py | 9 +++++++-- tests/test_lora.py | 6 +++--- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/litgpt/lora.py b/litgpt/lora.py index 51fd66713d..8fee63cbb6 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -264,18 +264,22 @@ def __init__( total_qkv = q_per_kv + 2 head_size = out_features // (self.n_query_groups * total_qkv) ind = range(out_features) - self.lora_ind = [] + lora_ind = [] if enable_q: q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2] - self.lora_ind.extend(q_ind) + lora_ind.extend(q_ind) if enable_k: k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2] - self.lora_ind.extend(k_ind) + lora_ind.extend(k_ind) if enable_v: v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1] - self.lora_ind.extend(v_ind) + lora_ind.extend(v_ind) + self._lora_ind = torch.tensor(lora_ind) + self._lora_ind_cache = {self._lora_ind.device: self._lora_ind} self.reset_parameters() + + def zero_pad(self, x: torch.Tensor) -> torch.Tensor: """Properly pad weight updates with zeros. @@ -328,15 +332,19 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: # ⚬ enable_lora: [True, False, True] # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but - # only for key updates (this is where self.lora_ind comes in handy) + # only for key updates (this is where lora_ind comes in handy) # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors # for example when we want to merge/unmerge LoRA weights and pretrained weights x = x.transpose(0, 1) result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384) result = result.view(-1, self.linear.out_features) # (4096, 384) - result = result.index_copy( - 1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) - ) # (4096, 256) + + # `lora_ind` is constant, so we want to avoid copying it (and incurring an expensive cudaStreamSynchronize) + # every time this method is called. So instead we simply cache a copy on each device that needs it. + if (lora_ind := self._lora_ind_cache.get(result.device)) is None: + self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device) + + result = result.index_copy(1, lora_ind, x.reshape(-1, sum(self.qkv_shapes))) # (4096, 256) return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384) def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: diff --git a/litgpt/utils.py b/litgpt/utils.py index 21f7f34a98..6eb7efbff4 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -272,7 +272,8 @@ def chunked_cross_entropy( for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] non_masked_elems = (targets != ignore_index).sum() - return torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + # See [non_masked_elems div note] + return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems)) # no chunking at all logits = logits.reshape(-1, logits.size(-1)) @@ -288,7 +289,11 @@ def chunked_cross_entropy( for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) ] non_masked_elems = (targets != ignore_index).sum() - return torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + # [non_masked_elems div note]: + # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that + # results in a python int which is then passed back to torch division. By using the + # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize. + return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems)) def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: diff --git a/tests/test_lora.py b/tests/test_lora.py index f8764c39bb..c09d07ee66 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -107,7 +107,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) - assert attn.lora_ind == lora_ind + torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 24) bsz, ctx_len, in_dim = 2, 30, 8 @@ -128,7 +128,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (12, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (10, 2) - assert attn.lora_ind == lora_ind + torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 12) bsz, ctx_len, in_dim = 2, 30, 8 @@ -149,7 +149,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) - assert attn.lora_ind == lora_ind + torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 16) bsz, ctx_len, in_dim = 2, 30, 8