Skip to content

Commit

Permalink
OpenCoder series
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects committed Dec 21, 2024
1 parent 7b26d35 commit eb1a9a5
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Every model is written from scratch to maximize performance and remove layers of
| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| OpenCoder | 1.5B, 8B | Infly | [Huang et al. 2024](https://opencoder-llm.github.io/) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
Expand Down
51 changes: 51 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,57 @@ def norm_class(self) -> Type:
)


############
# OpenCoder
############
opencoder = [
# https://huggingface.co/infly/OpenCoder-8B-Base/blob/main/config.json
dict(
name="OpenCoder-8B{}",
hf_config=dict(org="infly", name="OpenCoder-8B{}"),
n_embd=4096,
block_size=8192,
vocab_size=96640,
padded_vocab_size=96640,
n_layer=32,
n_head=32,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
rope_base=500000,
),
# https://huggingface.co/infly/OpenCoder-1.5B-Base/blob/main/config.json
dict(
name="OpenCoder-1.5B{}",
hf_config=dict(org="infly", name="OpenCoder-1.5B{}"),
n_embd=2240,
block_size=4096,
vocab_size=96640,
padded_vocab_size=96640,
n_layer=24,
n_head=14,
n_query_groups=14,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=6144,
rope_base=10000,
)
]
for c in opencoder:
for kind in ("-Base", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)


############
# TinyLlama
############
Expand Down
11 changes: 10 additions & 1 deletion litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,13 @@ def apply(self, prompt: str, **kwargs: str) -> str:
class OLMo(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n"



class OpenCoder(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are OpenCoder, created by OpenCoder Team."
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


class Qwen2_5(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
Expand Down Expand Up @@ -329,6 +335,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"gemma": Gemma,
"llama3": Llama3,
"olmo": OLMo,
"opencoder": OpenCoder,
"qwen2.5": Qwen2_5,
"qwen2.5-math": Qwen2_5_Math,
"qwq": QwQ,
Expand Down Expand Up @@ -372,6 +379,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Gemma()
if re.search(r"OLMo.*-hf", model_name):
return OLMo()
if re.search(r"OpenCoder.*-Instruct", model_name):
return OpenCoder()
if re.search(r"Qwen2\.5-Math-.*", model_name):
return Qwen2_5_Math()
if re.search(r"Qwen2\.5-.*", model_name):
Expand Down
60 changes: 60 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,66 @@ def test_against_original_smollm2(model_name, device, dtype):
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.parametrize(
"ours_kwargs",
[
{"name": "OpenCoder-1.5B-Base"},
{"name": "OpenCoder-8B-Instruct"},
],
)
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_original_opencoder(ours_kwargs, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
padded_vocab_size=10000, n_query_groups=2, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@RunIf(dynamo=True)
@torch.inference_mode()
def test_model_compile():
Expand Down
5 changes: 5 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| OpenCoder | 1.5B, 8B | Infly | [Huang et al. 2024](https://opencoder-llm.github.io/) |
| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219)
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
Expand Down Expand Up @@ -129,6 +130,10 @@ HuggingFaceTB/SmolLM2-360M
HuggingFaceTB/SmolLM2-360M-Instruct
HuggingFaceTB/SmolLM2-1.7B
HuggingFaceTB/SmolLM2-1.7B-Instruct
infly/OpenCoder-1.5B-Base
infly/OpenCoder-1.5B-Instruct
infly/OpenCoder-8B-Base
infly/OpenCoder-8B-Instruct
lmsys/longchat-13b-16k
lmsys/longchat-7b-16k
lmsys/vicuna-13b-v1.3
Expand Down

0 comments on commit eb1a9a5

Please sign in to comment.