diff --git a/README.md b/README.md index e12368fb7d..117f96a6cd 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,7 @@ Every model is written from scratch to maximize performance and remove layers of | Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) | | OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | +| OpenCoder | 1.5B, 8B | Infly | [Huang et al. 2024](https://opencoder-llm.github.io/) | | Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | | Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) | | Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | diff --git a/litgpt/config.py b/litgpt/config.py index 475f017e50..45a93913f1 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -1665,6 +1665,57 @@ def norm_class(self) -> Type: ) +############ +# OpenCoder +############ +opencoder = [ + # https://huggingface.co/infly/OpenCoder-8B-Base/blob/main/config.json + dict( + name="OpenCoder-8B{}", + hf_config=dict(org="infly", name="OpenCoder-8B{}"), + n_embd=4096, + block_size=8192, + vocab_size=96640, + padded_vocab_size=96640, + n_layer=32, + n_head=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=14336, + rope_base=500000, + ), + # https://huggingface.co/infly/OpenCoder-1.5B-Base/blob/main/config.json + dict( + name="OpenCoder-1.5B{}", + hf_config=dict(org="infly", name="OpenCoder-1.5B{}"), + n_embd=2240, + block_size=4096, + vocab_size=96640, + padded_vocab_size=96640, + n_layer=24, + n_head=14, + n_query_groups=14, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=6144, + rope_base=10000, + ) +] +for c in opencoder: + for kind in ("-Base", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + ############ # TinyLlama ############ diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 09b3277c7d..4e916c4aa2 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -277,7 +277,13 @@ def apply(self, prompt: str, **kwargs: str) -> str: class OLMo(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n" - + + +class OpenCoder(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "You are OpenCoder, created by OpenCoder Team." + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + class Qwen2_5(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: @@ -329,6 +335,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "gemma": Gemma, "llama3": Llama3, "olmo": OLMo, + "opencoder": OpenCoder, "qwen2.5": Qwen2_5, "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, @@ -372,6 +379,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Gemma() if re.search(r"OLMo.*-hf", model_name): return OLMo() + if re.search(r"OpenCoder.*-Instruct", model_name): + return OpenCoder() if re.search(r"Qwen2\.5-Math-.*", model_name): return Qwen2_5_Math() if re.search(r"Qwen2\.5-.*", model_name): diff --git a/tests/test_model.py b/tests/test_model.py index 89e926d173..96d4e28793 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -973,6 +973,66 @@ def test_against_original_smollm2(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize( + "ours_kwargs", + [ + {"name": "OpenCoder-1.5B-Base"}, + {"name": "OpenCoder-8B-Instruct"}, + ], +) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_opencoder(ours_kwargs, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + padded_vocab_size=10000, n_query_groups=2, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 876db1916a..ff01770e2b 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -31,6 +31,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) | | OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) | | OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) | +| OpenCoder | 1.5B, 8B | Infly | [Huang et al. 2024](https://opencoder-llm.github.io/) | | Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) | | Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) | Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) | @@ -129,6 +130,10 @@ HuggingFaceTB/SmolLM2-360M HuggingFaceTB/SmolLM2-360M-Instruct HuggingFaceTB/SmolLM2-1.7B HuggingFaceTB/SmolLM2-1.7B-Instruct +infly/OpenCoder-1.5B-Base +infly/OpenCoder-1.5B-Instruct +infly/OpenCoder-8B-Base +infly/OpenCoder-8B-Instruct lmsys/longchat-13b-16k lmsys/longchat-7b-16k lmsys/vicuna-13b-v1.3