Skip to content

Commit

Permalink
Add Llama 3.2 RoPE to CI (#391)
Browse files Browse the repository at this point in the history
* add Llama 3.2 RoPE to CI

* update
  • Loading branch information
rasbt authored Oct 8, 2024
1 parent 1eb0b38 commit ec18b6a
Showing 1 changed file with 116 additions and 35 deletions.
151 changes: 116 additions & 35 deletions ch05/07_gpt_to_llama/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,45 @@

@pytest.fixture(scope="module")
def notebook():
def import_definitions_from_notebook(fullname, names):
# Get the directory of the current test file
current_dir = os.path.dirname(__file__)
path = os.path.join(current_dir, "..", fullname + ".ipynb")
path = os.path.normpath(path)
def import_definitions_from_notebook(notebooks):
imported_modules = {}

# Load the notebook
if not os.path.exists(path):
raise FileNotFoundError(f"Notebook file not found at: {path}")
for fullname, names in notebooks.items():
# Get the directory of the current test file
current_dir = os.path.dirname(__file__)
path = os.path.join(current_dir, "..", fullname + ".ipynb")
path = os.path.normpath(path)

with io.open(path, "r", encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)
# Load the notebook
if not os.path.exists(path):
raise FileNotFoundError(f"Notebook file not found at: {path}")

# Create a module to store the imported functions and classes
mod = types.ModuleType(fullname)
sys.modules[fullname] = mod
with io.open(path, "r", encoding="utf-8") as f:
nb = nbformat.read(f, as_version=4)

# Go through the notebook cells and only execute function or class definitions
for cell in nb.cells:
if cell.cell_type == "code":
cell_code = cell.source
for name in names:
# Check for function or class definitions
if f"def {name}" in cell_code or f"class {name}" in cell_code:
exec(cell_code, mod.__dict__)
return mod
# Create a module to store the imported functions and classes
mod = types.ModuleType(fullname)
sys.modules[fullname] = mod

# Specify the notebook name and functions/classes to import
fullname = "converting-gpt-to-llama2"
names = ["precompute_rope_params", "compute_rope", "SiLU", "RMSNorm"]
# Go through the notebook cells and only execute function or class definitions
for cell in nb.cells:
if cell.cell_type == "code":
cell_code = cell.source
for name in names:
# Check for function or class definitions
if f"def {name}" in cell_code or f"class {name}" in cell_code:
exec(cell_code, mod.__dict__)

# Import the required functions and classes from the notebook
return import_definitions_from_notebook(fullname, names)
imported_modules[fullname] = mod

return imported_modules

notebooks = {
"converting-gpt-to-llama2": ["SiLU", "RMSNorm", "precompute_rope_params", "compute_rope"],
"converting-llama2-to-llama3": ["precompute_rope_params"]
}

return import_definitions_from_notebook(notebooks)


@pytest.fixture(autouse=True)
Expand All @@ -59,22 +65,25 @@ def set_seed():


def test_rope_llama2(notebook):

this_nb = notebook["converting-gpt-to-llama2"]

# Settings
batch_size = 1
context_len = 4096
num_heads = 4
head_dim = 16

# Instantiate RoPE parameters
cos, sin = notebook.precompute_rope_params(head_dim=head_dim, context_length=context_len)
cos, sin = this_nb.precompute_rope_params(head_dim=head_dim, context_length=context_len)

# Dummy query and key tensors
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = notebook.compute_rope(queries, cos, sin)
keys_rot = notebook.compute_rope(keys, cos, sin)
queries_rot = this_nb.compute_rope(queries, cos, sin)
keys_rot = this_nb.compute_rope(keys, cos, sin)

rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
Expand All @@ -93,6 +102,10 @@ def test_rope_llama2(notebook):


def test_rope_llama3(notebook):

nb1 = notebook["converting-gpt-to-llama2"]
nb2 = notebook["converting-llama2-to-llama3"]

# Settings
batch_size = 1
context_len = 8192
Expand All @@ -101,19 +114,20 @@ def test_rope_llama3(notebook):
theta_base = 50_000

# Instantiate RoPE parameters
cos, sin = notebook.precompute_rope_params(
cos, sin = nb2.precompute_rope_params(
head_dim=head_dim,
context_length=context_len,
theta_base=theta_base
)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = notebook.compute_rope(queries, cos, sin)
keys_rot = notebook.compute_rope(keys, cos, sin)
queries_rot = nb1.compute_rope(queries, cos, sin)
keys_rot = nb1.compute_rope(keys, cos, sin)

rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
Expand All @@ -131,16 +145,83 @@ def test_rope_llama3(notebook):
torch.testing.assert_close(queries_rot, ref_queries_rot)


def test_rope_llama3_12(notebook):

nb1 = notebook["converting-gpt-to-llama2"]
nb2 = notebook["converting-llama2-to-llama3"]

# Settings
batch_size = 1
context_len = 8192
num_heads = 4
head_dim = 16
rope_theta = 50_000

rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}

# Instantiate RoPE parameters
cos, sin = nb2.precompute_rope_params(
head_dim=head_dim,
theta_base=rope_theta,
context_length=context_len,
freq_config=rope_config,
)

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

# Apply rotary position embeddings
queries_rot = nb1.compute_rope(queries, cos, sin)
keys_rot = nb1.compute_rope(keys, cos, sin)

hf_rope_params = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}

class RoPEConfig:
rope_type = "llama3"
rope_scaling = hf_rope_params
factor = 1.0
dim: int = head_dim
rope_theta = 50_000
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads

config = RoPEConfig()

rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)

torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)


def test_silu(notebook):
example_batch = torch.randn(2, 3, 4)
silu = notebook.SiLU()
silu = notebook["converting-gpt-to-llama2"].SiLU()
assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))


@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
def test_rmsnorm(notebook):
example_batch = torch.randn(2, 3, 4)
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
rms_norm = notebook["converting-gpt-to-llama2"].RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))

0 comments on commit ec18b6a

Please sign in to comment.