From 8354eba80e7dd58627a47ce4f75ce7dc62427a28 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 4 Sep 2024 16:31:29 -0700 Subject: [PATCH] Support Llama-3.1-405B (#199) * Support Llama 3.1 405B * Update readme --- README.md | 2 ++ model.py | 34 ++++++++++++++++++++++++++++++-- scripts/convert_hf_checkpoint.py | 5 ++++- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8ae96eb..c462b72 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ mistralai/Mistral-7B-v0.1 mistralai/Mistral-7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2 meta-llama/Meta-Llama-3-8B +meta-llama/Meta-Llama-3.1-405B ``` For example, to convert Llama-2-7b-chat-hf @@ -120,6 +121,7 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh | Llama-2-70B | Base | 62.50 | 1135.29 | | | 8-bit | 80.44 | 752.04 | | | 4-bit (G=32) | 90.77 | 548.10 | +| Llama-3.1-405B | 8-bit | 15.60 | 815.87 | ### AMD Benchmarks run on one GCD of a MI-250x. diff --git a/model.py b/model.py index b89a19a..6799206 100644 --- a/model.py +++ b/model.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import math from dataclasses import dataclass from typing import Optional @@ -29,6 +30,7 @@ class ModelArgs: head_dim: int = 64 rope_base: float = 10000 norm_eps: float = 1e-5 + rope_scaling: Optional[dict] = None def __post_init__(self): if self.n_local_heads == -1: @@ -68,6 +70,9 @@ def from_name(cls, name: str): "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000), "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000), + "llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, + rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192), + ), } class KVCache(nn.Module): @@ -119,7 +124,7 @@ def setup_caches(self, max_batch_size, max_seq_length): for b in self.layers: b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: @@ -230,11 +235,36 @@ def forward(self, x: Tensor) -> Tensor: return output * self.weight +def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None): + factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + old_context_len = rope_scaling["original_max_position_embeddings"] + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + def precompute_freqs_cis( seq_len: int, n_elem: int, base: int = 10000, - dtype: torch.dtype = torch.bfloat16 + dtype: torch.dtype = torch.bfloat16, + rope_scaling: Optional[dict] = None, ) -> Tensor: freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + if rope_scaling is not None: + freqs = apply_rope_scaling(freqs, rope_scaling) t = torch.arange(seq_len, device=freqs.device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index d3a64d9..f08eaba 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -116,7 +116,10 @@ def permute(w, n_head): print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") if 'llama-3' in model_name.lower(): - original_dir = checkpoint_dir / "original" + if 'llama-3.1' in model_name.lower(): + original_dir = checkpoint_dir / "original" / "mp16" + else: + original_dir = checkpoint_dir / "original" tokenizer_model = original_dir / "tokenizer.model" tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")