Skip to content

Commit

Permalink
Add bnb number of parameters test (#744)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Nov 17, 2023
1 parent 0a202f6 commit 21e5c0e
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 38 deletions.
3 changes: 2 additions & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from unittest.mock import Mock

import torch
from conftest import RunIf
from lightning import Fabric

from conftest import RunIf


def test_config_identical():
import lit_gpt.adapter as gpt_adapter
Expand Down
3 changes: 2 additions & 1 deletion tests/test_adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from unittest.mock import Mock

import torch
from conftest import RunIf
from lightning import Fabric

from conftest import RunIf


def test_config_identical():
import lit_gpt.adapter_v2 as gpt_adapter
Expand Down
22 changes: 9 additions & 13 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def test_convert_lit_checkpoint(tmp_path):

@torch.inference_mode()
def test_against_falcon_40b():
from transformers.models.falcon.configuration_falcon import FalconConfig
from transformers.models.falcon.modeling_falcon import FalconForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_falcon as copy_to_theirs
from transformers.models.falcon.configuration_falcon import FalconConfig
from transformers.models.falcon.modeling_falcon import FalconForCausalLM

ours_config = Config.from_name("falcon-40b", n_layer=2, n_head=8, n_query_groups=4, n_embd=32)
theirs_config = FalconConfig(
Expand Down Expand Up @@ -72,10 +71,9 @@ def test_against_falcon_40b():

@torch.inference_mode()
def test_against_original_gpt_neox():
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_gpt_neox as copy_to_theirs
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

ours_config = Config(block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=16)
assert ours_config.padded_vocab_size == 512
Expand Down Expand Up @@ -116,11 +114,10 @@ def test_against_original_gpt_neox():
"ours_kwargs", [{"name": "Llama-2-7b-hf"}, {"name": "CodeLlama-7b-hf"}, {"name": "Llama-2-70b-chat-hf"}]
)
def test_against_hf_llama2(ours_kwargs):
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

ours_config = Config.from_name(
padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
Expand Down Expand Up @@ -155,11 +152,10 @@ def test_against_hf_llama2(ours_kwargs):

@torch.inference_mode()
def test_against_original_open_llama_3b():
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)
T = 5
Expand Down Expand Up @@ -194,11 +190,11 @@ def test_against_hf_phi():
if not file_path.is_file():
urlretrieve(url=url, filename=file_path)

from original_phi_1_5 import MixFormerSequentialConfig, MixFormerSequentialForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_phi

from original_phi_1_5 import MixFormerSequentialConfig, MixFormerSequentialForCausalLM

ours_config = Config.from_name(
"phi-1_5", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_gptq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import lightning as L
import pytest
import torch

from conftest import RunIf


Expand Down
6 changes: 3 additions & 3 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import pytest
import torch
from conftest import RunIf
from lightning import Fabric

from conftest import RunIf


def test_lora_layer_replacement():
from lit_gpt.lora import GPT, Config, LoRALinear
Expand Down Expand Up @@ -304,9 +305,8 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap
],
)
def test_lora_qkv_linear_compare_conv1d(n_head, enable_lora):
from torch.nn import functional as F

from lit_gpt.lora import LoRAQKVLinear
from torch.nn import functional as F

C = 12
layer = LoRAQKVLinear(C, 3 * C, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora)
Expand Down
30 changes: 12 additions & 18 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import pytest
import torch
from conftest import RunIf
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_2

from conftest import RunIf

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
Expand Down Expand Up @@ -37,10 +38,9 @@
],
)
def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, device, dtype) -> None:
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_gpt_neox
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

torch.set_default_dtype(dtype)

Expand Down Expand Up @@ -110,10 +110,9 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua
],
)
def test_against_hf_falcon(kwargs, device, dtype):
from transformers.models.falcon import FalconConfig, FalconForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_falcon
from transformers.models.falcon import FalconConfig, FalconForCausalLM

torch.set_default_dtype(dtype)

Expand Down Expand Up @@ -161,11 +160,10 @@ def test_against_hf_falcon(kwargs, device, dtype):
],
)
def test_against_original_open_llama_3b(device, dtype):
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

torch.set_default_dtype(dtype)

Expand Down Expand Up @@ -217,11 +215,10 @@ def test_against_original_open_llama_3b(device, dtype):
],
)
def test_against_hf_llama2(ours_kwargs, device, dtype):
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

torch.set_default_dtype(dtype)

Expand Down Expand Up @@ -330,11 +327,10 @@ def test_against_hf_phi(device, dtype):
],
)
def test_against_hf_mistral(device, dtype):
from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama
from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.models.mistral.modeling_mistral import MistralForCausalLM

torch.set_default_dtype(dtype)

Expand Down Expand Up @@ -460,9 +456,8 @@ def test_model_kv_cache_amp():
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice(config):
from torch.backends.cuda import SDPBackend

from lit_gpt import GPT
from torch.backends.cuda import SDPBackend

torch.set_default_dtype(torch.float16)

Expand Down Expand Up @@ -505,9 +500,8 @@ def assert_sdpa_uses_flash(original_fn, q, k, v, mask):
@pytest.mark.parametrize("config", config_module.configs, ids=[c["name"] for c in config_module.configs])
@torch.inference_mode()
def test_sdpa_choice_kv_cache(config):
from torch.backends.cuda import SDPBackend

from lit_gpt import GPT
from torch.backends.cuda import SDPBackend

torch.set_default_dtype(torch.float16)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

@torch.inference_mode()
def test_rope():
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding, apply_rotary_pos_emb

from lit_gpt.model import apply_rope, build_rope_cache
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding, apply_rotary_pos_emb

bs, seq_len, n_head, n_embed = 1, 6, 2, 8
head_size = n_embed // n_head
Expand Down
22 changes: 22 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest
import torch
import torch.nn.functional as F
from lightning import Fabric

from conftest import RunIf


Expand Down Expand Up @@ -135,3 +137,23 @@ def test_num_parameters():
assert num_parameters(model) == 6
assert num_parameters(model, requires_grad=True) == 4
assert num_parameters(model, requires_grad=False) == 2


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("mode", ["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"])
@pytest.mark.skip("To be fixed")
def test_num_parameters_bitsandbytes(mode):
from lightning.fabric.plugins import BitsandbytesPrecision
from lit_gpt import GPT
from lit_gpt.utils import num_parameters

plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin, accelerator="cuda", devices=1)

model = torch.nn.Linear(10, 10)
model = fabric.setup(model)
assert num_parameters(model) == 110

with fabric.init_module(empty_init=True):
model = GPT.from_name("pythia-70m")
assert num_parameters(model) == 70426624

0 comments on commit 21e5c0e

Please sign in to comment.