diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 8ab5f548fb..96a99073b6 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -364,7 +364,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() - if re.search(r"salamandra-.*", model_name): + if re.search(r"salamandra-.*-instruct", model_name): return Salamandra() return Default()