Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 23, 2024
1 parent 93f3024 commit 1012eaf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
7 changes: 4 additions & 3 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def norm_class(self) -> Type:
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)


###############
# Meta LLaMA 3
Expand Down Expand Up @@ -1415,8 +1415,8 @@ def norm_class(self) -> Type:
),
# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
dict(
name="phi-3-mini-4k-instruct",
hf_config=dict(org="microsoft", name="Phi-3-mini-4k-instruct"),
name="Phi-3-mini-4k-instruct",
hf_config=dict(org="microsoft", name="microsoft/Phi-3-mini-4k-instruct"),
vocab_size=32064,
padded_vocab_size=32768,
block_size=4096,
Expand All @@ -1425,6 +1425,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0, # Double-check
# shared_attention_norm=True, # Double-check
bias=False, # Double-check
intermediate_size=11008,
mlp_class_name="LLaMAMLP", # Double-check
),
]
Expand Down
12 changes: 12 additions & 0 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def copy_weights_hf_llama(
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.qkv_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight",
Expand Down Expand Up @@ -149,6 +150,12 @@ def copy_weights_hf_llama(
else:
raise NotImplementedError

if "phi-3" in config.name.lower():
weight_map = {
key.replace("mlp.up_proj.weight", "mlp.gate_up_proj.weight"): value
for key, value in weight_map.items()
}

for name, param in hf_weights.items():
if "model.layers" in name:
from_name, l = layer_template(name, 2)
Expand All @@ -162,6 +169,8 @@ def copy_weights_hf_llama(
qkv[1] = param
elif "v_proj" in name:
qkv[2] = param
elif "qkv_proj" in name:
qkv[:] = param
to_name = weight_map[from_name]
if to_name is None:
continue
Expand Down Expand Up @@ -316,6 +325,9 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi-3" in model_name.lower():
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down
1 change: 1 addition & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
microsoft/phi-1_5
microsoft/phi-2
microsoft/Phi-3-mini-4k-instruct
mistralai/Mistral-7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistralai/Mistral-7B-v0.1
Expand Down

0 comments on commit 1012eaf

Please sign in to comment.