Skip to content

Commit

Permalink
Apply my suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 5, 2024
1 parent 9c1d6e9 commit df31f15
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,18 @@ def __init__(
# Compute the indices
# Indices are needed to properly pad weight updates with zeros in `zero_pad` method.
q_per_kv = self.n_head // self.n_query_groups
query_group_size = q_per_kv + 2
head_dim = out_features // (self.n_query_groups * query_group_size)
total_qkv = q_per_kv + 2
head_size = out_features // (self.n_query_groups * total_qkv)
ind = range(out_features)
self.lora_ind = []
if enable_q:
q_ind = list(filter(lambda x: (x // head_dim) % query_group_size < query_group_size - 2, ind))
q_ind = [x for x in ind if (x // head_size) % total_qkv < total_qkv - 2]
self.lora_ind.extend(q_ind)
if enable_k:
k_ind = list(filter(lambda x: (x // head_dim) % query_group_size == query_group_size - 2, ind))
k_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 2]
self.lora_ind.extend(k_ind)
if enable_v:
v_ind = list(filter(lambda x: (x // head_dim) % query_group_size == query_group_size - 1, ind))
v_ind = [x for x in ind if (x // head_size) % total_qkv == total_qkv - 1]
self.lora_ind.extend(v_ind)
self.reset_parameters()

Expand All @@ -298,7 +298,7 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
[Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V],
where each Q, K, and V has size head_dim.
where each Q, K, and V has size head_size.
In this case, the previously-described weight update applies separately to each
individual block, so the update will take the form
Expand Down

0 comments on commit df31f15

Please sign in to comment.