Skip to content

Commit

Permalink
weight loading works
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 23, 2024
1 parent 1012eaf commit 581e27f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 36 deletions.
14 changes: 7 additions & 7 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Config:
shared_attention_norm: bool = False
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE", "Phi3MLP"] = "GptNeoxMLP"
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
Expand Down Expand Up @@ -1418,15 +1418,15 @@ def norm_class(self) -> Type:
name="Phi-3-mini-4k-instruct",
hf_config=dict(org="microsoft", name="microsoft/Phi-3-mini-4k-instruct"),
vocab_size=32064,
padded_vocab_size=32768,
padded_vocab_size=32064,
block_size=4096,
n_embd=3072,
n_layer=32,
rotary_percentage=1.0, # Double-check
# shared_attention_norm=True, # Double-check
bias=False, # Double-check
intermediate_size=11008,
mlp_class_name="LLaMAMLP", # Double-check
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=16384,
mlp_class_name="Phi3MLP",
),
]
configs.extend(phi)
Expand Down
14 changes: 14 additions & 0 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)


class Phi3MLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.gate_up_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.down_proj = nn.Linear(config.intermediate_size//2, config.n_embd, bias=config.bias)
self.config = config

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.gate_up_proj(x)
gate, y = y.chunk(2, dim=-1)
y = y * torch.nn.functional.silu(gate)
return self.down_proj(y)


class LLaMAMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
Expand Down
57 changes: 28 additions & 29 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,30 +147,31 @@ def copy_weights_hf_llama(
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight",
}
)
elif config.mlp_class_name in ("Phi3MLP",):
weight_map.update(
{
"model.layers.{}.mlp.gate_up_proj.weight": "transformer.h.{l}.mlp.gate_up_proj.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.down_proj.weight",
}
)
else:
raise NotImplementedError

if "phi-3" in config.name.lower():
weight_map = {
key.replace("mlp.up_proj.weight", "mlp.gate_up_proj.weight"): value
for key, value in weight_map.items()
}

for name, param in hf_weights.items():
if "model.layers" in name:
from_name, l = layer_template(name, 2)
e = None
if "block_sparse_moe.experts" in name:
from_name, e = layer_template(from_name, 5)
qkv = qkv_weights.setdefault(l, [None, None, None])
if "q_proj" in name:
if "qkv_proj" in name:
state_dict[f"transformer.h.{l}.attn.attn.weight"] = load_param(param, f"layer {l} qkv", dtype)
elif "q_proj" in name:
qkv[0] = param
elif "k_proj" in name:
qkv[1] = param
elif "v_proj" in name:
qkv[2] = param
elif "qkv_proj" in name:
qkv[:] = param
to_name = weight_map[from_name]
if to_name is None:
continue
Expand All @@ -186,21 +187,22 @@ def copy_weights_hf_llama(
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

# convert separate q, k, v matrices into an interleaved qkv
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
continue
q = load_param(q, f"layer {i} q", dtype)
k = load_param(k, f"layer {i} k", dtype)
v = load_param(v, f"layer {i} v", dtype)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
del qkv_weights[i]
if "qkv_proj" not in name:
for i, (q, k, v) in list(qkv_weights.items()):
if q is None or k is None or v is None:
# split across different .bin files
continue
q = load_param(q, f"layer {i} q", dtype)
k = load_param(k, f"layer {i} k", dtype)
v = load_param(v, f"layer {i} v", dtype)
q_per_kv = config.n_head // config.n_query_groups
qs = torch.split(q, config.head_size * q_per_kv)
ks = torch.split(k, config.head_size)
vs = torch.split(v, config.head_size)
cycled = [t for group in zip(qs, ks, vs) for t in group]
qkv = torch.cat(cycled)
state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv
del qkv_weights[i]


def copy_weights_phi(
Expand Down Expand Up @@ -321,13 +323,10 @@ def convert_hf_checkpoint(

if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, model_name)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE", "Phi3MLP"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi-3" in model_name.lower():
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down Expand Up @@ -366,4 +365,4 @@ def convert_hf_checkpoint(
if __name__ == "__main__":
from jsonargparse import CLI

CLI(convert_hf_checkpoint)
CLI(convert_hf_checkpoint)

0 comments on commit 581e27f

Please sign in to comment.