From 4529239827bc94283545f4faf13ab7d34ea3db70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Apr 2024 16:34:26 +0200 Subject: [PATCH] Update LoRA test --- tests/test_lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_lora.py b/tests/test_lora.py index c09d07ee66..d131411d9c 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -107,7 +107,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) - torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind)) + assert torch.equal(attn._lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 24) bsz, ctx_len, in_dim = 2, 30, 8 @@ -128,7 +128,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (12, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (10, 2) - torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind)) + assert torch.equal(attn._lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 12) bsz, ctx_len, in_dim = 2, 30, 8 @@ -149,7 +149,7 @@ def test_lora_mqa_gqa(): assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) - torch.testing.assert_allclose(attn._lora_ind, torch.tensor(lora_ind)) + assert torch.equal(attn._lora_ind, torch.tensor(lora_ind)) x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64) assert attn.zero_pad(x).shape == (3, 5, 16) bsz, ctx_len, in_dim = 2, 30, 8