diff --git a/litgpt/config.py b/litgpt/config.py index 459ca4d560..c302bd8847 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -836,7 +836,7 @@ def norm_class(self) -> Type: copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) - + ############### # Meta LLaMA 3 @@ -1415,8 +1415,8 @@ def norm_class(self) -> Type: ), # https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json dict( - name="phi-3-mini-4k-instruct", - hf_config=dict(org="microsoft", name="Phi-3-mini-4k-instruct"), + name="Phi-3-mini-4k-instruct", + hf_config=dict(org="microsoft", name="microsoft/Phi-3-mini-4k-instruct"), vocab_size=32064, padded_vocab_size=32768, block_size=4096, @@ -1425,6 +1425,7 @@ def norm_class(self) -> Type: rotary_percentage=1.0, # Double-check # shared_attention_norm=True, # Double-check bias=False, # Double-check + intermediate_size=11008, mlp_class_name="LLaMAMLP", # Double-check ), ] diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 9fdb337ee4..9ff6209978 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -121,6 +121,7 @@ def copy_weights_hf_llama( "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.qkv_proj.weight": None, "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight", @@ -149,6 +150,12 @@ def copy_weights_hf_llama( else: raise NotImplementedError + if "phi-3" in config.name.lower(): + weight_map = { + key.replace("mlp.up_proj.weight", "mlp.gate_up_proj.weight"): value + for key, value in weight_map.items() + } + for name, param in hf_weights.items(): if "model.layers" in name: from_name, l = layer_template(name, 2) @@ -162,6 +169,8 @@ def copy_weights_hf_llama( qkv[1] = param elif "v_proj" in name: qkv[2] = param + elif "qkv_proj" in name: + qkv[:] = param to_name = weight_map[from_name] if to_name is None: continue @@ -316,6 +325,9 @@ def convert_hf_checkpoint( # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) + elif "phi-3" in model_name.lower(): + qkv_weights = {} + copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) elif "phi" in model_name: # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 45c9c7d50c..f5d33d0b60 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -112,6 +112,7 @@ meta-llama/Meta-Llama-3-8B meta-llama/Meta-Llama-3-8B-Instruct microsoft/phi-1_5 microsoft/phi-2 +microsoft/Phi-3-mini-4k-instruct mistralai/Mistral-7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2 mistralai/Mistral-7B-v0.1