Skip to content

Commit

Permalink
Merge branch 'main' into mixtral8x22b
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored Dec 2, 2024
2 parents 191aa34 + 7449dad commit b044fad
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ Every model is written from scratch to maximize performance and remove layers of
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
Expand Down
26 changes: 26 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,4 +2017,30 @@ def norm_class(self) -> Type:
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)

qwq = [
# https://huggingface.co/Qwen/QwQ-32B-Preview/blob/main/config.json
dict(
name="QwQ-32B-Preview",
hf_config=dict(org="Qwen", name="QwQ-32B-Preview"),
block_size=131072,
vocab_size=151643,
padded_vocab_size=152064,
n_layer=64,
n_head=40,
n_embd=5120,
n_query_groups=8,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
attn_bias=True,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=27648,
norm_eps=1e-5,
rope_base=1000000
),
]

configs.extend(qwq)

name_to_config = {config["name"]: config for config in configs}
11 changes: 9 additions & 2 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"




class OLMo(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|endoftext|><|user|>\n{prompt}\n<|assistant|>\n"
Expand All @@ -287,6 +285,12 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


class QwQ(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"


# Maps prompt style names to PromptStyle classes
prompt_styles: Dict[str, Type[PromptStyle]] = {
# Dataset-specific prompt styles
Expand All @@ -311,6 +315,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"llama3": Llama3,
"olmo": OLMo,
"qwen2.5": Qwen2_5,
"qwq": QwQ,
}


Expand Down Expand Up @@ -351,6 +356,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return OLMo()
if re.search(r"Qwen2\.5-.*", model_name):
return Qwen2_5()
if re.search(r"QwQ-.*", model_name):
return QwQ()
return Default()


Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_phi, config, qkv_weights)
elif model_name.lower().startswith("qwen2.5"):
elif model_name.lower().startswith(("qwen2.5","qwq")):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
copy_fn = partial(copy_weights_gemma_2, config)
elif config.name.lower().startswith("phi"):
copy_fn = partial(copy_weights_phi, config)
elif config.name.lower().startswith("qwen2.5"):
elif config.name.lower().startswith(("qwen2.5","qwq")):
copy_fn = partial(copy_weights_qwen_2_5, config)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
untie_weights = "Gemma" in config.name
Expand Down
2 changes: 1 addition & 1 deletion tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def test_check_conversion_supported_lora():
check_conversion_supported(lit_weights=lit_weights)

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B"))
@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "QwQ-32B-Preview"))
@pytest.mark.parametrize(
("device", "dtype"),
[
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def test_against_original_gemma_2(model_name, device, dtype):


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B"))
@pytest.mark.parametrize("model_name", ("Qwen2.5-1.5B", "Qwen2.5-Coder-1.5B", "QwQ-32B-Preview"))
@pytest.mark.parametrize(
("device", "dtype"),
[
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_tokenizer_against_hf(config):
else:
assert ours.vocab_size == config.vocab_size

if config.name.startswith(("falcon", "stablecode", "Qwen2.5")):
if config.name.startswith(("falcon", "stablecode", "Qwen2.5", "QwQ")):
# even though their config defines it, it's set as None in HF
assert isinstance(ours.bos_id, int)
assert theirs.bos_token_id is None
Expand Down
2 changes: 2 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
Expand Down Expand Up @@ -192,6 +193,7 @@ Qwen/Qwen2.5-Coder-14B
Qwen/Qwen2.5-Coder-14B-Instruct
Qwen/Qwen2.5-Coder-32B
Qwen/Qwen2.5-Coder-32B-Instruct
Qwen/QwQ-32B-Preview
stabilityai/FreeWilly2
stabilityai/stable-code-3b
stabilityai/stablecode-completion-alpha-3b
Expand Down

0 comments on commit b044fad

Please sign in to comment.