Skip to content

Commit

Permalink
Added ChatML inheritance for better typing compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects committed Dec 22, 2024
1 parent b72cc20 commit 2dc0be8
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,31 @@ def apply(self, prompt: str, **kwargs: str) -> str:


class ChatML(PromptStyle):
def __init__(self, model_name: str):
self.model_name = model_name
self.system_messages = {
"qwen2.5": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
"qwen2.5-math": "Please reason step by step, and put your final answer within \\boxed{}.",
"qwq": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.",
"smollm2": "You are a helpful AI assistant named SmolLM, trained by Hugging Face",
"salamandra": "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit."
}
def __init__(self, system_message: str):
self.system_message = system_message

def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|im_start|>system\n{self.system_messages[self.model_name]}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
return f"<|im_start|>system\n{self.system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"

class Qwen2_5(ChatML):
def __init__(self):
super().__init__("You are Qwen, created by Alibaba Cloud. You are a helpful assistant.")

class Qwen2_5_Math(ChatML):
def __init__(self):
super().__init__("Please reason step by step, and put your final answer within \\boxed{}.")

class QwQ(ChatML):
def __init__(self):
super().__init__("You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.")

class SmolLM2(ChatML):
def __init__(self):
super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face")

class Salamandra(ChatML):
def __init__(self):
super().__init__("I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit.")


# Maps prompt style names to PromptStyle classes
Expand All @@ -317,11 +330,11 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"gemma": Gemma,
"llama3": Llama3,
"olmo": OLMo,
"qwen2.5": lambda: ChatML("qwen2.5"),
"qwen2.5-math": lambda: ChatML("qwen2.5-math"),
"qwq": lambda: ChatML("qwq"),
"smollm2": lambda: ChatML("smollm2"),
"salamandra": lambda: ChatML("salamandra"),
"qwen2.5": Qwen2_5,
"qwen2.5-math": Qwen2_5_Math,
"qwq": QwQ,
"smollm2": SmolLM2,
"salamandra": Salamandra,
}


Expand Down Expand Up @@ -361,15 +374,15 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
if re.search(r"OLMo.*-hf", model_name):
return OLMo()
if re.search(r"Qwen2\.5-Math-.*", model_name):
return ChatML("qwen2.5-math")
return Qwen2_5_Math()
if re.search(r"Qwen2\.5-.*", model_name):
return ChatML("qwen2.5")
return Qwen2_5()
if re.search(r"QwQ-.*", model_name):
return ChatML("qwq")
return QwQ()
if re.search(r"SmolLM2.*-Instruct", model_name):
return ChatML("smollm2")
return SmolLM2()
if re.search(r"salamandra-.*-instruct", model_name):
return ChatML("salamandra")
return Salamandra()
return Default()


Expand Down

0 comments on commit 2dc0be8

Please sign in to comment.