Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma: WTE scaling for Adapter and LoRA #1193

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def forward(
mask = None

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if self.config.scale_embeddings:
x = x * (self.config.n_embd**0.5)
for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
Expand Down
2 changes: 2 additions & 0 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,8 @@ def forward(
mask = None

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if self.config.scale_embeddings:
x = x * (self.config.n_embd**0.5)
for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
from torch._dynamo.backends import debugging
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM

import litgpt.adapter as gpt_adapter
import litgpt.finetune.adapter as module
import litgpt.model as gpt
from litgpt.adapter import GPT, Config, adapter_filter
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca
from litgpt.scripts.convert_hf_checkpoint import copy_weights_hf_llama


def test_config_identical():
Expand Down Expand Up @@ -232,3 +234,44 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca
logs = stdout.getvalue()
assert "of trainable parameters: 168" in logs
assert "of non-trainable parameters: 1,888" in logs


@torch.inference_mode()
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)
1 change: 0 additions & 1 deletion tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def test_against_hf_mixtral():


@torch.inference_mode()
@pytest.mark.xfail(raises=AssertionError, match="Tensor-likes are not close")
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
Expand Down
1 change: 0 additions & 1 deletion tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ def test_against_hf_mixtral():


@torch.inference_mode()
@pytest.mark.xfail(raises=AssertionError, match="Tensor-likes are not close")
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
Expand Down
Loading