diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index 322fe62..0000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Basic test to always pass""" - -from __future__ import annotations - - -def test_always_passes(): - """Simple Test""" - assert True diff --git a/tests/test_math.py b/tests/test_math.py index c0107a3..4c4b897 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -44,42 +44,36 @@ def test_rope(self): ) def test_apply_rope(self): - B, L, H, D = ( - 2, - 4, - 2, - 8, - ) # Batch size, sequence length, number of heads, embedding dimension + B, H, L, D = ( + 1, + 24, + 4336, + 128, + ) theta = 10000 # Inputs np_q = np.random.randn(B, H, L, D).astype(np.float32) np_k = np.random.randn(B, H, L, D).astype(np.float32) - np_v = np.random.randn(B, H, L, D).astype(np.float32) jax_q = jnp.array(np_q, dtype=jnp.float32) jax_k = jnp.array(np_k, dtype=jnp.float32) - jax_v = jnp.array(np_v, dtype=jnp.float32) torch_q = torch.from_numpy(np_q).to(torch.float32) torch_k = torch.from_numpy(np_k).to(torch.float32) - torch_v = torch.from_numpy(np_v).to(torch.float32) np.testing.assert_allclose(np.array(jax_q), torch_q.numpy()) np.testing.assert_allclose(np.array(jax_k), torch_k.numpy()) - np.testing.assert_allclose(np.array(jax_v), torch_v.numpy()) # Position indices (e.g., positions in the sequence) - np_positions = np.repeat(np.expand_dims(np.arange(L), 0), repeats=B, axis=1) - torch_positions = torch.from_numpy(np_positions).to(torch.int32) - jax_positions = jnp.array(np_positions, dtype=jnp.int32) + np_positions = np.random.randn(1, L).astype(np.float32) + torch_positions = torch.from_numpy(np_positions).to(torch.float32) + jax_positions = jnp.array(np_positions, dtype=jnp.float32) np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy()) - torch_pe = torch_rope(pos=torch_positions, dim=D, theta=theta) - jax_pe = jax_rope( - pos=jax_positions, dim=D, theta=theta - ) # Shape: [B, L, D/2, 2, 2] + torch_pe = torch_rope(pos=torch_positions, dim=(3072 // 24), theta=theta) + jax_pe = jax_rope(pos=jax_positions, dim=(3072 // 24), theta=theta) np.testing.assert_allclose( np.array(jax_pe),