Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add H2O Danube2 Checkpoint #1282

Merged
merged 12 commits into from
May 3, 2024
Merged

Conversation

Dev-Khant
Copy link
Contributor

@Dev-Khant Dev-Khant commented Apr 13, 2024

It includes the addition of Danube2 from H20.ai https://huggingface.co/h2oai/h2o-danube2-1.8b-chat.
Solves #1261

@Dev-Khant
Copy link
Contributor Author

Hi @rasbt

Here when I tried inference I could not get the desired response. I took the prompt style using this code

import torch
from transformers import pipeline

pipe = pipeline(
    "text-generation",
    model="h2oai/h2o-danube2-1.8b-chat",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# We use the HF Tokenizer chat template to format each message
# https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
    {"role": "user", "content": "What is 2+2?"},
]
prompt = pipe.tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
res = pipe(
    prompt,
    max_new_tokens=256,
)
print(res[0]["generated_text"])

The above code generates a response after passing through the model:
<|prompt|>What is 2+2?</s> <|answer|> Two (2) plus two (2) equals four (4).</s>

So even by using the chat prompt template HF uses for this above code, when I tried it with litegpt I’m getting random numbers.

I extensively tried to find the issue but couldn’t find it. So please can you guide me on what could be the potential issue?

@Dev-Khant
Copy link
Contributor Author

And for failing tokenizer test for Danube2, where do I have to make the change because I see there is no relevant field in config to change.

@Dev-Khant Dev-Khant mentioned this pull request Apr 18, 2024
@rasbt
Copy link
Collaborator

rasbt commented Apr 18, 2024

Hi there, and sorry for the late response, it's been a super intense week. Regarding the tokenizer, usually there shouldn't be a modification necessary as it loads the tokenizer from the hub. But maybe there is a special case here ... I have to think about this more ...
Regarding the random numbers you are getting, this could potentially be related.
Otherwise, when I tried to add a checkpoint in the past and got weird results, that was usually because there was something weird about the architecture that required some additional adjustments. How I debugged this in the past was adding a small model test based on a small tensor, e.g.,

def test_against_hf_phi_2(device, dtype):

And then in the model starting at the bottom, printing the outputs one layer at a time comparing it to the reference implementation to see at which layer the issue appears.

Maybe @Andrei-Aksionov has additional tips as he went through the ordeal with the different Gemma implementations, which all had some minor non-documented things in them.

@Andrei-Aksionov
Copy link
Collaborator

I'll take a look at it tomorrow.

@Dev-Khant
Copy link
Contributor Author

Hi there, and sorry for the late response, it's been a super intense week. Regarding the tokenizer, usually there shouldn't be a modification necessary as it loads the tokenizer from the hub. But maybe there is a special case here ... I have to think about this more ... Regarding the random numbers you are getting, this could potentially be related. Otherwise, when I tried to add a checkpoint in the past and got weird results, that was usually because there was something weird about the architecture that required some additional adjustments. How I debugged this in the past was adding a small model test based on a small tensor, e.g.,

def test_against_hf_phi_2(device, dtype):

And then in the model starting at the bottom, printing the outputs one layer at a time comparing it to the reference implementation to see at which layer the issue appears.

Maybe @Andrei-Aksionov has additional tips as he went through the ordeal with the different Gemma implementations, which all had some minor non-documented things in them.

Thanks for informing me.

@rasbt rasbt changed the title Add Danube2 Add H2O Danube2 Checkpoint Apr 18, 2024
@Andrei-Aksionov
Copy link
Collaborator

Andrei-Aksionov commented Apr 22, 2024

Hey @Dev-Khant
The config seems to be ok, the only missing part was rotary_percentage, which needs to be 1.0, since it's used in calculation of rope_n_elem (for RoPE embeddings):
https://github.com/Lightning-AI/litgpt/blob/main/litgpt/config.py#L92

and in the RoPE for Mistral they use self.dim (analogous to LitGPT head_dim).
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L100

That means that we need to use 100% of the head_dim size, hence rotary_percentage=1.0.


So, this little change + #1328 results in the identical output between LitGPT and HF variant.
You can verify it with a simple test code

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["h2o-danube2-1.8b-chat"])
@pytest.mark.parametrize(
    ("device", "dtype"),
    [
        (torch.device("cpu"), torch.float32),
    ],
)
def test_against_original_danube(model_name, device, dtype):
    torch.set_default_dtype(dtype)

    T = 5
    ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
    theirs_config = AutoConfig.from_pretrained(
        "/".join(ours_config.hf_config.values()),
        vocab_size=ours_config.padded_vocab_size,
        hidden_size=ours_config.n_embd,
        head_dim=ours_config.head_size,
        num_attention_heads=ours_config.n_head,
        num_hidden_layers=ours_config.n_layer,
        intermediate_size=ours_config.intermediate_size,
        max_position_embeddings=T,
        rms_norm_eps=ours_config.norm_eps,
        num_key_value_heads=ours_config.n_query_groups,
        rope_theta=ours_config.rope_base,
        attention_bias=ours_config.bias,
    )
    assert ours_config.intermediate_size == theirs_config.intermediate_size

    theirs_model = AutoModelForCausalLM.from_config(theirs_config).to(device)
    theirs_state_dict = theirs_model.state_dict()
    state_dict = {}
    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
    ours_model = GPT(ours_config).to(device)
    ours_model.load_state_dict(state_dict)

    # test end to end
    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
    assert x.size(1) == T
    ours_y = ours_model(x)
    theirs_y = theirs_model(x)["logits"].to(dtype)  # HF converts logits to float
    torch.testing.assert_close(ours_y, theirs_y)

