Skip to content

Commit

Permalink
move comment into zero_pad docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers committed Jan 29, 2024
1 parent 969b51b commit 9c1d6e9
Showing 1 changed file with 22 additions and 32 deletions.
54 changes: 22 additions & 32 deletions lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,38 +261,7 @@ def __init__(
self.scaling = self.lora_alpha / self.r

# Compute the indices
# Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
# but not keys, then the weights update should be:
#
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
# [....................................],
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
# ↑ ↑ ↑
# ________________________________________
# | query | key | value |
# ----------------------------------------
# For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped
# queries are adjacent to their associated key and value weights.
# For example, suppose we have n_head = 12 with 3 query groups.
# Then along the embedding dimension the interleaved weights would look like
#
# [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V],
#
# where each Q, K, and V has size head_dim.
#
# In this case, the previously-described weight update applies separately to each
# individual block, so the update will take the form
#
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...],
# [.............................................................................],
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]]
# ↑ ↑ ↑ ↑ ↑ ↑
# ________________________________________________________________________________
# | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ...
# --------------------------------------------------------------------------------
# Note that in the above diagram, the size of each q block will equal q_per_kv
# times the size of each k and v block.

# Indices are needed to properly pad weight updates with zeros in `zero_pad` method.
q_per_kv = self.n_head // self.n_query_groups
query_group_size = q_per_kv + 2
head_dim = out_features // (self.n_query_groups * query_group_size)
Expand Down Expand Up @@ -322,6 +291,27 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
________________________________________
| query | key | value |
----------------------------------------
For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped
queries are adjacent to their associated key and value weights.
For example, suppose we have n_head = 12 with 3 query groups.
Then along the embedding dimension the interleaved weights would look like
[Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V],
where each Q, K, and V has size head_dim.
In this case, the previously-described weight update applies separately to each
individual block, so the update will take the form
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...],
[.............................................................................],
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]]
↑ ↑ ↑ ↑ ↑ ↑
________________________________________________________________________________
| q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ...
--------------------------------------------------------------------------------
Note that in the above diagram, the size of each q block will equal q_per_kv
times the size of each k and v block.
Args:
x: tensor with weights update that will be padded with zeros if necessary
Expand Down

0 comments on commit 9c1d6e9

Please sign in to comment.