diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index aad485aba2..aa0bf84f0c 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -261,38 +261,7 @@ def __init__( self.scaling = self.lora_alpha / self.r # Compute the indices - # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values, - # but not keys, then the weights update should be: - # - # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], - # [....................................], - # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] - # ↑ ↑ ↑ - # ________________________________________ - # | query | key | value | - # ---------------------------------------- - # For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped - # queries are adjacent to their associated key and value weights. - # For example, suppose we have n_head = 12 with 3 query groups. - # Then along the embedding dimension the interleaved weights would look like - # - # [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], - # - # where each Q, K, and V has size head_dim. - # - # In this case, the previously-described weight update applies separately to each - # individual block, so the update will take the form - # - # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], - # [.............................................................................], - # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] - # ↑ ↑ ↑ ↑ ↑ ↑ - # ________________________________________________________________________________ - # | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... - # -------------------------------------------------------------------------------- - # Note that in the above diagram, the size of each q block will equal q_per_kv - # times the size of each k and v block. - + # Indices are needed to properly pad weight updates with zeros in `zero_pad` method. q_per_kv = self.n_head // self.n_query_groups query_group_size = q_per_kv + 2 head_dim = out_features // (self.n_query_groups * query_group_size) @@ -322,6 +291,27 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: ________________________________________ | query | key | value | ---------------------------------------- + For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped + queries are adjacent to their associated key and value weights. + For example, suppose we have n_head = 12 with 3 query groups. + Then along the embedding dimension the interleaved weights would look like + + [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], + + where each Q, K, and V has size head_dim. + + In this case, the previously-described weight update applies separately to each + individual block, so the update will take the form + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], + [.............................................................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] + ↑ ↑ ↑ ↑ ↑ ↑ + ________________________________________________________________________________ + | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... + -------------------------------------------------------------------------------- + Note that in the above diagram, the size of each q block will equal q_per_kv + times the size of each k and v block. Args: x: tensor with weights update that will be padded with zeros if necessary