Skip to content

Commit

Permalink
GemmaMLP: add missing approximation for LoRA and AdapterV2 variants (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored and rasbt committed Apr 3, 2024
1 parent 98b288d commit f5602ed
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 4 deletions.
4 changes: 3 additions & 1 deletion litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def __init__(self, config: Config) -> None:
self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
Expand All @@ -199,7 +201,7 @@ class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2
return self.proj(x)


Expand Down
4 changes: 3 additions & 1 deletion litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ def __init__(self, config: Config) -> None:
lora_dropout=config.lora_dropout,
)

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
mapping = {
Expand All @@ -704,7 +706,7 @@ class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.gelu(x_fc_1) * x_fc_2
x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2
return self.proj(x)


Expand Down
47 changes: 45 additions & 2 deletions tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision
from lightning.fabric.wrappers import _FabricOptimizer
from torch._dynamo.backends import debugging
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM

import litgpt.config as config_module
import litgpt.finetune.adapter_v2 as module
from litgpt.adapter_v2 import GPT, Config, adapter_filter
from litgpt.adapter_v2 import GPT as AdapterV2GPT
from litgpt.adapter_v2 import Config, adapter_filter
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca
from litgpt.model import GPT as BaseGPT
Expand Down Expand Up @@ -195,7 +196,7 @@ def test_against_hf_mixtral():
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model = AdapterV2GPT(ours_config).to(device)
# strict=False because missing keys due to adapter weights not contained in state dict
ours_model.load_state_dict(state_dict, strict=False)

Expand All @@ -207,6 +208,48 @@ def test_against_hf_mixtral():
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.xfail(raises=AssertionError, match="Tensor-likes are not close")
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = AdapterV2GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict, strict=False)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@RunIf(min_cuda_gpus=1)
def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):
if not _BITSANDBYTES_AVAILABLE:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lightning.fabric.wrappers import _FabricOptimizer
from torch._dynamo.backends import debugging
from torch.nn import functional as F
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM

import litgpt.config as config_module
Expand Down Expand Up @@ -554,6 +555,48 @@ def test_against_hf_mixtral():
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
@pytest.mark.xfail(raises=AssertionError, match="Tensor-likes are not close")
@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"])
def test_against_hf_gemma(model_name):
device = torch.device("cpu")
dtype = torch.float32
T = 5
ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)
theirs_config = GemmaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
head_dim=ours_config.head_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
tie_word_embeddings=True,
hidden_act="gelu_pytorch_tanh",
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = GemmaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
# Gemma weights are shipped without `lm_head.weight`
theirs_state_dict.pop("lm_head.weight")
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = LoRAGPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@RunIf(min_cuda_gpus=1)
def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):
if not _BITSANDBYTES_AVAILABLE:
Expand Down

0 comments on commit f5602ed

Please sign in to comment.