Skip to content

Commit

Permalink
Fix LoRA indices for interleaved QKV weights (#900)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
ebsmothers and carmocca authored Feb 5, 2024
1 parent d80842c commit 6150d04
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
45 changes: 32 additions & 13 deletions lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,23 +261,21 @@ def __init__(
self.scaling = self.lora_alpha / self.r

# Compute the indices
# Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
# but not keys, then the weights update should be:
#
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
# [....................................],
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
# ↑ ↑ ↑
# ________________________________________
# | query | key | value |
# ----------------------------------------
# Indices are needed to properly pad weight updates with zeros in `zero_pad` method.
q_per_kv = self.n_head // self.n_query_groups
total_qkv = q_per_kv + 2
head_size = out_features // (self.n_query_groups * total_qkv)
ind = range(out_features)
self.lora_ind = []
if enable_q:
self.lora_ind.extend(range(0, self.linear.in_features))
q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2]
self.lora_ind.extend(q_ind)
if enable_k:
self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2]
self.lora_ind.extend(k_ind)
if enable_v:
self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1]
self.lora_ind.extend(v_ind)
self.reset_parameters()

def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -293,6 +291,27 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
________________________________________
| query | key | value |
----------------------------------------
For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped
queries are adjacent to their associated key and value weights.
For example, suppose we have n_head = 12 with 3 query groups.
Then along the embedding dimension the interleaved weights would look like
[Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V],
where each Q, K, and V has size head_size.
In this case, the previously-described weight update applies separately to each
individual block, so the update will take the form
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...],
[.............................................................................],
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]]
↑ ↑ ↑ ↑ ↑ ↑
________________________________________________________________________________
| q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ...
--------------------------------------------------------------------------------
Note that in the above diagram, the size of each q block will equal q_per_kv
times the size of each k and v block.
Args:
x: tensor with weights update that will be padded with zeros if necessary
Expand Down
37 changes: 33 additions & 4 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,35 +98,64 @@ def test_lora_mqa_gqa():
assert config.n_query_groups == config.n_head
model = GPT(config)
attn = model.transformer.h[0].attn.attn
for p in attn.linear.parameters():
torch.nn.init.zeros_(p)
torch.nn.init.ones_(attn.lora_B)
lora_ind = [0, 1, 6, 7, 12, 13, 18, 19, 4, 5, 10, 11, 16, 17, 22, 23]
assert attn.linear.weight.shape == (24, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (16, 2)
assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23]
assert attn.lora_ind == lora_ind
x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 24)
bsz, ctx_len, in_dim = 2, 30, 8
x_in = torch.randn(bsz, ctx_len, in_dim)
out = attn(x_in)
non_lora_ind = list(set(range(24)).difference(lora_ind))
assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)
assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0

# MQA
config.n_query_groups = 1
model = GPT(config)
attn = model.transformer.h[0].attn.attn
for p in attn.linear.parameters():
torch.nn.init.zeros_(p)
torch.nn.init.ones_(attn.lora_B)
lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]
assert attn.linear.weight.shape == (12, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (10, 2)
assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]
assert attn.lora_ind == lora_ind
x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 12)
bsz, ctx_len, in_dim = 2, 30, 8
x_in = torch.randn(bsz, ctx_len, in_dim)
out = attn(x_in)
non_lora_ind = list(set(range(12)).difference(lora_ind))
assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)
assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0

# GQA
config.n_query_groups = 2
model = GPT(config)
attn = model.transformer.h[0].attn.attn
for p in attn.linear.parameters():
torch.nn.init.zeros_(p)
torch.nn.init.ones_(attn.lora_B)
lora_ind = [0, 1, 2, 3, 8, 9, 10, 11, 6, 7, 14, 15]
assert attn.linear.weight.shape == (16, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (12, 2)
assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
assert attn.lora_ind == lora_ind
x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 16)

bsz, ctx_len, in_dim = 2, 30, 8
x_in = torch.randn(bsz, ctx_len, in_dim)
out = attn(x_in)
non_lora_ind = list(set(range(16)).difference(lora_ind))
assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)
assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0

def test_lora_filter(tmp_path):
from lit_gpt.lora import GPT, lora_filter
Expand Down

0 comments on commit 6150d04

Please sign in to comment.