diff --git a/litgpt/config.py b/litgpt/config.py index c302bd8847..cb8a1de4d8 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -54,7 +54,7 @@ class Config: shared_attention_norm: bool = False norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" norm_eps: float = 1e-5 - mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" + mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE", "Phi3MLP"] = "GptNeoxMLP" gelu_approximate: str = "none" intermediate_size: Optional[int] = None rope_condense_ratio: int = 1 @@ -1418,15 +1418,15 @@ def norm_class(self) -> Type: name="Phi-3-mini-4k-instruct", hf_config=dict(org="microsoft", name="microsoft/Phi-3-mini-4k-instruct"), vocab_size=32064, - padded_vocab_size=32768, + padded_vocab_size=32064, block_size=4096, n_embd=3072, n_layer=32, - rotary_percentage=1.0, # Double-check - # shared_attention_norm=True, # Double-check - bias=False, # Double-check - intermediate_size=11008, - mlp_class_name="LLaMAMLP", # Double-check + rotary_percentage=1.0, + bias=False, + norm_class_name="RMSNorm", + intermediate_size=16384, + mlp_class_name="Phi3MLP", ), ] configs.extend(phi) diff --git a/litgpt/model.py b/litgpt/model.py index fe71c60b80..515ac18ce5 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -298,6 +298,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) +class Phi3MLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.gate_up_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.down_proj = nn.Linear(config.intermediate_size//2, config.n_embd, bias=config.bias) + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + y = self.gate_up_proj(x) + gate, y = y.chunk(2, dim=-1) + y = y * torch.nn.functional.silu(gate) + return self.down_proj(y) + + class LLaMAMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 9ff6209978..61182bbe39 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -147,15 +147,16 @@ def copy_weights_hf_llama( "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", } ) + elif config.mlp_class_name in ("Phi3MLP",): + weight_map.update( + { + "model.layers.{}.mlp.gate_up_proj.weight": "transformer.h.{l}.mlp.gate_up_proj.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.down_proj.weight", + } + ) else: raise NotImplementedError - if "phi-3" in config.name.lower(): - weight_map = { - key.replace("mlp.up_proj.weight", "mlp.gate_up_proj.weight"): value - for key, value in weight_map.items() - } - for name, param in hf_weights.items(): if "model.layers" in name: from_name, l = layer_template(name, 2) @@ -163,14 +164,14 @@ def copy_weights_hf_llama( if "block_sparse_moe.experts" in name: from_name, e = layer_template(from_name, 5) qkv = qkv_weights.setdefault(l, [None, None, None]) - if "q_proj" in name: + if "qkv_proj" in name: + state_dict[f"transformer.h.{l}.attn.attn.weight"] = load_param(param, f"layer {l} qkv", dtype) + elif "q_proj" in name: qkv[0] = param elif "k_proj" in name: qkv[1] = param elif "v_proj" in name: qkv[2] = param - elif "qkv_proj" in name: - qkv[:] = param to_name = weight_map[from_name] if to_name is None: continue @@ -186,21 +187,22 @@ def copy_weights_hf_llama( state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] # convert separate q, k, v matrices into an interleaved qkv - for i, (q, k, v) in list(qkv_weights.items()): - if q is None or k is None or v is None: - # split across different .bin files - continue - q = load_param(q, f"layer {i} q", dtype) - k = load_param(k, f"layer {i} k", dtype) - v = load_param(v, f"layer {i} v", dtype) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv - del qkv_weights[i] + if "qkv_proj" not in name: + for i, (q, k, v) in list(qkv_weights.items()): + if q is None or k is None or v is None: + # split across different .bin files + continue + q = load_param(q, f"layer {i} q", dtype) + k = load_param(k, f"layer {i} k", dtype) + v = load_param(v, f"layer {i} v", dtype) + q_per_kv = config.n_head // config.n_query_groups + qs = torch.split(q, config.head_size * q_per_kv) + ks = torch.split(k, config.head_size) + vs = torch.split(v, config.head_size) + cycled = [t for group in zip(qs, ks, vs) for t in group] + qkv = torch.cat(cycled) + state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv + del qkv_weights[i] def copy_weights_phi( @@ -321,13 +323,10 @@ def convert_hf_checkpoint( if "falcon" in model_name: copy_fn = partial(copy_weights_falcon, model_name) - elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): + elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE", "Phi3MLP"): # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) - elif "phi-3" in model_name.lower(): - qkv_weights = {} - copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) elif "phi" in model_name: # holder to reconstitute the split q, k, v qkv_weights = {} @@ -366,4 +365,4 @@ def convert_hf_checkpoint( if __name__ == "__main__": from jsonargparse import CLI - CLI(convert_hf_checkpoint) + CLI(convert_hf_checkpoint) \ No newline at end of file