Skip to content

Commit

Permalink
[NPU] support asym_int4 for minicpm (#12567)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzivan authored Dec 18, 2024
1 parent 6e801bc commit 1a2ab12
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 40 deletions.
48 changes: 37 additions & 11 deletions python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def __init__(
num_hidden_layers,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
group_size: int = 0,
asym: bool = False,
):
super().__init__(max_seq_len=max_seq_len,
transpose_value=transpose_value,
Expand All @@ -90,7 +91,8 @@ def __init__(
device=device,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size)
group_size=group_size,
asym=asym)
self.max_seq_len = max_seq_len
self.intermediate_size = intermediate_size
self.dtype = dtype
Expand Down Expand Up @@ -272,16 +274,19 @@ def __init__(
do_print: bool = False,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
group_size: int = 0,
asym: bool = False,
):
super().__init__()

self.do_print = do_print

op_parameters = []
for w in parameters:
if isinstance(w, tuple): # from QuantizedLinear
if isinstance(w, tuple) and not asym: # from QuantizedLinear
op_parameters.append((w[0].numpy(), w[1].numpy()))
elif isinstance(w, tuple) and asym: # from QuantizedLinear
op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy()))
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
op_parameters.append(w.numpy())
elif isinstance(w, np.ndarray): # scale
Expand Down Expand Up @@ -336,7 +341,8 @@ def __init__(
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
group_size=group_size,
asym=asym,
)
self.backend_decoders.append(decoder)

Expand Down Expand Up @@ -414,7 +420,8 @@ def __init__(
transpose_value: bool = False,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
group_size: int = 0,
asym: bool = False,
):
super().__init__()
self.op_parameters = parameters
Expand Down Expand Up @@ -447,7 +454,8 @@ def __init__(
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
group_size=group_size,
asym=asym,
)
self.layer_norm_0 = layer_norm_0
self.layer_norm_1 = layer_norm_1
Expand Down Expand Up @@ -534,6 +542,7 @@ def run_decode(
layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
asym = getattr(model.config, "asym", False)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
Expand All @@ -546,10 +555,17 @@ def run_decode(
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
zeros = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if l.zero is not None:
zeros.append(l.zero)
if len(zeros):
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
else:
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
Expand Down Expand Up @@ -580,7 +596,8 @@ def run_decode(
do_print=False,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
group_size=group_size,
asym=asym,
)

dist.barrier()
Expand Down Expand Up @@ -753,6 +770,7 @@ def run_prefill(
layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
asym = getattr(model.config, "asym", False)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
Expand All @@ -765,10 +783,17 @@ def run_prefill(
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
zeros = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if l.zero is not None:
zeros.append(l.zero)
if len(zeros):
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
torch.stack(zeros, axis=0)))
else:
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
Expand All @@ -793,7 +818,8 @@ def run_prefill(
transpose_value=transpose_value_cache,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
group_size=group_size,
asym=asym
)

layer_weights.extend(weights)
Expand Down
Loading

0 comments on commit 1a2ab12

Please sign in to comment.