Skip to content

Commit

Permalink
Add support for llama 3.1 8B/70B (#200)
Browse files Browse the repository at this point in the history
* Add support for llama 3.1 8B/70B

* Update 4 GPU perf numbers
  • Loading branch information
yanboliang authored Sep 13, 2024
1 parent 8354eba commit c9f683e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
22 changes: 16 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ mistralai/Mistral-7B-v0.1
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3.1-8B
meta-llama/Meta-Llama-3.1-70B
meta-llama/Meta-Llama-3.1-405B
```

Expand All @@ -93,8 +95,10 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh
| Llama-2-70B | Base | OOM ||
| | 8-bit | 19.13 | 1322.58 |
| | 4-bit (G=32) | 25.25 | 1097.66 |
| Llama-3-8B | Base | 94.25 | 1411.95 |
| | 8-bit | 139.55 | 1047.23 |
| Llama-3.1-8B | Base | 93.89 | 1410.76 |
| | 8-bit | 137.64 | 1030.89 |
| Llama-3.1-70B | Base | OOM ||
| | 8-bit | 18.04 | 1253.78 |

### Speculative Sampling
[Verifier: Llama-70B (int4), Draft: Llama-7B (int4)](./scripts/speculate_70B_int4.sh): 48.4 tok/s
Expand All @@ -110,17 +114,23 @@ Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh
| | 2 | 21.32 | 1481.87 |
| | 4 | 38.01 | 1340.76 |
| | 8 | 62.50 | 1135.29 |
| Llama-3-8B | 1 | 94.19 | 1411.76 |
| | 2 | 150.48 | 1208.80 |
| | 4 | 219.77 | 991.63 |
| | 8 | 274.65 | 768.55 |
| Llama-3.1-8B | 1 | 93.83 | 1408.37 |
| | 2 | 149.10 | 1197.32 |
| | 4 | 217.21 | 986.32 |
| | 8 | 276.01 | 772.60 |
| Llama-3.1-70B | 1 | OOM | |
| | 2 | 16.03 | 1130.81 |
| | 4 | 37.45 | 1360.53 |
| | 8 | 58.78 | 1129.61 |

### Tensor Parallelism + Quantization
| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) |
| -------- | ------- | ------ | ------ |
| Llama-2-70B | Base | 62.50 | 1135.29 |
| | 8-bit | 80.44 | 752.04 |
| | 4-bit (G=32) | 90.77 | 548.10 |
| Llama-3.1-70B | Base | 58.78 | 1129.61 |
| | 8-bit | 75.58 | 726.57 |
| Llama-3.1-405B | 8-bit | 15.60 | 815.87 |

### AMD
Expand Down
6 changes: 6 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def from_name(cls, name: str):

"llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000),
"llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000),
"llama-3.1-8b": dict(block_size=131072, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000,
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
),
"llama-3.1-70b": dict(block_size=131072, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672, vocab_size=128256, rope_base=500000,
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
),
"llama-3.1-405b": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000,
rope_scaling=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_position_embeddings=8192),
),
Expand Down
2 changes: 1 addition & 1 deletion scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def permute(w, n_head):
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if 'llama-3' in model_name.lower():
if 'llama-3.1' in model_name.lower():
if 'llama-3.1-405b' in model_name.lower():
original_dir = checkpoint_dir / "original" / "mp16"
else:
original_dir = checkpoint_dir / "original"
Expand Down

2 comments on commit c9f683e

@wfloveiu
Copy link

@wfloveiu wfloveiu commented on c9f683e Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there is no "original" dictionary in "llama-3.1-8b" model?And how can I generate the "tokenizer.model"

@deafTim
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there is no "original" dictionary in "llama-3.1-8b" model?And how can I generate the "tokenizer.model"

@wfloveiu may be you can use any other tokenizer from these models

Please sign in to comment.