Skip to content

Commit

Permalink
Phi model update (#876)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
Andrei-Aksionov and carmocca authored Jan 16, 2024
1 parent 1e5afd6 commit 423b1a8
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 91 deletions.
84 changes: 54 additions & 30 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gc
import json
import sys
from collections import defaultdict
from dataclasses import asdict
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -189,54 +190,75 @@ def copy_weights_hf_llama(

def copy_weights_phi(
config: Config,
qkv_weights: dict,
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
if any(layer_name.startswith("layers.") for layer_name in hf_weights):
if any(layer_name.startswith(("layers.", "transformer.")) for layer_name in hf_weights):
raise ValueError(
"You are using an outdated Phi1.5 checkpoint. "
"Please reload it as described in 'tutorials/download_phi15.md'"
"You are using an outdated Phi checkpoint. Please reload it as described in 'tutorials/download_phi.md'"
)

weight_map = {
"transformer.embd.wte.weight": "transformer.wte.weight",
"transformer.h.{}.ln.bias": "transformer.h.{}.norm_1.bias",
"transformer.h.{}.ln.weight": "transformer.h.{}.norm_1.weight",
"transformer.h.{}.mixer.Wqkv.bias": "transformer.h.{}.attn.attn.bias",
"transformer.h.{}.mixer.Wqkv.weight": "transformer.h.{}.attn.attn.weight",
"transformer.h.{}.mixer.out_proj.bias": "transformer.h.{}.attn.proj.bias",
"transformer.h.{}.mixer.out_proj.weight": "transformer.h.{}.attn.proj.weight",
"transformer.h.{}.mixer.rotary_emb.inv_freq": None,
"transformer.h.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias",
"transformer.h.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight",
"transformer.h.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias",
"transformer.h.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight",
"lm_head.ln.weight": "transformer.ln_f.weight",
"lm_head.ln.bias": "transformer.ln_f.bias",
"lm_head.linear.weight": "lm_head.weight",
"lm_head.linear.bias": "lm_head.bias",
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.q_proj.bias": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.k_proj.bias": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.v_proj.bias": None,
"model.layers.{}.self_attn.dense.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.dense.bias": "transformer.h.{}.attn.proj.bias",
"model.layers.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight",
"model.layers.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias",
"model.layers.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight",
"model.layers.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias",
"model.final_layernorm.weight": "transformer.ln_f.weight",
"model.final_layernorm.bias": "transformer.ln_f.bias",
"lm_head.weight": "lm_head.weight",
"lm_head.bias": "lm_head.bias",
}

for name, param in hf_weights.items():
if name.startswith("transformer.h."):
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name].format(number)
if name.startswith("model.layers."):
from_name, l = layer_template(name, 2)
qkv = qkv_weights.setdefault(l, defaultdict(dict))
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
to_name = weight_map[from_name]
if to_name is None:
continue
to_name = to_name.format(l)
else:
to_name = weight_map[name]
param = load_param(param, name, dtype)
if "Wqkv" in name:
q_per_kv = config.n_head // config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
param = param.view(total_qkv, config.n_query_groups, -1).transpose(0, 1)
param = param.reshape(config.n_embd * 3, -1)
if "bias" in name:
param = param.squeeze()
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# split across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv
del qkv_weights[i][weight_type]


def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
split = layer_name.split(".")
Expand Down Expand Up @@ -282,7 +304,9 @@ def convert_hf_checkpoint(
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
copy_fn = partial(copy_weights_phi, config)
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_phi, config, qkv_weights)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
62 changes: 37 additions & 25 deletions scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,35 +166,45 @@ def copy_weights_phi(
saver: Optional[incremental_save] = None,
) -> None:
weight_map = {
"transformer.wte.weight": "transformer.embd.wte.weight",
"transformer.h.{}.norm_1.bias": "transformer.h.{}.ln.bias",
"transformer.h.{}.norm_1.weight": "transformer.h.{}.ln.weight",
"transformer.h.{}.attn.attn.bias": "transformer.h.{}.mixer.Wqkv.bias",
"transformer.h.{}.attn.attn.weight": "transformer.h.{}.mixer.Wqkv.weight",
"transformer.h.{}.attn.proj.bias": "transformer.h.{}.mixer.out_proj.bias",
"transformer.h.{}.attn.proj.weight": "transformer.h.{}.mixer.out_proj.weight",
"transformer.h.{}.mlp.fc.bias": "transformer.h.{}.mlp.fc1.bias",
"transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.fc1.weight",
"transformer.h.{}.mlp.proj.bias": "transformer.h.{}.mlp.fc2.bias",
"transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.fc2.weight",
"transformer.ln_f.weight": "lm_head.ln.weight",
"transformer.ln_f.bias": "lm_head.ln.bias",
"lm_head.weight": "lm_head.linear.weight",
"lm_head.bias": "lm_head.linear.bias",
"transformer.wte.weight": "model.embed_tokens.weight",
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
"transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.dense.weight",
"transformer.h.{}.attn.proj.bias": "model.layers.{}.self_attn.dense.bias",
"transformer.h.{}.mlp.fc.weight": "model.layers.{}.mlp.fc1.weight",
"transformer.h.{}.mlp.fc.bias": "model.layers.{}.mlp.fc1.bias",
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.fc2.weight",
"transformer.h.{}.mlp.proj.bias": "model.layers.{}.mlp.fc2.bias",
"transformer.ln_f.weight": "model.final_layernorm.weight",
"transformer.ln_f.bias": "model.final_layernorm.bias",
"lm_head.weight": "lm_head.weight",
"lm_head.bias": "lm_head.bias",
}

for name, param in lit_weights.items():
if name.startswith("transformer.h."):
from_name, number = layer_template(name, 2)
to_name = weight_map[from_name].format(number)
if name.endswith((".attn.attn.weight", ".attn.attn.bias")):
from_name, l = layer_template(name, 2)
weight_type = name.split(".")[-1] # weight or bias
q = f"model.layers.{l}.self_attn.q_proj.{weight_type}"
k = f"model.layers.{l}.self_attn.k_proj.{weight_type}"
v = f"model.layers.{l}.self_attn.v_proj.{weight_type}"
qkv = load_param(param, name, None)
qp, kp, vp = qkv_split(qkv, config)
for to_name, param in zip((q, k, v), (qp, kp, vp)):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
else:
to_name = weight_map[name]
param = load_param(param, name, None)
if "attn.attn." in name:
param = torch.cat(qkv_split(param, config))
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
if "transformer.h" in name:
from_name, l = layer_template(name, 2)
to_name = weight_map[from_name]
to_name = to_name.format(l)
else:
to_name = weight_map[name]
param = load_param(param, name, None)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param


def qkv_split(
Expand Down Expand Up @@ -230,6 +240,8 @@ def convert_lit_checkpoint(checkpoint_path: Path, output_path: Path, config_path
copy_fn = partial(copy_weights_falcon, config.name)
elif config._mlp_class in ("LLaMAMLP", "LLaMAMoE"):
copy_fn = partial(copy_weights_llama, config)
elif "phi" in config.name:
copy_fn = partial(copy_weights_phi, config)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
6 changes: 0 additions & 6 deletions scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ def download_from_hub(
elif from_safetensors:
raise ValueError("`--from_safetensors=True` won't have an effect with `--tokenizer_only=True`")

# contains revisions that are known to work without issues
hf_model_revision_map = {
"microsoft/phi-1_5": "24f9ea14df973a49a0d87c16d04df88d90067468",
"microsoft/phi-2": "834565c23f9b28b96ccbeabe614dd906b6db551a",
}
directory = checkpoint_dir / repo_id
snapshot_download(
repo_id,
Expand All @@ -64,7 +59,6 @@ def download_from_hub(
resume_download=True,
allow_patterns=download_files,
token=access_token,
revision=hf_model_revision_map.get(repo_id),
)

# convert safetensors to PyTorch binaries
Expand Down
70 changes: 60 additions & 10 deletions tests/test_convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ def test_against_original_open_llama_3b():


@torch.inference_mode()
def test_against_hf_phi():
def test_against_hf_phi_1_5():
workdir = wd / "tests" / "reference_models"
workdir.mkdir(parents=True, exist_ok=True)
file_paths = [workdir / "original_phi_1_5.py", workdir / "configuration_phi.py"]
urls = [
"https://huggingface.co/microsoft/phi-1_5/raw/24f9ea14df973a49a0d87c16d04df88d90067468/modeling_phi.py",
"https://huggingface.co/microsoft/phi-1_5/raw/24f9ea14df973a49a0d87c16d04df88d90067468/configuration_phi.py",
"https://huggingface.co/microsoft/phi-1_5/raw/main/modeling_phi.py",
"https://huggingface.co/microsoft/phi-1_5/raw/main/configuration_phi.py",
]
for file_path, url in zip(file_paths, urls):
if not file_path.is_file():
Expand All @@ -258,14 +258,64 @@ def test_against_hf_phi():
)
T = 5
theirs_config = PhiConfig(
n_positions=ours_config.block_size,
n_embd=ours_config.n_embd,
n_head=ours_config.n_head,
n_layer=ours_config.n_layer,
rotary_dim=ours_config.rope_n_elem,
architecture={"block_cls": "parallel", "mixer": {}, "mlp": {"mlp_cls": "mlp"}},
vocab_size=ours_config.padded_vocab_size,
max_position_embeddings=ours_config.block_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
partial_rotary_factor=ours_config.rotary_percentage,
)

ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
theirs_state_dict = {}
copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)
theirs_model = PhiForCausalLM(theirs_config)
# strict=False because we don't save the rotary embeddings inv frequency
keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)
assert not keys.unexpected_keys
assert all("inv_freq" in k for k in keys.missing_keys)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)


@torch.inference_mode()
def test_against_hf_phi_2():
workdir = wd / "tests" / "reference_models"
workdir.mkdir(parents=True, exist_ok=True)
file_paths = [workdir / "original_phi_2.py", workdir / "configuration_phi.py"]
urls = [
"https://huggingface.co/microsoft/phi-2/raw/main/modeling_phi.py",
"https://huggingface.co/microsoft/phi-2/raw/main/configuration_phi.py",
]
for file_path, url in zip(file_paths, urls):
if not file_path.is_file():
urlretrieve(url=url, filename=file_path)

from lit_gpt import GPT, Config
from scripts.convert_lit_checkpoint import copy_weights_phi
from tests.reference_models.configuration_phi import PhiConfig
from tests.reference_models.original_phi_2 import PhiForCausalLM

ours_config = Config.from_name(
"phi-2", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
)
T = 5
theirs_config = PhiConfig(
vocab_size=ours_config.padded_vocab_size,
max_position_embeddings=ours_config.block_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
partial_rotary_factor=ours_config.rotary_percentage,
)
theirs_config.vocab_size = ours_config.padded_vocab_size

ours_model = GPT(ours_config)
ours_state_dict = ours_model.state_dict()
Expand Down
40 changes: 20 additions & 20 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ def test_against_hf_phi_1_5(device, dtype):
workdir.mkdir(parents=True, exist_ok=True)
file_paths = [workdir / "original_phi_1_5.py", workdir / "configuration_phi.py"]
urls = [
"https://huggingface.co/microsoft/phi-1_5/raw/24f9ea14df973a49a0d87c16d04df88d90067468/modeling_phi.py",
"https://huggingface.co/microsoft/phi-1_5/raw/24f9ea14df973a49a0d87c16d04df88d90067468/configuration_phi.py",
"https://huggingface.co/microsoft/phi-1_5/raw/main/modeling_phi.py",
"https://huggingface.co/microsoft/phi-1_5/raw/main/configuration_phi.py",
]
for file_path, url in zip(file_paths, urls):
if not file_path.is_file():
Expand All @@ -296,20 +296,20 @@ def test_against_hf_phi_1_5(device, dtype):
)
T = 5
theirs_config = PhiConfig(
n_positions=ours_config.block_size,
n_embd=ours_config.n_embd,
n_head=ours_config.n_head,
n_layer=ours_config.n_layer,
rotary_dim=ours_config.rope_n_elem,
architecture={"block_cls": "parallel", "mixer": {}, "mlp": {"mlp_cls": "mlp"}},
vocab_size=ours_config.padded_vocab_size,
max_position_embeddings=ours_config.block_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
partial_rotary_factor=ours_config.rotary_percentage,
torch_dtype=dtype,
)
theirs_config.vocab_size = ours_config.padded_vocab_size

theirs_model = PhiForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_phi(ours_config, state_dict, theirs_state_dict)
copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

Expand Down Expand Up @@ -338,8 +338,8 @@ def test_against_hf_phi_2(device, dtype):
workdir.mkdir(parents=True, exist_ok=True)
file_paths = [workdir / "original_phi_2.py", workdir / "configuration_phi.py"]
urls = [
"https://huggingface.co/microsoft/phi-2/raw/834565c23f9b28b96ccbeabe614dd906b6db551a/modeling_phi.py",
"https://huggingface.co/microsoft/phi-2/raw/834565c23f9b28b96ccbeabe614dd906b6db551a/configuration_phi.py",
"https://huggingface.co/microsoft/phi-2/raw/main/modeling_phi.py",
"https://huggingface.co/microsoft/phi-2/raw/main/configuration_phi.py",
]
for file_path, url in zip(file_paths, urls):
if not file_path.is_file():
Expand All @@ -357,20 +357,20 @@ def test_against_hf_phi_2(device, dtype):
)
T = 5
theirs_config = PhiConfig(
n_positions=ours_config.block_size,
n_embd=ours_config.n_embd,
n_head=ours_config.n_head,
n_layer=ours_config.n_layer,
rotary_dim=ours_config.rope_n_elem,
architecture={"block_cls": "parallel", "mixer": {}, "mlp": {"mlp_cls": "mlp"}},
vocab_size=ours_config.padded_vocab_size,
max_position_embeddings=ours_config.block_size,
hidden_size=ours_config.n_embd,
intermediate_size=ours_config.intermediate_size,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
partial_rotary_factor=ours_config.rotary_percentage,
torch_dtype=dtype,
)
theirs_config.vocab_size = ours_config.padded_vocab_size

theirs_model = PhiForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_phi(ours_config, state_dict, theirs_state_dict)
copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

Expand Down

0 comments on commit 423b1a8

Please sign in to comment.