Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA: zero_pad speed improvements #770

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3a6e637
Use only `index_copy` without any views or reshapes.
Andrei-Aksionov Sep 18, 2023
1d591f2
Don't do transpose required for the merge method in each `zero_pad` c…
Andrei-Aksionov Sep 18, 2023
3d068cb
`self.lora_ind` as a property
Andrei-Aksionov Sep 18, 2023
89d8eaa
Updates for a case with an inference tensor
Andrei-Aksionov Sep 19, 2023
6785012
Docstring for a `lora_ind` property.
Andrei-Aksionov Sep 19, 2023
8951570
Reassign `self._lora_ind` so it will be recreated outside inference mode
Andrei-Aksionov Sep 19, 2023
38d9ef0
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Sep 20, 2023
eb73d27
Make `lora_ind` property a bit shorter.
Andrei-Aksionov Sep 22, 2023
4cc6cb3
Trim comments for `lora_ind` property.
Andrei-Aksionov Sep 22, 2023
138acaa
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Oct 9, 2023
7013fed
If validate is running in no_grad mode there is no need to clone ind
Andrei-Aksionov Oct 9, 2023
b494438
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Oct 9, 2023
865f883
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Oct 10, 2023
48a401b
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Oct 16, 2023
5a14afc
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Oct 25, 2023
c5f079f
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Nov 13, 2023
4b20c88
Revert "Minor tutorial updates"
Andrei-Aksionov Nov 21, 2023
daf0e4e
Revert "Fix typo"
Andrei-Aksionov Nov 21, 2023
652d7bd
Revert "Revert "Minor tutorial updates""
Andrei-Aksionov Nov 21, 2023
b16e385
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov Nov 23, 2023
015f868
Undo weirdly appeared typo.
Andrei-Aksionov Nov 23, 2023
532a710
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov May 6, 2024
1ef81f3
Add FSDP test with empty_init=True
Andrei-Aksionov May 6, 2024
9560bc0
Merge branch 'main' into lora_zero_pad_speed_improvements
Andrei-Aksionov May 6, 2024
576de42
Merge branch 'main' into lora_zero_pad_speed_improvements
carmocca May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions lit_gpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,28 +233,29 @@ def __init__(
# https://github.com/cloneofsimo/lora
self.scaling = self.lora_alpha / self.r

# Compute the indices
# Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
# but not keys, then the weights update should be:
#
# [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
# [....................................],
# [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
# ↑ ↑ ↑
# ________________________________________
# | query | key | value |
# ----------------------------------------
self.lora_ind = []
self.reset_parameters()

@property
def lora_ind(self) -> torch.Tensor:
"""Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
# Indices are needed to properly pad weight updates with zeros.
if not hasattr(self, "_lora_ind"):
indices = []
enable_q, enable_k, enable_v = self.enable_lora
in_features, out_features = self.linear.in_features, self.linear.out_features
device = self.linear.weight.device
if enable_q:
self.lora_ind.extend(range(0, self.linear.in_features))
indices.append(torch.arange(0, in_features, device=device))
if enable_k:
self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
indices.append(torch.arange(in_features, in_features + self.kv_embd_size, device=device))
if enable_v:
self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
self.reset_parameters()
indices.append(torch.arange(in_features + self.kv_embd_size, out_features, device=device))
self.register_buffer("_lora_ind", torch.cat(indices), persistent=False)

return self._lora_ind
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
"""Properly pad weight updates with zeros.
"""Properly pad the last dimension of weight updates with zeros.

If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
then the weights update should be:
Expand Down Expand Up @@ -285,15 +286,8 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
# Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
# embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
# only for key updates (this is where self.lora_ind comes in handy)
# Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
# for example when we want to merge/unmerge LoRA weights and pretrained weights
x = x.transpose(0, 1)
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
result = result.view(-1, self.linear.out_features) # (4096, 384)
result = result.index_copy(
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
) # (4096, 256)
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) # (64, 64, 384)
result = x.new_zeros(*x.shape[:-1], self.linear.out_features) # (64, 64, 384)
return result.index_copy_(dim=-1, index=self.lora_ind, source=x) # (64, 64, 384)

def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
Expand Down Expand Up @@ -345,7 +339,7 @@ def merge(self) -> None:
0
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
# W = W + delta_W (merge)
self.linear.weight.data += self.zero_pad(delta_w * self.scaling) # (256, 128) after zero_pad (384, 128)
self.linear.weight.data += self.zero_pad(delta_w.T * self.scaling).T # (256, 128) after zero_pad (384, 128)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.merged = True

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (24, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (16, 2)
assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23]
assert torch.equal(attn.lora_ind, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23]))
x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 24)

Expand All @@ -103,7 +103,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (12, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (10, 2)
assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]
assert torch.equal(attn.lora_ind, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11]))
x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 12)

Expand All @@ -114,7 +114,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (16, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (12, 2)
assert attn.lora_ind == [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
assert torch.equal(attn.lora_ind, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]))
x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 16)

Expand Down
Loading