From ad3816ed64d1a7fce5dafa935c9bbf6fb184ba4b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Thu, 5 Dec 2024 23:34:48 -0500 Subject: [PATCH 1/3] add Salamandra --- README.md | 1 + litgpt/config.py | 55 ++++++++++++++++++++++++++ litgpt/prompts.py | 8 ++++ tests/test_model.py | 60 +++++++++++++++++++++++++++++ tutorials/download_model_weights.md | 5 +++ 5 files changed, 129 insertions(+) diff --git a/README.md b/README.md index 3856a332ea..ad4468217c 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | diff --git a/litgpt/config.py b/litgpt/config.py index 684f3f78be..582a1721a1 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2043,4 +2043,59 @@ def norm_class(self) -> Type: configs.extend(qwq) +############# +# Salamandra +############# + +salamandra = [ + # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json + dict( + name="salamandra-2b{}", + hf_config=dict(org="BSC-LT", name="salamandra-2b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=24, + n_head=16, + n_embd=2048, + n_query_groups=16, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=5440, + norm_eps=1e-5, + rope_base=10000 + ), + # https://huggingface.co/BSC-LT/salamandra-7b-instruct/blob/main/config.json + dict( + name="salamandra-7b{}", + hf_config=dict(org="BSC-LT", name="salamandra-7b{}"), + block_size=8192, + vocab_size=256000, + padded_vocab_size=256000, + n_layer=32, + n_head=32, + n_embd=4096, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=11008, + norm_eps=1e-6, + rope_base=10000 + ), +] + +for c in salamandra: + for kind in ("", "-instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 5f5fd14494..8ab5f548fb 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -290,6 +290,11 @@ def apply(self, prompt: str, **kwargs: str) -> str: system_message = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class Salamandra(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit. La meva base de coneixement es va actualitzar per última vegada l'agost de 2023." + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + # Maps prompt style names to PromptStyle classes prompt_styles: Dict[str, Type[PromptStyle]] = { @@ -316,6 +321,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "olmo": OLMo, "qwen2.5": Qwen2_5, "qwq": QwQ, + "salamandra": Salamandra, } @@ -358,6 +364,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() + if re.search(r"salamandra-.*", model_name): + return Salamandra() return Default() diff --git a/tests/test_model.py b/tests/test_model.py index 3ca5e80599..48b945a9a5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -851,6 +851,66 @@ def test_against_original_qwen_2_5(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_salamandra(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @RunIf(dynamo=True) @torch.inference_mode() def test_model_compile(): diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 509218ac96..8abb2d544d 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -39,6 +39,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | +| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | | StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) | @@ -62,6 +63,10 @@ The output is shown below: allenai/OLMo-1B-hf allenai/OLMo-7B-hf allenai/OLMo-7B-Instruct-hf +bsc-lt/salamandra-2b +bsc-lt/salamandra-2b-instruct +bsc-lt/salamandra-7b +bsc-lt/salamandra-7b-instruct codellama/CodeLlama-13b-hf codellama/CodeLlama-13b-Instruct-hf codellama/CodeLlama-13b-Python-hf From 7cfc08677bf2067aff2bde7a8ad820e5aab172a8 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Fri, 6 Dec 2024 12:59:31 -0500 Subject: [PATCH 2/3] Salamandra: fix tokenizer -> apply_decoding_fix dummy token is different for salamandra --- litgpt/tokenizer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index a81c59aa2d..6018a44734 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -143,6 +143,9 @@ def decode(self, tensor: torch.Tensor) -> str: if len(tokens) == 1 and self.apply_decoding_fix: dummy_token_id = 33 # \x1e dummy_token = self.processor.decode([dummy_token_id]) + if dummy_token != "\x1e": + dummy_token_id = 165 # \x1e is different in salamandra tokenizers + dummy_token = self.processor.decode([dummy_token_id]) return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :] return self.processor.decode(tokens) From 2f35d56e40500e9bd51b545853b78db8a07fd1ed Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Fri, 6 Dec 2024 14:31:25 -0500 Subject: [PATCH 3/3] Salamandra: fix chat template only applies to instruct --- litgpt/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 8ab5f548fb..96a99073b6 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -364,7 +364,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() - if re.search(r"salamandra-.*", model_name): + if re.search(r"salamandra-.*-instruct", model_name): return Salamandra() return Default()