diff --git a/lit_gpt/lora.py b/lit_gpt/lora.py index aa0bf84f0c..fed1a11cb1 100644 --- a/lit_gpt/lora.py +++ b/lit_gpt/lora.py @@ -263,18 +263,18 @@ def __init__( # Compute the indices # Indices are needed to properly pad weight updates with zeros in `zero_pad` method. q_per_kv = self.n_head // self.n_query_groups - query_group_size = q_per_kv + 2 - head_dim = out_features // (self.n_query_groups * query_group_size) + total_qkv = q_per_kv + 2 + head_size = out_features // (self.n_query_groups * total_qkv) ind = range(out_features) self.lora_ind = [] if enable_q: - q_ind = list(filter(lambda x: (x // head_dim) % query_group_size < query_group_size - 2, ind)) + q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2] self.lora_ind.extend(q_ind) if enable_k: - k_ind = list(filter(lambda x: (x // head_dim) % query_group_size == query_group_size - 2, ind)) + k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2] self.lora_ind.extend(k_ind) if enable_v: - v_ind = list(filter(lambda x: (x // head_dim) % query_group_size == query_group_size - 1, ind)) + v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1] self.lora_ind.extend(v_ind) self.reset_parameters() @@ -298,7 +298,7 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], - where each Q, K, and V has size head_dim. + where each Q, K, and V has size head_size. In this case, the previously-described weight update applies separately to each individual block, so the update will take the form