Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mixtral-8x22B #1845

Merged
merged 6 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ Every model is written from scratch to maximize performance and remove layers of
|----|----|----|----|
| Llama 3, 3.1, 3.2 | 1B, 3B, 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mixtral MoE | 8x7B, 8x22B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf) |
Expand Down Expand Up @@ -128,6 +128,7 @@ Every model is written from scratch to maximize performance and remove layers of
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama) |
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
Expand Down
21 changes: 21 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,27 @@ def norm_class(self) -> Type:
n_expert=8,
n_expert_per_token=2,
),
# https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/config.json
dict(
name="Mixtral-8x22B-{}v0.1",
hf_config=dict(org="mistralai", name="Mixtral-8x22B-{}v0.1"),
padded_vocab_size=32768,
block_size=65536,
n_layer=56,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-05,
mlp_class_name="LLaMAMoE",
intermediate_size=16384,
n_head=48,
n_embd=6144,
rope_base=1000000,
n_expert=8,
n_expert_per_token=2,
),
]
for c in mistral:
for kind in ("", "Instruct-"):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ def test_against_hf_llama2(ours_kwargs):


@torch.inference_mode()
def test_against_mixtral():
@pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1"))
def test_against_mixtral(model_name):
ours_config = Config.from_name(
"Mixtral-8x7B-Instruct-v0.1",
model_name,
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,12 @@ def test_against_mathstral_hf_models(device, dtype):


@torch.inference_mode()
def test_against_hf_mixtral():
@pytest.mark.parametrize("model_name", ("Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x22B-Instruct-v0.1"))
def test_against_hf_mixtral(model_name):
device = torch.device("cpu")
dtype = torch.float32
ours_config = Config.from_name(
"Mixtral-8x7B-Instruct-v0.1",
model_name,
padded_vocab_size=10000,
n_layer=2,
n_embd=32,
Expand Down
3 changes: 3 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
Expand Down Expand Up @@ -157,6 +158,8 @@ mistralai/Mistral-7B-v0.3
mistralai/Mistral-Large-Instruct-2407
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x22B-Instruct-v0.1
mistralai/Mixtral-8x22B-v0.1
NousResearch/Nous-Hermes-13b
NousResearch/Nous-Hermes-llama-2-7b
NousResearch/Nous-Hermes-Llama2-13b
Expand Down
Loading