The problem that I couldn't solve yet, is that despite the tokenizer output (with fixes from #1328) and the model output to be identical to the HF variant, the generation itself produces somewhat weird response.

I'll revisit it a bit later. Or maybe you can try to find it in the meantime. We can make it a race: who can find it quicker 🏎️.

Update: after I redownloaded the weights, the model started to show a decent output. Maybe a bit too many new tokens are generated for my liking, but I guess it can be tweaked with generation parameters.
In other words, we need to wait till #1328 is merged and this PR should be ready.

litgpt/prompts.py Outdated Show resolved Hide resolved
Co-authored-by: Andrei-Aksionov <[email protected]>
@Dev-Khant
Copy link
Contributor Author

Dev-Khant commented Apr 23, 2024

Hey @Dev-Khant The config seems to be ok, the only missing part was rotary_percentage, which needs to be 1.0, since it's used in calculation of rope_n_elem (for RoPE embeddings): https://github.com/Lightning-AI/litgpt/blob/main/litgpt/config.py#L92

and in the RoPE for Mistral they use self.dim (analogous to LitGPT head_dim). https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L100

That means that we need to use 100% of the head_dim size, hence rotary_percentage=1.0.

So, this little change + #1328 results in the identical output between LitGPT and HF variant. You can verify it with a simple test code

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["h2o-danube2-1.8b-chat"])
@pytest.mark.parametrize(
    ("device", "dtype"),
    [
        (torch.device("cpu"), torch.float32),
    ],
)
def test_against_original_danube(model_name, device, dtype):
    torch.set_default_dtype(dtype)

    T = 5
    ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
    theirs_config = AutoConfig.from_pretrained(
        "/".join(ours_config.hf_config.values()),
        vocab_size=ours_config.padded_vocab_size,
        hidden_size=ours_config.n_embd,
        head_dim=ours_config.head_size,
        num_attention_heads=ours_config.n_head,
        num_hidden_layers=ours_config.n_layer,
        intermediate_size=ours_config.intermediate_size,
        max_position_embeddings=T,
        rms_norm_eps=ours_config.norm_eps,
        num_key_value_heads=ours_config.n_query_groups,
        rope_theta=ours_config.rope_base,
        attention_bias=ours_config.bias,
    )
    assert ours_config.intermediate_size == theirs_config.intermediate_size

    theirs_model = AutoModelForCausalLM.from_config(theirs_config).to(device)
    theirs_state_dict = theirs_model.state_dict()
    state_dict = {}
    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
    ours_model = GPT(ours_config).to(device)
    ours_model.load_state_dict(state_dict)

    # test end to end
    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
    assert x.size(1) == T
    ours_y = ours_model(x)
    theirs_y = theirs_model(x)["logits"].to(dtype)  # HF converts logits to float
    torch.testing.assert_close(ours_y, theirs_y)

The problem that I couldn't solve yet, is that despite the tokenizer output (with fixes from #1328) and the model output to be identical to the HF variant, the generation itself produces somewhat weird response.

I'll revisit it a bit later. Or maybe you can try to find it in the meantime. We can make it a race: who can find it quicker 🏎️.

Update: after I redownloaded the weights, the model started to show a decent output. Maybe a bit too many new tokens are generated for my liking, but I guess it can be tweaked with generation parameters. In other words, we need to wait till #1328 is merged and this PR should be ready.

Thanks, @Andrei-Aksionov for thoroughly going through the PR. I'll run the above code with changes from #1328 and check the output for both tokenizers.

Also prior to this #1328, what I did was take the tokens generated from HF tokenizer and then pass them to the LitGPT model but still somehow it generated random words. So as you said we can wait for #1328 to get merged and meanwhile, I'll again try to see what is going wrong with the prediction.

Update: Looks like it works properly with changes from #1328 and with latest weights. @Andrei-Aksionov Please confirm on your side as well. Thanks!
Screenshot 2024-04-23 at 11 03 30 AM

@Dev-Khant
Copy link
Contributor Author

Hi @Andrei-Aksionov @rasbt Can you please review the PR now as it's passing the tests and also text generation is working properly.

@Andrei-Aksionov
Copy link
Collaborator

Hey @Dev-Khant
You don't need my approval, since I'm not a maintainer (though I approved anyway 🙂).

@rasbt Since you are a markdown Jedi, could you look at the changes like added empty lines?
From the commit history I see that I introduced those changes when I merged changes from the main, but apparently they are no longer there (in the main branch), or maybe my markdown auto-formatting did it 🤷‍♂️.

@rasbt
Copy link
Collaborator

rasbt commented May 3, 2024

Thanks for the ping @Dev-Khant & @Andrei-Aksionov , and thanks so much for this valuable contribution. I'll take a look!

Copy link
Collaborator

@rasbt rasbt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMT. I'll try it out just to make sure, and add a unit test, but this looks great!

@rasbt
Copy link
Collaborator

rasbt commented May 3, 2024

Just played around with it for a bit and it works great. Thanks again for this great contrib!

@rasbt rasbt merged commit e441c65 into Lightning-AI:main May 3, 2024
9 checks passed
@Dev-Khant
Copy link
Contributor Author

Thanks @rasbt!

),
],
)
def test_against_hf_h2o_danube(device, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that this test could be removed, since the checkpoint uses a model arch that's already tested for: https://huggingface.co/h2oai/h2o-danube2-1.8b-chat/blob/main/config.json#L4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants