From 1767797279441b94374cc9088bdf184334689f8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 09:30:44 +0000 Subject: [PATCH 01/14] add config for 8b --- litgpt/config.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/litgpt/config.py b/litgpt/config.py index 0a4234222d..906576e6d6 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -836,6 +836,71 @@ def norm_class(self) -> Type: copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) + + +############### +# Meta LLaMA 3 +############### +llama_3 = [ + # https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json + dict( + name="Llama-3-8B", + hf_config=dict(org="meta-llama", name="Llama-3-8B"), + vocab_size=128256, + padding_multiple=64, + n_layer=32, + n_head=32, + # n_head=64, + # n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=14336, + rope_base=500000, + ), + # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json + # dict( + # name="Llama-2-13b{}-hf", + # hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"), + # vocab_size=32000, + # padding_multiple=64, + # n_layer=40, + # n_head=40, + # n_embd=5120, + # rotary_percentage=1.0, + # parallel_residual=False, + # bias=False, + # norm_class_name="RMSNorm", + # mlp_class_name="LLaMAMLP", + # intermediate_size=13824, + # ), + # # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json + # dict( + # name="Llama-2-70b{}-hf", + # hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"), + # vocab_size=32000, + # padding_multiple=64, + # n_layer=80, + # n_head=64, + # n_embd=8192, + # n_query_groups=8, + # rotary_percentage=1.0, + # parallel_residual=False, + # bias=False, + # norm_class_name="RMSNorm", + # mlp_class_name="LLaMAMLP", + # intermediate_size=28672, + # ), +] +# for c in llama_3: +# for kind in ("", "-instruct"): +# copy = deepcopy(c) +# copy["name"] = c["name"].format(kind) +# copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) +# configs.append(copy) ############### From 9a8784e76acb8bffcefd8cceb0b5c5b2dc07d0f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 09:39:11 +0000 Subject: [PATCH 02/14] update configs --- litgpt/config.py | 67 ++++++++++++++++++------------------------------ 1 file changed, 25 insertions(+), 42 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 906576e6d6..02829bd4e8 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -844,13 +844,12 @@ def norm_class(self) -> Type: llama_3 = [ # https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json dict( - name="Llama-3-8B", - hf_config=dict(org="meta-llama", name="Llama-3-8B"), + name="Llama-3-8B{}", + hf_config=dict(org="meta-llama", name="Meta-Llama-3-8B{}"), vocab_size=128256, padding_multiple=64, n_layer=32, n_head=32, - # n_head=64, # n_embd=8192, n_query_groups=8, rotary_percentage=1.0, @@ -861,46 +860,30 @@ def norm_class(self) -> Type: intermediate_size=14336, rope_base=500000, ), - # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json - # dict( - # name="Llama-2-13b{}-hf", - # hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"), - # vocab_size=32000, - # padding_multiple=64, - # n_layer=40, - # n_head=40, - # n_embd=5120, - # rotary_percentage=1.0, - # parallel_residual=False, - # bias=False, - # norm_class_name="RMSNorm", - # mlp_class_name="LLaMAMLP", - # intermediate_size=13824, - # ), - # # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json - # dict( - # name="Llama-2-70b{}-hf", - # hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"), - # vocab_size=32000, - # padding_multiple=64, - # n_layer=80, - # n_head=64, - # n_embd=8192, - # n_query_groups=8, - # rotary_percentage=1.0, - # parallel_residual=False, - # bias=False, - # norm_class_name="RMSNorm", - # mlp_class_name="LLaMAMLP", - # intermediate_size=28672, - # ), + # https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json + dict( + name="Llama-3-70B{}", + hf_config=dict(org="meta-llama", name="Meta-Llama-3-70B{}"), + vocab_size=128256, + padding_multiple=64, + n_layer=80, + n_head=64, + n_embd=8192, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + rope_base=500000, + ), ] -# for c in llama_3: -# for kind in ("", "-instruct"): -# copy = deepcopy(c) -# copy["name"] = c["name"].format(kind) -# copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) -# configs.append(copy) +for c in llama_3: + for kind in ("", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) ############### From 67d0d121401d90c240a5873558dbff137f84f1c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 05:59:51 -0400 Subject: [PATCH 03/14] add test --- tests/test_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 7bc0ccb5b4..0537098342 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -206,7 +206,13 @@ def test_against_original_open_llama_3b(device, dtype): @torch.inference_mode() @pytest.mark.parametrize( "ours_kwargs", - [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1}], + [ + {"name": "Llama-2-7b-hf"}, + {"name": "CodeLlama-7b-hf"}, + {"name": "Llama-2-70b-chat-hf", "n_query_groups": 1}, + {"name": "Llama-3-8B"}, + {"name": "Llama-3-8B-Instruct"} + ], ) @pytest.mark.parametrize( ("device", "dtype"), @@ -224,7 +230,7 @@ def test_against_original_open_llama_3b(device, dtype): ), ], ) -def test_against_hf_llama2(ours_kwargs, device, dtype): +def test_against_hf_llama_2_and_3(ours_kwargs, device, dtype): torch.set_default_dtype(dtype) ours_config = Config.from_name( From 76137dabc94fc86d7a2d32ed340408dab5a587b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 06:21:12 -0400 Subject: [PATCH 04/14] add prompt --- litgpt/prompts.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index d827413913..bbc60e318d 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -200,6 +200,23 @@ def apply(self, prompt: str, **kwargs: str) -> str: ) +class Llama3(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant.<|eot_id|>\n" + "<|start_header_id|>user<|end_header_id|>\n\n" + f"{prompt}<|eot_id|>\n" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + [tokenizer.token_to_id("<|eot_id|>")], + ) + + class FreeWilly2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: return ( @@ -316,6 +333,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama2FunctionCalling() if re.search("Llama-2.*-chat", model_name): return Llama2() + if re.search("Llama-3.*-Instruct", model_name): + return Llama3() if re.search("FreeWilly2", model_name): return FreeWilly2() if re.search("Platypus", model_name): From b5f4098bd755b10b467c125b2d05ba3408f9c222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 06:25:31 -0400 Subject: [PATCH 05/14] add test --- tests/test_prompts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 3250ce4801..20f2c84e0c 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -50,6 +50,8 @@ def test_prompt_style_from_config(): "Llama-2-7b-chat-hf", "Llama-2-13b-chat-hf", "Llama-2-70b-chat-hf", + "Llama-3-8B-Instruct", + "Llama-3-70B-Instruct", "Gemma-2b-it", "Gemma-7b-it", "FreeWilly2", From 0879aaae33643575e1212e73a8b38bd6f1c762c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 06:40:07 -0400 Subject: [PATCH 06/14] optional --- litgpt/prompts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index bbc60e318d..04a0551cd1 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -202,9 +202,10 @@ def apply(self, prompt: str, **kwargs: str) -> str: class Llama3(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: + # https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229 return ( "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" - "You are a helpful assistant.<|eot_id|>\n" + "You are a helpful assistant.<|eot_id|>\n" # The system prompt is optional "<|start_header_id|>user<|end_header_id|>\n\n" f"{prompt}<|eot_id|>\n" "<|start_header_id|>assistant<|end_header_id|>\n\n" From 53322351fa0886a7d53a7fc2b619fbcaf0d16daf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 19 Apr 2024 12:42:35 +0200 Subject: [PATCH 07/14] add to table --- README.md | 47 ++++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index a98128445c..bec3650532 100644 --- a/README.md +++ b/README.md @@ -60,29 +60,30 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials #### All models -| Model | Model size | Author | Reference | -|----|----|----|----| -| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | -| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | -| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | -| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | -| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | -| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | -| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | -| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | -| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | -| Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | -| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | -| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | -| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | -| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | -| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | -| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | -| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | -| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | -| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) +| Model | Model size | Author | Reference | +|------------------------------|------------------------------------------|----|------------------------------------------------------------------------------------------------------------------------------| +| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | +| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | +| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | +| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | +| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | +| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | +| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | +| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | +| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | +| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | +| Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | +| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | +| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | +| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | +| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | +| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | +| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | +| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | +| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) From 8ed1502fb4d351d3230dccba30e7570579009f01 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 19 Apr 2024 12:44:01 +0200 Subject: [PATCH 08/14] Revert "add to table" This reverts commit 53322351fa0886a7d53a7fc2b619fbcaf0d16daf. --- README.md | 47 +++++++++++++++++++++++------------------------ 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index bec3650532..a98128445c 100644 --- a/README.md +++ b/README.md @@ -60,30 +60,29 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials #### All models -| Model | Model size | Author | Reference | -|------------------------------|------------------------------------------|----|------------------------------------------------------------------------------------------------------------------------------| -| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | -| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | -| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | -| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | -| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | -| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | -| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | -| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | -| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | -| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | -| Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | -| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | -| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | -| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | -| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | -| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | -| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | -| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | -| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | -| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) +| Model | Model size | Author | Reference | +|----|----|----|----| +| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | +| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | +| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | +| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | +| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | +| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | +| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | +| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | +| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | +| Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | +| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | +| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | +| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | +| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | +| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | +| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | +| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | +| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) From 6aa55e4fdd0038fea75f99c9597a0fa19f032292 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 06:44:31 -0400 Subject: [PATCH 09/14] add model to table --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a98128445c..faa78e18ff 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials | Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | | Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | +| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | | Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | | Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | From 80e97b0d96f41c1686bd5199c22b9de3455d4e97 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 19 Apr 2024 12:48:13 +0200 Subject: [PATCH 10/14] update --- tutorials/download_model_weights.md | 47 +++++++++++++++-------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index d1c320ac33..dab034e7f9 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -3,29 +3,30 @@ LitGPT supports a variety of LLM architectures with publicly available weights. You can download model weights and access a list of supported models using the LitGPT `download.py` script. -| Model | Model size | Reference | -|----------------------------------------------|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------| -| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | -| Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | -| Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | -| Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) | -| FreeWilly2 (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | -| Function Calling Llama 2 by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | -| Gemma by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | -| Llama 2 by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | -| LongChat by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | -| Mistral and Mixtral by Mistral AI | 7B | [Mistral website](https://mistral.ai/) | -| Nous-Hermes by NousResearch | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) | -| OpenLLaMA by OpenLM Research | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | -| Phi by Microsoft Research | 1.3B, 2.7B | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | -| Platypus by Lee at el. | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | -| Pythia by EleutherAI | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | -| RedPajama-INCITE by Together | 3B, 7B | [Together 2023](https://together.ai/blog/redpajama-models-v1) | -| StableCode by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| StableLM by Stability AI | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | -| StableLM Zephyr by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | -| TinyLlama by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | -| Vicuna by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | +| Model | Model size | Reference | +|----------------------------------------------|-----------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | +| Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | +| Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | +| Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) | +| FreeWilly2 (Stable Beluga 2) by Stability AI | 70B | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | +| Function Calling Llama 2 by Trelis | 7B | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) | +| Gemma by Google | 2B, 7B | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) | +| Llama 2 by Meta AI | 7B, 13B, 70B | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | +| Llama 3 by Meta AI | 8B, 70B | [Meta AI 2024](https://github.com/meta-llama/llama3) | +| LongChat by LMSYS | 7B, 13B | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) | +| Mistral and Mixtral by Mistral AI | 7B | [Mistral website](https://mistral.ai/) | +| Nous-Hermes by NousResearch | 7B, 13B, 70B | [Org page](https://huggingface.co/NousResearch) | +| OpenLLaMA by OpenLM Research | 3B, 7B, 13B | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | +| Phi by Microsoft Research | 1.3B, 2.7B | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | +| Platypus by Lee at el. | 7B, 13B, 70B | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | +| Pythia by EleutherAI | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) | +| RedPajama-INCITE by Together | 3B, 7B | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| StableCode by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| StableLM by Stability AI | 3B, 7B | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | +| StableLM Zephyr by Stability AI | 3B | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| TinyLlama by Zhang et al. | 1.1B | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | +| Vicuna by LMSYS | 7B, 13B, 33B | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) | From a202d9addf0de3c0df1ad0b0f173cec6d96c6719 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 19 Apr 2024 06:52:53 -0400 Subject: [PATCH 11/14] feature llama 3 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index faa78e18ff..146ef5b7c3 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials | Model | Model size | Author | Reference | |----|----|----|----| +| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) | | Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Mistral | 7B | Mistral AI | [Mistral website](https://mistral.ai/) | From d866b90f3762f04d21aa8fecf84ff5f9c9b33fc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 08:17:10 -0400 Subject: [PATCH 12/14] Update litgpt/config.py --- litgpt/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litgpt/config.py b/litgpt/config.py index 02829bd4e8..43c66f43be 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -850,7 +850,6 @@ def norm_class(self) -> Type: padding_multiple=64, n_layer=32, n_head=32, - # n_embd=8192, n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, From ce1b3620a50106735b4f83fac423cf167526dc99 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 19 Apr 2024 14:26:02 +0200 Subject: [PATCH 13/14] litgpt download --- tutorials/download_model_weights.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index dab034e7f9..45c9c7d50c 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -106,6 +106,10 @@ meta-llama/Llama-2-70b-chat-hf meta-llama/Llama-2-70b-hf meta-llama/Llama-2-7b-chat-hf meta-llama/Llama-2-7b-hf +meta-llama/Meta-Llama-3-70B +meta-llama/Meta-Llama-3-70B-Instruct +meta-llama/Meta-Llama-3-8B +meta-llama/Meta-Llama-3-8B-Instruct microsoft/phi-1_5 microsoft/phi-2 mistralai/Mistral-7B-Instruct-v0.1 From 5f49fb77467f5233c653137cb35c09efd11ebdeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 08:31:43 -0400 Subject: [PATCH 14/14] =?UTF-8?q?set=20n=5Fquery=E2=80=93groups=20for=2070?= =?UTF-8?q?b?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- litgpt/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litgpt/config.py b/litgpt/config.py index 43c66f43be..97e9f3a69d 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -868,6 +868,7 @@ def norm_class(self) -> Type: n_layer=80, n_head=64, n_embd=8192, + n_query_groups=8, rotary_percentage=1.0, parallel_residual=False, bias=False,