Skip to content

Commit

Permalink
Support Llama-3.1-405B (#199)
Browse files Browse the repository at this point in the history
* Support Llama 3.1 405B

* Update readme
  • Loading branch information
yanboliang authored Sep 4, 2024
1 parent 61c193d commit 8354eba
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ mistralai/Mistral-7B-v0.1
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3.1-405B
```

For example, to convert Llama-2-7b-chat-hf
Expand Down Expand Up @@ -120,6 +121,7 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh
| Llama-2-70B | Base | 62.50 | 1135.29 |
| | 8-bit | 80.44 | 752.04 |
| | 4-bit (G=32) | 90.77 | 548.10 |
| Llama-3.1-405B | 8-bit | 15.60 | 815.87 |

### AMD
Benchmarks run on one GCD of a MI-250x.
Expand Down
34 changes: 32 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Optional

Expand All @@ -29,6 +30,7 @@ class ModelArgs:
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
rope_scaling: Optional[dict] = None

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down Expand Up @@ -68,6 +70,9 @@ def from_name(cls, name: str):

"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
"llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
),
}

class KVCache(nn.Module):
Expand Down Expand Up @@ -119,7 +124,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype, self.config.rope_scaling)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
Expand Down Expand Up @@ -230,11 +235,36 @@ def forward(self, x: Tensor) -> Tensor:
return output * self.weight


def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Optional[dict] = None):
factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
old_context_len = rope_scaling["original_max_position_embeddings"]

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000,
dtype: torch.dtype = torch.bfloat16
dtype: torch.dtype = torch.bfloat16,
rope_scaling: Optional[dict] = None,
) -> Tensor:
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
if rope_scaling is not None:
freqs = apply_rope_scaling(freqs, rope_scaling)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
Expand Down
5 changes: 4 additions & 1 deletion scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def permute(w, n_head):
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if 'llama-3' in model_name.lower():
original_dir = checkpoint_dir / "original"
if 'llama-3.1' in model_name.lower():
original_dir = checkpoint_dir / "original" / "mp16"
else:
original_dir = checkpoint_dir / "original"
tokenizer_model = original_dir / "tokenizer.model"
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
Expand Down

0 comments on commit 8354eba

Please sign in to comment.