From 76137dabc94fc86d7a2d32ed340408dab5a587b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Apr 2024 06:21:12 -0400 Subject: [PATCH] add prompt --- litgpt/prompts.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index d827413913..bbc60e318d 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -200,6 +200,23 @@ def apply(self, prompt: str, **kwargs: str) -> str: ) +class Llama3(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + return ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "You are a helpful assistant.<|eot_id|>\n" + "<|start_header_id|>user<|end_header_id|>\n\n" + f"{prompt}<|eot_id|>\n" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + + def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: + return ( + [tokenizer.eos_id], + [tokenizer.token_to_id("<|eot_id|>")], + ) + + class FreeWilly2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: return ( @@ -316,6 +333,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama2FunctionCalling() if re.search("Llama-2.*-chat", model_name): return Llama2() + if re.search("Llama-3.*-Instruct", model_name): + return Llama3() if re.search("FreeWilly2", model_name): return FreeWilly2() if re.search("Platypus", model_name):