diff --git a/litgpt/prompts.py b/litgpt/prompts.py index ab9cfee594..09b3277c7d 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -332,7 +332,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "qwen2.5": Qwen2_5, "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, - "smollm2": SmolLM2, # SmolLM uses a different template + "smollm2": SmolLM2, "salamandra": Salamandra, } diff --git a/tests/test_model.py b/tests/test_model.py index e3aad3bb0e..89e926d173 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -914,7 +914,7 @@ def test_against_original_salamandra(model_name, device, dtype): @torch.inference_mode() -@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-1.7B")) +@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B")) @pytest.mark.parametrize( ("device", "dtype"), [