Skip to content

Commit

Permalink
fix styles
Browse files Browse the repository at this point in the history
  • Loading branch information
lzivan committed Dec 18, 2024
1 parent bbd466d commit a7b80ad
Showing 1 changed file with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
asym = model.lm_head_0.qtype == "asym_int4_rtn"
if asym:
weights = [(model.lm_head_0.weight, model.lm_head_0.scale, model.lm_head_0.zero),
(model.lm_head_1.weight, model.lm_head_1.scale, model.lm_head_1.zero)]
(model.lm_head_1.weight, model.lm_head_1.scale, model.lm_head_1.zero)]
else:
weights = [(model.lm_head_0.weight, model.lm_head_0.scale),
(model.lm_head_1.weight, model.lm_head_1.scale)]
(model.lm_head_1.weight, model.lm_head_1.scale)]
else:
# for MiniCPM-1B-sft-bf16
asym = model.lm_head.qtype == "asym_int4_rtn"
Expand All @@ -206,8 +206,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
zeros.append(l.zero)
if len(zeros):
weights.append((torch.stack(lm_head_weights, axis=0),
torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
else:
weights.append((torch.stack(lm_head_weights, axis=0),
torch.stack(scales, axis=0)))
Expand Down Expand Up @@ -249,9 +249,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
model.lm_head_1.zero.data.numpy(), ]
else:
if not asym:
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ]
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy()]
else:
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(),
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(),
model.lm_head.zero.data.numpy()]
else:
weight_numpy = [v.numpy() for v in weights[0]]
Expand Down Expand Up @@ -469,7 +469,8 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
# 6, 7 are past k/v
if not asym:
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
Expand Down

0 comments on commit a7b80ad

Please sign in to comment.