Skip to content

Commit

Permalink
Introduce buffers to improve Llama 3.2 efficiency (#389)
Browse files Browse the repository at this point in the history
* Introduce buffers to improve Llama 3.2 efficiency

* update

* update
  • Loading branch information
rasbt authored Oct 6, 2024
1 parent a0c0c76 commit 1eb0b38
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 89 deletions.
174 changes: 134 additions & 40 deletions ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,14 @@
"- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below"
]
},
{
"cell_type": "markdown",
"id": "842aa71a-4659-424e-8830-392bd6ae86af",
"metadata": {},
"source": [
"- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
]
},
{
"cell_type": "code",
"execution_count": 8,
Expand All @@ -441,6 +449,28 @@
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"############################# NEW #############################\n",
"class SharedBuffers:\n",
" _buffers = {}\n",
"\n",
" @staticmethod\n",
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
"\n",
" if key not in SharedBuffers._buffers:\n",
" # Create or fetch the buffers\n",
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
" if dtype is not None:\n",
" cos = cos.to(dtype)\n",
" sin = sin.to(dtype)\n",
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
"\n",
" return SharedBuffers._buffers[key]\n",
"############################# NEW #############################\n",
"\n",
"\n",
"class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n",
Expand Down Expand Up @@ -469,13 +499,12 @@
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
"\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
" cos, sin = precompute_rope_params(\n",
" head_dim=self.head_dim,\n",
" theta_base=rope_base, # NEW\n",
" freq_config=rope_config, # NEW\n",
" context_length=8192\n",
" )\n",
" ############################# NEW #############################\n",
" # Fetch buffers using SharedBuffers\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" ############################# NEW #############################\n",
" \n",
" self.register_buffer(\"mask\", mask)\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
"\n",
Expand Down Expand Up @@ -907,6 +936,35 @@
"model = Llama3Model(LLAMA3_CONFIG_8B)"
]
},
{
"cell_type": "markdown",
"id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
"metadata": {},
"source": [
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
"metadata": {},
"outputs": [],
"source": [
"# Check buffers\n",
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
]
},
{
"cell_type": "markdown",
"id": "8056a521-91a6-440f-8473-591409c3177b",
"metadata": {},
"source": [
"- Let's now also compute the number of trainable parameters:"
]
},
{
"cell_type": "code",
"execution_count": 18,
Expand Down Expand Up @@ -2008,16 +2066,16 @@
"}\n",
"\n",
"LLAMA31_CONFIG_8B = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 8192, # Context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # NEW: Larger supported context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
Expand All @@ -2026,6 +2084,24 @@
"}"
]
},
{
"cell_type": "markdown",
"id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
"metadata": {},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bdbe32f-4c96-4e60-8bf4-52b5217df1e6",
"metadata": {},
"outputs": [],
"source": [
"LLAMA32_CONFIG[\"context_length\"] = 8192"
]
},
{
"cell_type": "markdown",
"id": "xa3bpMDtTdBs",
Expand Down Expand Up @@ -2338,16 +2414,16 @@
"outputs": [],
"source": [
"LLAMA31_CONFIG_8B = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 8192, # Context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # NEW: Larger supported context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
Expand All @@ -2357,24 +2433,42 @@
"\n",
"\n",
"LLAMA32_CONFIG_1B = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 8192, # Context length\n",
" \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # NEW: Half the number of layers\n",
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # Context length\n",
" \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # NEW: Half the number of layers\n",
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
" \"original_context_length\": 8192,\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
"metadata": {},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "387456c3-c6a1-46fe-8830-6e00eb46ac13",
"metadata": {},
"outputs": [],
"source": [
"LLAMA32_CONFIG[\"context_length\"] = 8192"
]
},
{
"cell_type": "markdown",
"id": "Dl4_0EoJKKYv",
Expand Down Expand Up @@ -2593,7 +2687,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "base",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -2607,7 +2701,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.11.4"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
Loading

0 comments on commit 1eb0b38

Please sign in to comment.