Skip to content

Commit

Permalink
Update LoRA test (#1376)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 30, 2024
1 parent 9683600 commit 0ce1ca4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (24, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (16, 2)
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
assert torch.equal(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 24)
bsz, ctx_len, in_dim = 2, 30, 8
Expand All @@ -128,7 +128,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (12, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (10, 2)
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
assert torch.equal(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 12)
bsz, ctx_len, in_dim = 2, 30, 8
Expand All @@ -149,7 +149,7 @@ def test_lora_mqa_gqa():
assert attn.linear.weight.shape == (16, 8)
assert attn.lora_A.shape == (4, 8)
assert attn.lora_B.shape == (12, 2)
torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind))
assert torch.equal(attn._lora_ind, torch.tensor(lora_ind))
x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64)
assert attn.zero_pad(x).shape == (3, 5, 16)
bsz, ctx_len, in_dim = 2, 30, 8
Expand Down

0 comments on commit 0ce1ca4

Please sign in to comment.