From 6150d04ff3b199ddefbe55e58d593ecae587b9d9 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Mon, 5 Feb 2024 11:23:18 -0800 Subject: [PATCH] Fix LoRA indices for interleaved QKV weights (#900) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- lit_gpt/lora.py | 45 ++++++++++++++++++++++++++++++++------------- tests/test_lora.py | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index 80da55a242..fed1a11cb1 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -261,23 +261,21 @@ def __init__( self.scaling = self.lora_alpha / self.r # Compute the indices - # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values, - # but not keys, then the weights update should be: - # - # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], - # [....................................], - # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] - # ↑ ↑ ↑ - # ________________________________________ - # | query | key | value | - # ---------------------------------------- + # Indices are needed to properly pad weight updates with zeros in `zero_pad` method. + q_per_kv = self.n_head // self.n_query_groups + total_qkv = q_per_kv + 2 + head_size = out_features // (self.n_query_groups * total_qkv) + ind = range(out_features) self.lora_ind = [] if enable_q: - self.lora_ind.extend(range(0, self.linear.in_features)) + q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2] + self.lora_ind.extend(q_ind) if enable_k: - self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size)) + k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2] + self.lora_ind.extend(k_ind) if enable_v: - self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features)) + v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1] + self.lora_ind.extend(v_ind) self.reset_parameters() def zero_pad(self, x: torch.Tensor) -> torch.Tensor: @@ -293,6 +291,27 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: ________________________________________ | query | key | value | ---------------------------------------- + For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped + queries are adjacent to their associated key and value weights. + For example, suppose we have n_head = 12 with 3 query groups. + Then along the embedding dimension the interleaved weights would look like + + [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], + + where each Q, K, and V has size head_size. + + In this case, the previously-described weight update applies separately to each + individual block, so the update will take the form + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], + [.............................................................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] + ↑ ↑ ↑ ↑ ↑ ↑ + ________________________________________________________________________________ + | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... + -------------------------------------------------------------------------------- + Note that in the above diagram, the size of each q block will equal q_per_kv + times the size of each k and v block. Args: x: tensor with weights update that will be padded with zeros if necessary diff --git a/tests/test_lora.py b/tests/test_lora.py index da9967a52b..c6fa964f74 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -98,35 +98,64 @@ def test_lora_mqa_gqa(): assert config.n_query_groups == config.n_head model = GPT(config) attn = model.transformer.h[0].attn.attn + for p in attn.linear.parameters(): + torch.nn.init.zeros_(p) + torch.nn.init.ones_(attn.lora_B) + lora_ind = [0, 1, 6, 7, 12, 13, 18, 19, 4, 5, 10, 11, 16, 17, 22, 23] assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) - assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] + assert attn.lora_ind == lora_ind x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 24) + bsz, ctx_len, in_dim = 2, 30, 8 + x_in = torch.randn(bsz, ctx_len, in_dim) + out = attn(x_in) + non_lora_ind = list(set(range(24)).difference(lora_ind)) + assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind) + assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0 # MQA config.n_query_groups = 1 model = GPT(config) attn = model.transformer.h[0].attn.attn + for p in attn.linear.parameters(): + torch.nn.init.zeros_(p) + torch.nn.init.ones_(attn.lora_B) + lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11] assert attn.linear.weight.shape == (12, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (10, 2) - assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 10, 11] + assert attn.lora_ind == lora_ind x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 12) + bsz, ctx_len, in_dim = 2, 30, 8 + x_in = torch.randn(bsz, ctx_len, in_dim) + out = attn(x_in) + non_lora_ind = list(set(range(12)).difference(lora_ind)) + assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind) + assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0 # GQA config.n_query_groups = 2 model = GPT(config) attn = model.transformer.h[0].attn.attn + for p in attn.linear.parameters(): + torch.nn.init.zeros_(p) + torch.nn.init.ones_(attn.lora_B) + lora_ind = [0, 1, 2, 3, 8, 9, 10, 11, 6, 7, 14, 15] assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) - assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] + assert attn.lora_ind == lora_ind x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 16) - + bsz, ctx_len, in_dim = 2, 30, 8 + x_in = torch.randn(bsz, ctx_len, in_dim) + out = attn(x_in) + non_lora_ind = list(set(range(16)).difference(lora_ind)) + assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind) + assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0 def test_lora_filter(tmp_path): from lit_gpt.lora import GPT, lora_filter