From 3539d277835b31770e220b730d38a153284ca03d Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 13 Apr 2024 11:07:03 +0530 Subject: [PATCH 1/7] Add Danube2 --- README.md | 1 + litgpt/config.py | 26 ++++++++++++++++++++++++++ litgpt/prompts.py | 8 ++++++++ tutorials/download_model_weights.md | 2 ++ 4 files changed, 37 insertions(+) diff --git a/README.md b/README.md index 692525df48..674720c5e9 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,7 @@ Use, Finetune, pretrain, deploy over 20+ LLMs ([full list](tutorials/download_mo | Model | Model size | Author | Reference | |----|----|----|----| +| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | diff --git a/litgpt/config.py b/litgpt/config.py index 0a4234222d..a5bf97fe0b 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -914,6 +914,32 @@ def norm_class(self) -> Type: ] configs.extend(codegemma) +################ +# H2Oai Danube2 +################ +danube2 = [ + # https://huggingface.co/h2oai/h2o-danube2-1.8b-chat/blob/main/config.json + dict( + name="Danube2-1.8b-chat", + hf_config=dict(org="h2oai", name="h2o-danube2-1.8b-chat"), + vocab_size=32000, + n_layer=24, + n_head=32, + n_embd=2560, + block_size=4096, # should be 8192 but sliding_window mechanism is not implemented + intermediate_size=6912, + padding_multiple=64, + norm_eps=1e-05, + rope_base=10000, + n_query_groups=8, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP" + ) +] +configs.extend(danube2) + ########################## # Stability AI FreeWilly2 diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 2d989be32b..f7fd59d0a0 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -268,6 +268,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: class Gemma(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: return f"user\n{prompt}\nmodel\n" + + +class H2Oai(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return f"<|prompt|>{prompt}<|answer|>" # Maps prompt style names to PromptStyle classes @@ -294,6 +299,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "phi-2": Phi2, "tinyllama": TinyLlama, "gemma": Gemma, + "h2oai": H2Oai } @@ -332,6 +338,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return TinyLlama() if re.search(r"(Code)?Gemma.*-it", model_name): return Gemma() + if re.search("Danube2.*-chat", model_name): + return H2Oai return Default() diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index b91afa5929..127a015835 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -5,6 +5,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Model | Model size | Reference | |----------------------------------------------|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------| +| Danube2 by H2O.ai | 1.8B | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) | CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | @@ -90,6 +91,7 @@ google/gemma-2b google/gemma-2b-it google/gemma-7b google/gemma-7b-it +h2oai/h2o-danube2-1.8b-chat lmsys/longchat-13b-16k lmsys/longchat-7b-16k lmsys/vicuna-13b-v1.3 From ea5f9bf3c7abc0b23095350b41c0443057fd5810 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 13 Apr 2024 11:34:48 +0530 Subject: [PATCH 2/7] prompt fix --- litgpt/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index f7fd59d0a0..df38aa3937 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -339,7 +339,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: if re.search(r"(Code)?Gemma.*-it", model_name): return Gemma() if re.search("Danube2.*-chat", model_name): - return H2Oai + return H2Oai() return Default() From a6da5a8bce10ef2725e9f620b49dbb65468170e0 Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Mon, 22 Apr 2024 13:45:37 +0300 Subject: [PATCH 3/7] Code-style fixes --- litgpt/prompts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 71ad69f9e5..7d91148b87 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -286,11 +286,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: class Gemma(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: return f"user\n{prompt}\nmodel\n" - + class H2Oai(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - return f"<|prompt|>{prompt}<|answer|>" + return f"<|prompt|>{prompt}<|answer|>" # Maps prompt style names to PromptStyle classes @@ -317,7 +317,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "phi-2": Phi2, "tinyllama": TinyLlama, "gemma": Gemma, - "h2oai": H2Oai + "h2oai": H2Oai, } From e3e502c3f19a0b86e142c413f562fba14ebb6d1f Mon Sep 17 00:00:00 2001 From: Andrei-Aksionov Date: Mon, 22 Apr 2024 13:49:31 +0300 Subject: [PATCH 4/7] Add forgotten rotary_percentage --- litgpt/config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 1157e7acb7..ee99af6edb 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -836,7 +836,7 @@ def norm_class(self) -> Type: copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) - + ############### # Meta LLaMA 3 @@ -976,16 +976,17 @@ def norm_class(self) -> Type: n_layer=24, n_head=32, n_embd=2560, - block_size=4096, # should be 8192 but sliding_window mechanism is not implemented + block_size=4096, # should be 8192 but sliding_window mechanism is not implemented intermediate_size=6912, padding_multiple=64, norm_eps=1e-05, rope_base=10000, n_query_groups=8, + rotary_percentage=1.0, parallel_residual=False, bias=False, norm_class_name="RMSNorm", - mlp_class_name="LLaMAMLP" + mlp_class_name="LLaMAMLP", ) ] configs.extend(danube2) From 1c2d0cc5ccaf0aa24bddaf0a9c47926c31c8318f Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Tue, 23 Apr 2024 10:25:42 +0530 Subject: [PATCH 5/7] Prompt template change Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com> --- litgpt/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 7d91148b87..a0e515c3f8 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -290,7 +290,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: class H2Oai(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - return f"<|prompt|>{prompt}<|answer|>" + return f"<|prompt|>{prompt}<|answer|>" # Maps prompt style names to PromptStyle classes From 1f18fe9bf28f53893f12588443770f9ded8efed6 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 3 May 2024 08:47:41 -0500 Subject: [PATCH 6/7] Update tutorials/download_model_weights.md --- tutorials/download_model_weights.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index fb8a242b69..2775e1e578 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -250,7 +250,7 @@ litgpt download \   -## Finetunes and other model variants +## Finetunes and Other Model Variants Sometimes you want to download the weights of a finetune of one of the models listed above. To do this, you need to manually specify the `model_name` associated to the config to use. For example: From 6571db2529da5a77488c46000ad37afa3046343c Mon Sep 17 00:00:00 2001 From: rasbt Date: Fri, 3 May 2024 14:56:23 +0000 Subject: [PATCH 7/7] add test --- tests/test_model.py | 58 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 7743c4f143..49584aeb87 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -417,6 +417,64 @@ def test_against_hf_mixtral(): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_hf_h2o_danube(device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + "Danube2-1.8b-chat", + padded_vocab_size=10000, + n_layer=2, + n_embd=16, + n_head=8, + n_query_groups=2, + intermediate_size=43, + ) + T = 5 + theirs_config = MistralConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = MistralForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"),