Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: apply_rope somehow works
Browse files Browse the repository at this point in the history
SauravMaheshkar committed Oct 7, 2024
1 parent f1c1f81 commit b41c23b
Showing 2 changed files with 11 additions and 25 deletions.
8 changes: 0 additions & 8 deletions tests/test_basic.py

This file was deleted.

28 changes: 11 additions & 17 deletions tests/test_math.py
Original file line number Diff line number Diff line change
@@ -44,42 +44,36 @@ def test_rope(self):
)

def test_apply_rope(self):
B, L, H, D = (
2,
4,
2,
8,
) # Batch size, sequence length, number of heads, embedding dimension
B, H, L, D = (
1,
24,
4336,
128,
)
theta = 10000

# Inputs
np_q = np.random.randn(B, H, L, D).astype(np.float32)
np_k = np.random.randn(B, H, L, D).astype(np.float32)
np_v = np.random.randn(B, H, L, D).astype(np.float32)

jax_q = jnp.array(np_q, dtype=jnp.float32)
jax_k = jnp.array(np_k, dtype=jnp.float32)
jax_v = jnp.array(np_v, dtype=jnp.float32)

torch_q = torch.from_numpy(np_q).to(torch.float32)
torch_k = torch.from_numpy(np_k).to(torch.float32)
torch_v = torch.from_numpy(np_v).to(torch.float32)

np.testing.assert_allclose(np.array(jax_q), torch_q.numpy())
np.testing.assert_allclose(np.array(jax_k), torch_k.numpy())
np.testing.assert_allclose(np.array(jax_v), torch_v.numpy())

# Position indices (e.g., positions in the sequence)
np_positions = np.repeat(np.expand_dims(np.arange(L), 0), repeats=B, axis=1)
torch_positions = torch.from_numpy(np_positions).to(torch.int32)
jax_positions = jnp.array(np_positions, dtype=jnp.int32)
np_positions = np.random.randn(1, L).astype(np.float32)
torch_positions = torch.from_numpy(np_positions).to(torch.float32)
jax_positions = jnp.array(np_positions, dtype=jnp.float32)

np.testing.assert_allclose(np.array(jax_positions), torch_positions.numpy())

torch_pe = torch_rope(pos=torch_positions, dim=D, theta=theta)
jax_pe = jax_rope(
pos=jax_positions, dim=D, theta=theta
) # Shape: [B, L, D/2, 2, 2]
torch_pe = torch_rope(pos=torch_positions, dim=(3072 // 24), theta=theta)
jax_pe = jax_rope(pos=jax_positions, dim=(3072 // 24), theta=theta)

np.testing.assert_allclose(
np.array(jax_pe),

0 comments on commit b41c23b

Please sign in to comment.