Skip to content

Commit

Permalink
Falcon3: added prompts, test case, README
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects committed Dec 21, 2024
1 parent 2c0d963 commit e038c28
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Every model is written from scratch to maximize performance and remove layers of
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://falconllm.tii.ae/falcon3/index.html) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
Expand Down
13 changes: 12 additions & 1 deletion litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
)


class Falcon3(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n"

def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("<|endoftext|>")],
)


class Llama2FunctionCalling(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
# Has to be before the llama config
Expand Down Expand Up @@ -345,7 +356,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
if re.search("stablecode-instruct", model_name):
return StableCode()
if re.search(r"Falcon3.*-Instruct", model_name):
pass
return Falcon3()
if re.search(r"falcon.*-instruct", model_name):
return Falcon()
if re.search("Llama-2-7b-chat-hf-function-calling-v2", model_name):
Expand Down
59 changes: 59 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,65 @@ def test_against_original_smollm2(model_name, device, dtype):
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Falcon3-1B-Base", "Falcon3-7B-Base"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_hf_falcon3(model_name, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@RunIf(dynamo=True)
@torch.inference_mode()
Expand Down
9 changes: 9 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://falconllm.tii.ae/falcon3/index.html) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
Expand Down Expand Up @@ -232,6 +233,14 @@ tiiuae/falcon-40b
tiiuae/falcon-40b-instruct
tiiuae/falcon-7b
tiiuae/falcon-7b-instruct
tiiuae/Falcon3-1B-Base
tiiuae/Falcon3-1B-Instruct
tiiuae/Falcon3-3B-Base
tiiuae/Falcon3-3B-Instruct
tiiuae/Falcon3-7B-Base
tiiuae/Falcon3-7B-Instruct
tiiuae/Falcon3-10B-Base
tiiuae/Falcon3-10B-Instruct
TinyLlama/TinyLlama-1.1B-Chat-v1.0
TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
togethercomputer/LLaMA-2-7B-32K
Expand Down

0 comments on commit e038c28

Please sign in to comment.