From 1a2ab128763016c038c73db3a18e2dd32d501020 Mon Sep 17 00:00:00 2001 From: Zijie Li Date: Tue, 17 Dec 2024 21:55:35 -0500 Subject: [PATCH] [NPU] support asym_int4 for minicpm (#12567) --- .../transformers/npu_models/minicpm_mp.py | 48 ++++-- .../npu_pipeline_model/minicpm.py | 138 ++++++++++++++---- 2 files changed, 146 insertions(+), 40 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py index bc0df95111e..2a34ae19547 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py @@ -81,7 +81,8 @@ def __init__( num_hidden_layers, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + asym: bool = False, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -90,7 +91,8 @@ def __init__( device=device, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size) + group_size=group_size, + asym=asym) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -272,7 +274,8 @@ def __init__( do_print: bool = False, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + asym: bool = False, ): super().__init__() @@ -280,8 +283,10 @@ def __init__( op_parameters = [] for w in parameters: - if isinstance(w, tuple): # from QuantizedLinear + if isinstance(w, tuple) and not asym: # from QuantizedLinear op_parameters.append((w[0].numpy(), w[1].numpy())) + elif isinstance(w, tuple) and asym: # from QuantizedLinear + op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy())) elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight op_parameters.append(w.numpy()) elif isinstance(w, np.ndarray): # scale @@ -336,7 +341,8 @@ def __init__( dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym, ) self.backend_decoders.append(decoder) @@ -414,7 +420,8 @@ def __init__( transpose_value: bool = False, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + asym: bool = False, ): super().__init__() self.op_parameters = parameters @@ -447,7 +454,8 @@ def __init__( dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym, ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -534,6 +542,7 @@ def run_decode( layer_indexs = range(layer_start, layer_end) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + asym = getattr(model.config, "asym", False) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn @@ -546,10 +555,17 @@ def run_decode( mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -580,7 +596,8 @@ def run_decode( do_print=False, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym, ) dist.barrier() @@ -753,6 +770,7 @@ def run_prefill( layer_indexs = range(layer_start, layer_end) n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + asym = getattr(model.config, "asym", False) for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn @@ -765,10 +783,17 @@ def run_prefill( mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -793,7 +818,8 @@ def run_prefill( transpose_value=transpose_value_cache, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym ) layer_weights.extend(weights) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py index 1893db9c963..9e89584a035 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py @@ -105,6 +105,7 @@ def __init__( profile: bool = False, device: str = "NPU", n_splits: int = 1, + asym: bool = False ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -134,11 +135,13 @@ def __init__( # for MiniCPM-2B-sft-bf16 hidden_states_1 = self.linear( hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype, - n_splits=n_splits, scale_factor=(n_splits == 1) + n_splits=n_splits, scale_factor=(n_splits == 1), + asym=asym ) hidden_states_2 = self.linear( hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype, - n_splits=n_splits, scale_factor=(n_splits == 1) + n_splits=n_splits, scale_factor=(n_splits == 1), + asym=asym ) hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313]) @@ -147,7 +150,8 @@ def __init__( # for MiniCPM-1B-sft-bf16 hidden_states = self.linear( hidden_states, self.vocab_size, self.hidden_size, bias=False, - wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1) + wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1), + asym=asym ) # define outputs @@ -165,28 +169,48 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, rms_norm_eps = model.config.rms_norm_eps vocab_size = model.config.vocab_size model_norm = model.model.norm + asym = getattr(model.config, "asym", False) if n_splits_linear == 1: if vocab_size == 122753: # for MiniCPM-2B-sft-bf16 - weights = [(model.lm_head_0.weight, model.lm_head_0.scale), - (model.lm_head_1.weight, model.lm_head_1.scale)] + asym = model.lm_head_0.qtype == "asym_int4_rtn" + if asym: + weights = [(model.lm_head_0.weight, model.lm_head_0.scale, model.lm_head_0.zero), + (model.lm_head_1.weight, model.lm_head_1.scale, model.lm_head_1.zero)] + else: + weights = [(model.lm_head_0.weight, model.lm_head_0.scale), + (model.lm_head_1.weight, model.lm_head_1.scale)] else: # for MiniCPM-1B-sft-bf16 - weights = [(model.lm_head.weight, model.lm_head.scale)] + asym = model.lm_head.qtype == "asym_int4_rtn" + if asym: + weights = [(model.lm_head.weight, model.lm_head.scale, model.lm_head.zero)] + else: + weights = [(model.lm_head.weight, model.lm_head.scale)] else: weights = [] if vocab_size == 122753: + asym = model.lm_head_0.lm_heads[0].qtype == "asym_int4_rtn" lm_head_list = [model.lm_head_0.lm_heads, model.lm_head_1.lm_heads] else: + asym = model.lm_head.lm_heads[0].qtype == "asym_int4_rtn" lm_head_list = [model.lm_head.lm_heads] for lh in lm_head_list: lm_head_weights = [] scales = [] + zeros = [] for l in lh: lm_head_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(lm_head_weights, axis=0), - torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0))) if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear @@ -202,7 +226,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, dtype=np_dtype, model_norm_weight=model_norm.weight.to(torch.float16), vocab_size=vocab_size, - n_splits=n_splits_linear + n_splits=n_splits_linear, + asym=asym ) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir, True, True) @@ -210,12 +235,24 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, # save weights bins files if n_splits_linear == 1: if vocab_size == 122753: - weight_numpy = [model.lm_head_0.weight.data.numpy(), - model.lm_head_0.scale.data.numpy(), - model.lm_head_1.weight.data.numpy(), - model.lm_head_1.scale.data.numpy(), ] + if not asym: + weight_numpy = [model.lm_head_0.weight.data.numpy(), + model.lm_head_0.scale.data.numpy(), + model.lm_head_1.weight.data.numpy(), + model.lm_head_1.scale.data.numpy(), ] + else: + weight_numpy = [model.lm_head_0.weight.data.numpy(), + model.lm_head_0.scale.data.numpy(), + model.lm_head_0.zero.data.numpy(), + model.lm_head_1.weight.data.numpy(), + model.lm_head_1.scale.data.numpy(), + model.lm_head_1.zero.data.numpy(), ] else: - weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ] + if not asym: + weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy()] + else: + weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), + model.lm_head.zero.data.numpy()] else: weight_numpy = [v.numpy() for v in weights[0]] if vocab_size == 122753: @@ -266,6 +303,7 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, rms_norm_eps = model.config.rms_norm_eps num_hidden_layers = model.config.num_hidden_layers scale_depth = model.model.config.scale_depth + asym = getattr(model.config, "asym", False) from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer curr_layer = model.model.layers[layer_idx] @@ -279,10 +317,17 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -321,7 +366,8 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, decoder_name, @@ -337,11 +383,23 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, layer_norm_0.data.numpy().tofile(input_lm_bin_file) layer_norm_1.data.numpy().tofile(post_lm_bin_file) st_idx = 7 - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") - scale.numpy().tofile(bin_file) + if not asym: + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + else: + for idx, (weight, scale, zero) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*3}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin") + scale.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin") + zero.numpy().tofile(bin_file) + del single_decoder @@ -357,6 +415,7 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d scale_depth = model.model.config.scale_depth layer_num = len(model.model.layers) fused_layer_num = layer_num // fused_layers + asym = getattr(model.config, "asym", False) from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer for i in range(fused_layers): @@ -380,10 +439,17 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d mlp_layer.down_proj_dq_list]: l_weights = [] scales = [] + zeros = [] for l in layer_list: l_weights.append(l.weight) scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -401,12 +467,25 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d layer_norm_1.data.numpy().tofile(post_lm_bin_file) st_idx = 5 # 6, 7 are past k/v - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, - f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") - scale.numpy().tofile(bin_file) + if not asym: + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + else: + for idx, (weight, scale, zero) in enumerate(weights): + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin") + scale.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, + f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin") + zero.numpy().tofile(bin_file) if isinstance(weights[0], tuple): np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 @@ -432,7 +511,8 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + asym=asym ) update_names_of_IR_and_export_blob(fused_decoder, f"decoder_layer_{i}",