From 76bebc0d6a26af85d1c985ad338a2b783a24c03f Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Thu, 19 Dec 2024 16:19:14 +0800 Subject: [PATCH 1/4] be compatible with auto-round --- .../npu_pipeline_model/convert_pipeline.py | 7 +- .../transformers/npu_pipeline_model/qwen.py | 119 +++++++++++------- 2 files changed, 81 insertions(+), 45 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 41c35b095e6..42c15e104c1 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -205,9 +205,10 @@ def convert_llm(model: torch.nn.Module, # do not split mlp down_proj for Qwen2-7B & sym_int8 n_splits_down_proj = 1 else: - n_splits_down_proj = 2 if (model.config.intermediate_size == 18944 or - os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or - os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1") else 1 + # n_splits_down_proj = 2 if (model.config.intermediate_size == 18944 or + # os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or + # os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1") else 1 + n_splits_down_proj = 1 # for auto-round test else: n_splits_linear = model.config.hidden_size // group_size n_splits_down_proj = model.config.intermediate_size // group_size diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index bb8003f06a7..581a5eb75e3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -134,27 +134,42 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, mlp_layer = curr_layer.mlp weights = [] - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, - mlp_layer.down_proj_dq_list]: - l_weights = [] - scales = [] - zeros = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - if l.zero is not None: - zeros.append(l.zero) - if len(zeros): - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), - torch.stack(zeros, axis=0))) - else: - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if hasattr(attn_layer, "q_proj_dq_list"): + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + zeros = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + else: + for layer in [attn_layer.q_proj, attn_layer.k_proj, + attn_layer.v_proj, attn_layer.o_proj, + mlp_layer.gate_proj, mlp_layer.up_proj, + mlp_layer.down_proj]: + if layer.zero is not None: + weights.append((layer.weight, layer.scale, layer.zero)) + else: + weights.append((layer.weight, layer.scale)) - q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16) - k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16) - v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16) + if hasattr(attn_layer, "q_proj_dq_list"): + q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16) + k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16) + v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16) + else: + q_bias = attn_layer.q_proj.bias.to(torch.float16) + k_bias = attn_layer.k_proj.bias.to(torch.float16) + v_bias = attn_layer.v_proj.bias.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) @@ -263,31 +278,46 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down k_biases = [] v_biases = [] layer_indexs = range(layer_start, layer_end) - n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) - n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + if hasattr(model.model.layers[0].mlp, "gate_proj_dq_list"): + n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list) + n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list) + else: + n_splits_linear = 1 + n_splits_down_proj = 1 + for layer_idx in layer_indexs: curr_layer = model.model.layers[layer_idx] attn_layer = curr_layer.self_attn mlp_layer = curr_layer.mlp weights = [] - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, - mlp_layer.down_proj_dq_list]: - l_weights = [] - scales = [] - zeros = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - if l.zero is not None: - zeros.append(l.zero) - if len(zeros): - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), - torch.stack(zeros, axis=0))) - else: - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + if hasattr(attn_layer, "q_proj_dq_list"): + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list, + mlp_layer.down_proj_dq_list]: + l_weights = [] + scales = [] + zeros = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + if l.zero is not None: + zeros.append(l.zero) + if len(zeros): + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), + torch.stack(zeros, axis=0))) + else: + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + else: + for layer in [attn_layer.q_proj, attn_layer.k_proj, + attn_layer.v_proj, attn_layer.o_proj, + mlp_layer.gate_proj, mlp_layer.up_proj, + mlp_layer.down_proj]: + if layer.zero is not None: + weights.append((layer.weight, layer.scale, layer.zero)) + else: + weights.append((layer.weight, layer.scale)) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) @@ -297,9 +327,14 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down layer_weights.extend(weights) input_layer_norm_weights.append(layer_norm_0) post_attn_layernorm_weights.append(layer_norm_1) - q_biases.append(attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)) - k_biases.append(attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)) - v_biases.append(attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)) + if hasattr(attn_layer, "q_proj_dq_list"): + q_biases.append(attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)) + k_biases.append(attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)) + v_biases.append(attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)) + else: + q_biases.append(attn_layer.q_proj.bias.to(torch.float16)) + k_biases.append(attn_layer.k_proj.bias.to(torch.float16)) + v_biases.append(attn_layer.v_proj.bias.to(torch.float16)) # save weight input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") From c37a180ed63ca803dc731ae97e4efaef7584e36b Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 24 Dec 2024 21:49:22 +0800 Subject: [PATCH 2/4] initial auto-round support integration --- .../npu_models/auto_round_patch.py | 455 ++++++++++++++++++ .../transformers/npu_models/common.py | 8 + .../npu_models/convert_auto_round_model.py | 184 +++++++ .../npu_pipeline_model/convert_pipeline.py | 11 +- 4 files changed, 654 insertions(+), 4 deletions(-) create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py diff --git a/python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py b/python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py new file mode 100644 index 00000000000..141055c8cd5 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py @@ -0,0 +1,455 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/intel/auto-round/blob/main/auto_round/auto_quantizer.py +# and +# https://github.com/intel/auto-round/blob/main/auto_round/backend.py +# which is licensed under Apache License 2.0: +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import importlib +import torch.nn as nn +from transformers.utils.versions import require_version +from transformers.pytorch_utils import Conv1D +from logging import getLogger +from typing import Union + +logger = getLogger(__name__) + +import auto_round + +def check_compatible(backend_name, device, bits, group_size, sym, packing_format, in_features, out_features, + check_requirements=True): + """Checks if the given configuration is compatible with the specified backend. + + Args: + backend_name (str): The name of the backend to check compatibility for. + device (str): The device on which the backend operates (e.g., 'cuda', 'cpu'). + bits (int): The bit-width of the quantization (e.g., 2, 4, 8). + group_size (Optional[int]): The size of the quantization group. Can be None if + not required by the backend. + sym (bool): Whether symmetric quantization is required (True for symmetric). + packing_format (str): The packing format used by the backend (e.g., 'triton'). + in_features (int): The number of input features for the model layer. + out_features (int): The number of output features for the model layer. + check_requirements (bool): Whether check the requirement + + Returns: + bool: True if the configuration is compatible with the backend, False otherwise. + + Raises: + KeyError: If the backend_name is not found in BackendInfos. + + Compatibility checks: + - Device must match one of the backend's supported devices. + - Bit-width must be supported by the backend. + - If group_size is required by the backend, it must match. + - Symmetric or asymmetric quantization must be supported. + - If the packing format matches exactly, all feature checks must pass. + - If the packing format does not match, it must be convertible. + """ + backend = auto_round.backend.BackendInfos[backend_name] + + # Check if device is supported by the backend + if not device in backend.device: + return False + + # Check if bit-width is supported + if bits not in backend.bits: + return False + + # Check if group_size is valid (if required by backend) + if backend.group_size is not None and group_size not in backend.group_size: + return False + + # Check if symmetric/asymmetric quantization is supported + if sym not in backend.sym: + return False + + # Check packing format and apply feature checks + if packing_format == backend.packing_format: + for check in backend.feature_checks: + if not check(in_features, out_features): + return False + + # Check if the format is convertible when packing formats differ + if packing_format != backend.packing_format and packing_format not in backend.convertable_format: + return False + + if check_requirements and backend.requirements is not None: + for requirement in backend.requirements: + if isinstance(requirement, str): + try: + require_version(requirement) + except ImportError: + return False + else: + res, _ = requirement() + return res + + return True + + +def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_features, out_features): + """Selects the most suitable backend for the layer based on compatibility and priority. + + This function first checks if the specified backend supports the layer with the provided configuration. + If not, it iterates through other available backends, + checking compatibility and returning the one with the highest priority. + + Args: + device (str): + The device on which the layer will run, e.g., 'cpu', 'cuda'. + backend (str): + The target backend to be used for this layer. + orig_backend (str): + The original backend from which packing format information is retrieved. + bits (int): + The number of bits used for quantization. + group_size (int): + The group size for quantization. + sym (bool): + Whether symmetric quantization is enabled. + in_features (int): + The number of input features for the layer. + out_features (int): + The number of output features for the layer. + + Returns: + str: + The selected backend that is compatible with the layer configuration. + + Raises: + AssertionError: + If the specified backend is not supported. + ValueError: + If no compatible backend is found for the given layer configuration. + """ + # Check if the provided backend is in BackendInfos + assert backend in auto_round.backend.BackendInfos.keys(), \ + f"Unsupported backend {backend}, please set it to `auto` to try automatic selection" + + packing_format = auto_round.backend.BackendInfos[orig_backend].packing_format + + # Check if the provided backend supports the layer configuration + if check_compatible(backend, device, bits, group_size, sym, packing_format, in_features, out_features): + return backend + + # Find and store other compatible backends + supported_backends = [] + for key in auto_round.backend.BackendInfos.keys(): + if key == backend: + continue + if check_compatible(key, device, bits, group_size, sym, packing_format, in_features, out_features): + supported_backends.append(key) + + # Raise an error if no compatible backends are found + if len(supported_backends) == 0: + supported_backends_need_package = [] + for key in auto_round.backend.BackendInfos.keys(): + if check_compatible(key, device, bits, group_size, sym, packing_format, in_features, out_features, + check_requirements=False): + supported_backends_need_package.append(key) + + if len(supported_backends_need_package) > 0: + supported_backends_need_package = sorted(supported_backends_need_package, + key=lambda support_backend: auto_round.backend.BackendInfos[support_backend].priority, + reverse=True) + backend_info = auto_round.backend.BackendInfos[supported_backends_need_package[0]] + # ipex-llm change start + # logger.error("please install all the following packages to support inference") + for requirement in backend_info.requirements: + if isinstance(requirement, str) and not requirement.startswith("intel-extension-for-"): + try: + require_version(requirement) + except ImportError: + logger.error(f"pip install {requirement}") + elif not requirement.startswith("intel-extension-for-"): + str_info = requirement()[1] + logger.error(str_info) + if not requirement.startswith("intel-extension-for-"): + exit(-1) + + # raise ValueError(f"None of the backends support this layer") + # ipex-llm change end + + # Sort the compatible backends by priority and return the one with the highest priority + supported_backends = sorted(supported_backends, key=lambda support_backend: auto_round.backend.BackendInfos[support_backend].priority, + reverse=True) + + # ipex-llm change start + try: + return supported_backends[0] + except: + return "ipex_gptq" + # ipex-llm change end + +import auto_round.backend +auto_round.backend.get_layer_backend = get_layer_backend +auto_round.backend.check_compatible = check_compatible + +importlib.reload(auto_round.backend) + +from auto_round.utils import (get_block_names, get_module, set_module, + get_multimodal_block_names, find_matching_blocks) + +def cpu_post_init(self, model): + return model + + +def convert_model(self, model: nn.Module): + """Converts the given model to an AutoRound model by replacing its layers with quantized layers. + + This method extracts the quantization configuration from the model and adjusts its layers + according to the specified quantization parameters. It supports different backends and + ensures that the model's data type is compatible with the selected hardware. + + Args: + model (nn.Module): + The model to be converted into an AutoRound model. + + Returns: + nn.Module: + The converted AutoRound model with quantized layers. + + Raises: + ValueError: + If the quantization backend is not specified in the configuration. + """ + + from auto_round.utils import get_layer_names_in_block + + quantization_config = model.config.quantization_config + if not hasattr(quantization_config, "target_backend"): + quantization_config.target_backend = quantization_config.backend + + target_device = self.detect_device(quantization_config.target_backend, quantization_config.backend) + self.target_device = target_device + + if hasattr(quantization_config, "backend"): # pragma: no cover + if ("hpu" == target_device or "cpu" == target_device) and model.dtype != torch.bfloat16: + # ipex-llm code change start + # model = model.to(torch.bfloat16) + model = model.to(torch.float16) + # ipex-llm code change end + else: + if model.dtype != torch.float16: + model = model.to(torch.float16) + + bits = quantization_config.bits + group_size = quantization_config.group_size + data_type = quantization_config.data_type if hasattr(quantization_config, + "data_type") else "int" # pragma: no cover + sym = quantization_config.sym + to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config, + "to_quant_block_names") else None + quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config, + "quant_block_list") else None + if to_quant_block_names is None: # TODO check compatibility + all_blocks = get_block_names(model) + else: + all_blocks = get_multimodal_block_names(model, quant_vision=True) + if quant_block_list is None: + quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names) + layer_names = get_layer_names_in_block(model, quant_block_list=quant_block_list) + + extra_config = {} + if hasattr(quantization_config, "extra_config"): + extra_config = quantization_config.extra_config + + layer_names += extra_config.keys() + layer_names = list(set(layer_names)) + + layer_configs = {} + for layer_name in layer_names: + layer_configs[layer_name] = {} + if layer_name not in extra_config: + layer_configs[layer_name]["bits"] = bits + layer_configs[layer_name]["group_size"] = group_size + layer_configs[layer_name]["data_type"] = data_type + layer_configs[layer_name]["sym"] = sym + layer_configs[layer_name]["clip"] = False + else: + layer_configs[layer_name]["bits"] = extra_config[layer_name].get("bits", bits) + layer_configs[layer_name]["group_size"] = extra_config[layer_name].get("group_size", group_size) + layer_configs[layer_name]["data_type"] = extra_config[layer_name].get("data_type", data_type) + layer_configs[layer_name]["sym"] = extra_config[layer_name].get("sym", sym) + layer_configs[layer_name]["clip"] = extra_config[layer_name].get("clip", False) + + if hasattr(quantization_config, "backend"): # pragma: no cover + backend = quantization_config.backend + elif 'gptq' in quantization_config.quant_method: # pragma: no cover + backend = 'gptq' + else: # pragma: no cover + raise ValueError("Quantization backend must be specified.") + + self._replace_by_quant_layers(model, layer_configs, quantization_config.target_backend, target_device, backend) + return model + + +def get_device(obj: Union[torch.Tensor, nn.Module]): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + +def _replace_by_quant_layers(self, module: nn.Module, layer_configs, target_backend, target_device, orig_backend): + """Replaces linear layers in the given module with quantized layers. + + This method iterates over the specified layer configurations and replaces + the original layers in the module with instances of `QuantLinear`. It handles + various layer types and ensures that the correct quantization parameters are applied. + + Args: + module (nn.Module): + The module containing layers to be quantized. + layer_configs (dict): + A dictionary containing configuration for each layer's quantization. + target_backend (str): + The backend to use for quantization, which includes device and format information. + target_device (str): + The device on which the model will run (e.g., 'cuda', 'cpu', 'hpu'). + orig_backend (str): + The original backend of the packing. + + Raises: + AssertionError: + If any condition related to backend or quantization configuration is not met. + """ + # ipex-llm code change start + from auto_round.backend import dynamic_import_inference_linear + # ipex-llm code change end + + def remove_device_str(s, device_str): + if s and s.startswith(device_str): + return s[len(device_str):].lstrip(":") + return s + + if "auto" == target_backend.split(':')[0]: + target_backend = target_backend[4:] # Remove 'auto' + if len(target_backend) >= 1 and target_backend[0] == ":": + target_backend = target_backend[1:] + + # Remove device info from target_backend + target_backend = remove_device_str(target_backend, "cpu") + target_backend = remove_device_str(target_backend, "hpu") + target_backend = remove_device_str(target_backend, "cuda") + orig_backend = self.find_backend(orig_backend) + + if target_backend == "": + target_backend = orig_backend + + self.need_marlin_repacking = False + + for layer_name in layer_configs.keys(): + config = layer_configs[layer_name] + bits = config["bits"] + group_size = config["group_size"] + data_type = config["data_type"] + sym = config["sym"] + clip = config["clip"] + + if not (bits <= 8): + continue + + layer = get_module(module, layer_name) + if isinstance(layer, nn.Linear): + in_features = layer.in_features + out_features = layer.out_features + elif isinstance(layer, nn.Conv2d): # Not supported currently + in_features = layer.in_channels + out_features = layer.out_channels + elif isinstance(layer, Conv1D): # TODO: Needs verification + in_features = layer.weight.shape[0] + out_features = layer.weight.shape[1] + else: + continue + + if "marlin" in target_backend and "marlin" not in orig_backend: + # Need to repack + assert sym == True, "Marlin only supports symmetric quantization" + assert target_device == "cuda", "Marlin only supports CUDA device" + assert not "awq" in orig_backend, "Marlin does not support repacking from AWQ format" + self.need_marlin_repacking = True + # Using original backend to load the layer then replace + layer_backend = orig_backend + else: + target_backend = self.find_backend(target_backend) # TODO: Move out if have supported marlin + layer_backend = get_layer_backend( + target_device, target_backend, orig_backend, bits, group_size, sym, in_features, out_features + ) + if "gptq" in layer_backend and "exllamav2" in layer_backend: + try: + from exllamav2_kernels import gemm_half_q_half, make_q_matrix # pylint: disable=E0611 + except: + logger.warning_once( + "For better inference performance, please install exllamav2 kernel " + "via `pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@b8b4127`") + + QuantLinear = dynamic_import_inference_linear(layer_backend, bits, group_size, sym) + + layer_device = get_device(layer) + + bias = layer.bias is not None + if "awq" in layer_backend: + new_layer = QuantLinear.from_linear( # pylint: disable=E1123 + layer, + bits, + group_size, + init_only=True + ) + else: + try: + new_layer = QuantLinear( # pylint: disable=E1123 + bits, + group_size, + in_features, + out_features, + bias, + weight_dtype=layer.weight.dtype, + clip=clip + ) + except: + new_layer = QuantLinear( # pylint: disable=E1123 + bits, + group_size, + in_features, + out_features, + bias, + weight_dtype=layer.weight.dtype, + ) + + new_layer.device = layer_device + set_module(module, layer_name, new_layer) + +auto_round.auto_quantizer.AutoRoundQuantizer.cpu_post_init = cpu_post_init +auto_round.auto_quantizer.AutoRoundQuantizer._replace_by_quant_layers = _replace_by_quant_layers +auto_round.auto_quantizer.AutoRoundQuantizer.convert_model = convert_model diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index 4bf492cbe0d..201ab610922 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -105,3 +105,11 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down n_splits=n_splits_hidden_size, load=load)) delattr(module, name) + + +def is_auto_round_model(model: torch.nn.Module): + if hasattr(model, "quantization_config"): + quant_config = getattr(model.config, "quantization_config", None) + if quant_config is not None and quant_config.quant_method == "intel/auto-round": + return True + return False diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py new file mode 100644 index 00000000000..8d38a1f1c2d --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py @@ -0,0 +1,184 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import auto_round_patch +import torch +import os +from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead +from ipex_llm.transformers.npu_models.convert import module_optimization +from ipex_llm.transformers.npu_models.linear import QuantizedLinear +from ipex_llm.utils.common.log4Error import invalidInputError + + +def unpack_auto_round_layer(layer, qtype="sym_int4_rtn"): + n, m = layer.infeatures, layer.outfeatures + weight = layer.qweight.to("cpu") + scale = layer.scales.to("cpu") + zeros = layer.qzeros.to("cpu") # np.int32, 1 x m // 4 + bits = layer.bits + + scale = scale.t().contiguous() + + int_weight = torch.zeros((n, m), dtype=torch.uint8) + num = 32 // bits + + for i in range(0, n // num): + for j in range(0, num): + int_weight[i * num + j, :] = (( weight[i, :] >> (j * bits) ) & 0x0000000F ).to(torch.uint8) + + int_weight = (int_weight - 8).to(torch.int8) # n, m + qweights = int_weight.t().contiguous() # m, n + + # if we want to transform it to our NPU format, uncomment below code + qweights = qweights.reshape(m, -1 , 2) # m * n/2 * 2 + low_bit, high_bit = qweights.split(1, dim=-1) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + low_bit = low_bit & 0x0f + qweights = high_bit | low_bit + + if qtype == "sym_int4_rtn" or qtype == "sym_int8_rtn": + zero = None + elif qtype == "asym_int4_rtn": + zero = zeros.view(torch.int32) + int_zero = torch.zeros((1, m), dtype=torch.uint8) + num = 32 // bits + + for i in range(0, m // num): + for j in range(0, num): + int_zero[:, i * num + j] = (( zero[:, i] >> (j * bits) ) & 0x0000000F ).to(torch.uint8) + + zero = int_zero.to(torch.int8) + zero = zero.t().contiguous() # m, 1 + zero = zero.to(torch.float32) * -1 * scale + zero += 8 * scale + else: + invalidInputError(False, + f"unpack_auto_round_layer does not support qtype {qtype}.") + return qweights.view(torch.uint8), scale.to(torch.float16), zero.to(torch.float16) + + +@module_optimization +def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, + group_size, imatrix): + from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype + from ipex_llm.ggml.quantize import ggml_tensor_qtype + iqtype = ggml_tensor_qtype[qtype] + if layer.__class__.__name__ == "QuantLinear": + from auto_round_extension.ipex.qlinear_ipex_gptq import QuantLinear + if isinstance(layer, QuantLinear): + # auto-round's QuantLinear + qweights, scale, zero = unpack_auto_round_layer(layer, qtype=qtype) + return QuantizedLinear(qweights, scale, zero, layer.bias, + group_size=group_size, qtype=qtype) + elif isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"): + enable_scale_search = (os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" or + os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0") + qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), + iqtype, device=device, + enable_scale_search=enable_scale_search, + imatrix=imatrix) + zero = None + # split scale to scale & zero + if qtype == "asym_int4_rtn": + scale, zero = torch.split(scale, scale.shape[0] // 2) + return QuantizedLinear(qweights, scale, zero, layer.bias, + group_size=group_size, qtype=qtype) + + +def convert_auto_round_model_to_npu_model(model, save_directory, max_context_len = 1024, max_prompt_len = 960, + transpose_value_cache = True, fuse_layers = None, mixed_precision = False, + inter_pp = None, intra_pp = None, optimize_model=True): + quant_config = getattr(model.config, "quantization_config", None) + if quant_config is None and quant_config.quant_method != "intel/auto-round": + exit(-1) + + bits = quant_config.bits + group_size = quant_config.group_size + sym = quant_config.sym + + if sym and bits == 4 : + qtype = "sym_int4_rtn" + elif not sym and bits == 4: + qtype = "asym_int4_rtn" + elif sym and bits == 4: + qtype = "sym_int8_rtn" + else: + invalidInputError(False, + "Invalid dtype.") + + if group_size == -1: + quantization_group_size = 0 + else: + quantization_group_size = group_size + + if model.config.model_type == "qwen2": + # for Qwen2-7B-Insturct and MiniCPM-V 2.6, divide lm_head into 14 parts + if model.config.hidden_size == 3584 and (model.config.vocab_size == 152064 or + model.config.vocab_size == 151666): + # Do not split lm_head and use sym_int8 instead when mixed_precison is True + if quantization_group_size == 0: + # Do not split lm_head and use sym_int8 instead when mixed_precison is True + is_split = (not mixed_precision) and qtype in ["sym_int4_rtn", "asym_int4_rtn"] + split_num = 14 if is_split else 1 + new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, + bias=model.lm_head.bias, use_split=True, + group_size=quantization_group_size, + asym=((qtype == "asym_int4_rtn") and + (not mixed_precision))) + del model.lm_head + model.lm_head = new_lm_head + + replace_with_QuantizedLinear(model, qtype, "cpu", [], + quantization_group_size, None) + + from intel_npu_acceleration_library.compiler import create_npu_kernels + create_npu_kernels(model) + model = model.eval() + model.config.update({"mixed_precision": mixed_precision}) + model.config.update({"group_size": quantization_group_size}) + model.config.update({"asym": qtype == "asym_int4_rtn"}) + model.config.update({"bigdl_transformers_low_bit": qtype}) + model.config.update({"optimize_model": optimize_model}) + + if (not hasattr(model, 'llm') and + model.config.model_type in ["qwen2", "llama", "minicpm"]): + from ipex_llm.transformers.npu_models.convert import optimize_llm_single_process + optimize_llm_single_process( + model, + kv_len=max_context_len - 1, + max_prompt_len=max_prompt_len, + transpose_value_cache=transpose_value_cache, + group_size=quantization_group_size, + qtype=qtype, + save_directory=save_directory, + fuse_layers=fuse_layers + ) + else: + from ipex_llm.transformers.npu_models.convert_mp import optimize_llm + optimize_llm( + model, + max_context_len=max_context_len - 1, + max_prompt_len=max_prompt_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + transpose_value_cache=transpose_value_cache, + group_size=quantization_group_size + ) + + return model diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 42c15e104c1..ec85deac34e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -31,6 +31,7 @@ import numpy as np from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead from multiprocessing import Pool +from ipex_llm.transformers.npu_models.common import is_auto_round_model def generate( @@ -205,10 +206,12 @@ def convert_llm(model: torch.nn.Module, # do not split mlp down_proj for Qwen2-7B & sym_int8 n_splits_down_proj = 1 else: - # n_splits_down_proj = 2 if (model.config.intermediate_size == 18944 or - # os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or - # os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1") else 1 - n_splits_down_proj = 1 # for auto-round test + if is_auto_round_model(model): + n_splits_down_proj = 1 # for auto-round + else: + n_splits_down_proj = 2 if (model.config.intermediate_size == 18944 or + os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or + os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1") else 1 else: n_splits_linear = model.config.hidden_size // group_size n_splits_down_proj = model.config.intermediate_size // group_size From 50ace72a99cc707fc9c4f955f066a758aebf657e Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 24 Dec 2024 21:57:52 +0800 Subject: [PATCH 3/4] fix style --- .../transformers/npu_models/common.py | 2 +- .../npu_models/convert_auto_round_model.py | 33 ++++++++++--------- .../npu_pipeline_model/convert_pipeline.py | 2 +- .../transformers/npu_pipeline_model/qwen.py | 3 +- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index 201ab610922..7fba5771b2d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -111,5 +111,5 @@ def is_auto_round_model(model: torch.nn.Module): if hasattr(model, "quantization_config"): quant_config = getattr(model.config, "quantization_config", None) if quant_config is not None and quant_config.quant_method == "intel/auto-round": - return True + return True return False diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py index 8d38a1f1c2d..c0935310384 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_auto_round_model.py @@ -28,7 +28,7 @@ def unpack_auto_round_layer(layer, qtype="sym_int4_rtn"): n, m = layer.infeatures, layer.outfeatures weight = layer.qweight.to("cpu") scale = layer.scales.to("cpu") - zeros = layer.qzeros.to("cpu") # np.int32, 1 x m // 4 + zeros = layer.qzeros.to("cpu") # np.int32, 1 x m // 4 bits = layer.bits scale = scale.t().contiguous() @@ -38,13 +38,13 @@ def unpack_auto_round_layer(layer, qtype="sym_int4_rtn"): for i in range(0, n // num): for j in range(0, num): - int_weight[i * num + j, :] = (( weight[i, :] >> (j * bits) ) & 0x0000000F ).to(torch.uint8) + int_weight[i*num + j, :] = ((weight[i, :] >> (j*bits)) & 0x0000000F).to(torch.uint8) - int_weight = (int_weight - 8).to(torch.int8) # n, m - qweights = int_weight.t().contiguous() # m, n + int_weight = (int_weight - 8).to(torch.int8) # n, m + qweights = int_weight.t().contiguous() # m, n # if we want to transform it to our NPU format, uncomment below code - qweights = qweights.reshape(m, -1 , 2) # m * n/2 * 2 + qweights = qweights.reshape(m, -1, 2) # m * n/2 * 2 low_bit, high_bit = qweights.split(1, dim=-1) high_bit = high_bit.squeeze().view(torch.int8) low_bit = low_bit.squeeze().view(torch.int8) @@ -61,16 +61,16 @@ def unpack_auto_round_layer(layer, qtype="sym_int4_rtn"): for i in range(0, m // num): for j in range(0, num): - int_zero[:, i * num + j] = (( zero[:, i] >> (j * bits) ) & 0x0000000F ).to(torch.uint8) + int_zero[:, i*num + j] = ((zero[:, i] >> (j*bits)) & 0x0000000F).to(torch.uint8) zero = int_zero.to(torch.int8) - zero = zero.t().contiguous() # m, 1 + zero = zero.t().contiguous() # m, 1 zero = zero.to(torch.float32) * -1 * scale zero += 8 * scale else: invalidInputError(False, f"unpack_auto_round_layer does not support qtype {qtype}.") - return qweights.view(torch.uint8), scale.to(torch.float16), zero.to(torch.float16) + return qweights.view(torch.uint8), scale.to(torch.float16), zero.to(torch.float16) @module_optimization @@ -85,7 +85,7 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, # auto-round's QuantLinear qweights, scale, zero = unpack_auto_round_layer(layer, qtype=qtype) return QuantizedLinear(qweights, scale, zero, layer.bias, - group_size=group_size, qtype=qtype) + group_size=group_size, qtype=qtype) elif isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"): enable_scale_search = (os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" or os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0") @@ -101,9 +101,10 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, group_size=group_size, qtype=qtype) -def convert_auto_round_model_to_npu_model(model, save_directory, max_context_len = 1024, max_prompt_len = 960, - transpose_value_cache = True, fuse_layers = None, mixed_precision = False, - inter_pp = None, intra_pp = None, optimize_model=True): +def convert_auto_round_model_to_npu_model(model, save_directory, max_context_len=1024, + max_prompt_len=960, transpose_value_cache=True, + fuse_layers=None, mixed_precision=False, + inter_pp=None, intra_pp=None, optimize_model=True): quant_config = getattr(model.config, "quantization_config", None) if quant_config is None and quant_config.quant_method != "intel/auto-round": exit(-1) @@ -112,16 +113,16 @@ def convert_auto_round_model_to_npu_model(model, save_directory, max_context_len group_size = quant_config.group_size sym = quant_config.sym - if sym and bits == 4 : + if sym and bits == 4: qtype = "sym_int4_rtn" elif not sym and bits == 4: qtype = "asym_int4_rtn" - elif sym and bits == 4: + elif sym and bits == 4: qtype = "sym_int8_rtn" else: invalidInputError(False, - "Invalid dtype.") - + "Invalid dtype.") + if group_size == -1: quantization_group_size = 0 else: diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index ec85deac34e..101eb67900d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -207,7 +207,7 @@ def convert_llm(model: torch.nn.Module, n_splits_down_proj = 1 else: if is_auto_round_model(model): - n_splits_down_proj = 1 # for auto-round + n_splits_down_proj = 1 # for auto-round else: n_splits_down_proj = 2 if (model.config.intermediate_size == 18944 or os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index 581a5eb75e3..6bdad7537fd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -308,7 +308,8 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0), torch.stack(zeros, axis=0))) else: - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + weights.append((torch.stack(l_weights, axis=0), + torch.stack(scales, axis=0))) else: for layer in [attn_layer.q_proj, attn_layer.k_proj, attn_layer.v_proj, attn_layer.o_proj, From 7abcd3cd8211ace5784863eeba4387b335fdd39c Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Tue, 24 Dec 2024 22:08:25 +0800 Subject: [PATCH 4/4] fix style --- .../npu_models/auto_round_patch.py | 103 +++++++++++------- 1 file changed, 63 insertions(+), 40 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py b/python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py index 141055c8cd5..801868a3b3f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/auto_round_patch.py @@ -41,13 +41,15 @@ from transformers.pytorch_utils import Conv1D from logging import getLogger from typing import Union +from ipex_llm.utils.common import invalidInputError logger = getLogger(__name__) import auto_round -def check_compatible(backend_name, device, bits, group_size, sym, packing_format, in_features, out_features, - check_requirements=True): + +def check_compatible(backend_name, device, bits, group_size, sym, packing_format, + in_features, out_features, check_requirements=True): """Checks if the given configuration is compatible with the specified backend. Args: @@ -79,7 +81,7 @@ def check_compatible(backend_name, device, bits, group_size, sym, packing_format backend = auto_round.backend.BackendInfos[backend_name] # Check if device is supported by the backend - if not device in backend.device: + if device not in backend.device: return False # Check if bit-width is supported @@ -101,7 +103,8 @@ def check_compatible(backend_name, device, bits, group_size, sym, packing_format return False # Check if the format is convertible when packing formats differ - if packing_format != backend.packing_format and packing_format not in backend.convertable_format: + if packing_format != backend.packing_format and \ + packing_format not in backend.convertable_format: return False if check_requirements and backend.requirements is not None: @@ -118,11 +121,12 @@ def check_compatible(backend_name, device, bits, group_size, sym, packing_format return True -def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_features, out_features): +def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, + in_features, out_features): """Selects the most suitable backend for the layer based on compatibility and priority. - This function first checks if the specified backend supports the layer with the provided configuration. - If not, it iterates through other available backends, + This function first checks if the specified backend supports the layer with the + provided configuration. If not, it iterates through other available backends, checking compatibility and returning the one with the highest priority. Args: @@ -154,13 +158,15 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f If no compatible backend is found for the given layer configuration. """ # Check if the provided backend is in BackendInfos - assert backend in auto_round.backend.BackendInfos.keys(), \ - f"Unsupported backend {backend}, please set it to `auto` to try automatic selection" + invalidInputError(backend in auto_round.backend.BackendInfos.keys(), + f"Unsupported backend {backend}, " + "please set it to `auto` to try automatic selection") packing_format = auto_round.backend.BackendInfos[orig_backend].packing_format # Check if the provided backend supports the layer configuration - if check_compatible(backend, device, bits, group_size, sym, packing_format, in_features, out_features): + if check_compatible(backend, device, bits, group_size, sym, packing_format, + in_features, out_features): return backend # Find and store other compatible backends @@ -168,26 +174,30 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f for key in auto_round.backend.BackendInfos.keys(): if key == backend: continue - if check_compatible(key, device, bits, group_size, sym, packing_format, in_features, out_features): + if check_compatible(key, device, bits, group_size, sym, packing_format, + in_features, out_features): supported_backends.append(key) # Raise an error if no compatible backends are found if len(supported_backends) == 0: supported_backends_need_package = [] for key in auto_round.backend.BackendInfos.keys(): - if check_compatible(key, device, bits, group_size, sym, packing_format, in_features, out_features, + if check_compatible(key, device, bits, group_size, sym, packing_format, + in_features, out_features, check_requirements=False): supported_backends_need_package.append(key) if len(supported_backends_need_package) > 0: - supported_backends_need_package = sorted(supported_backends_need_package, - key=lambda support_backend: auto_round.backend.BackendInfos[support_backend].priority, - reverse=True) + supported_backends_need_package = sorted( + supported_backends_need_package, + key=lambda support_backend: + auto_round.backend.BackendInfos[support_backend].priority, + reverse=True) backend_info = auto_round.backend.BackendInfos[supported_backends_need_package[0]] # ipex-llm change start - # logger.error("please install all the following packages to support inference") for requirement in backend_info.requirements: - if isinstance(requirement, str) and not requirement.startswith("intel-extension-for-"): + if isinstance(requirement, str) and \ + not requirement.startswith("intel-extension-for-"): try: require_version(requirement) except ImportError: @@ -196,13 +206,14 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f str_info = requirement()[1] logger.error(str_info) if not requirement.startswith("intel-extension-for-"): - exit(-1) - - # raise ValueError(f"None of the backends support this layer") - # ipex-llm change end + invalidInputError(False, + f"exit for missing requirement {requirement}") + # ipex-llm change end # Sort the compatible backends by priority and return the one with the highest priority - supported_backends = sorted(supported_backends, key=lambda support_backend: auto_round.backend.BackendInfos[support_backend].priority, + supported_backends = sorted(supported_backends, + key=lambda support_backend: + auto_round.backend.BackendInfos[support_backend].priority, reverse=True) # ipex-llm change start @@ -214,13 +225,14 @@ def get_layer_backend(device, backend, orig_backend, bits, group_size, sym, in_f import auto_round.backend auto_round.backend.get_layer_backend = get_layer_backend -auto_round.backend.check_compatible = check_compatible +auto_round.backend.check_compatible = check_compatible importlib.reload(auto_round.backend) from auto_round.utils import (get_block_names, get_module, set_module, get_multimodal_block_names, find_matching_blocks) + def cpu_post_init(self, model): return model @@ -251,7 +263,8 @@ def convert_model(self, model: nn.Module): if not hasattr(quantization_config, "target_backend"): quantization_config.target_backend = quantization_config.backend - target_device = self.detect_device(quantization_config.target_backend, quantization_config.backend) + target_device = self.detect_device(quantization_config.target_backend, + quantization_config.backend) self.target_device = target_device if hasattr(quantization_config, "backend"): # pragma: no cover @@ -267,13 +280,15 @@ def convert_model(self, model: nn.Module): bits = quantization_config.bits group_size = quantization_config.group_size data_type = quantization_config.data_type if hasattr(quantization_config, - "data_type") else "int" # pragma: no cover + "data_type") else "int" # pragma: no cover sym = quantization_config.sym - to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config, - "to_quant_block_names") else None + if hasattr(quantization_config, "to_quant_block_names"): + to_quant_block_names = quantization_config.to_quant_block_names + else: + to_quant_block_names = None quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config, - "quant_block_list") else None - if to_quant_block_names is None: # TODO check compatibility + "quant_block_list") else None + if to_quant_block_names is None: # TODO check compatibility all_blocks = get_block_names(model) else: all_blocks = get_multimodal_block_names(model, quant_vision=True) @@ -299,8 +314,10 @@ def convert_model(self, model: nn.Module): layer_configs[layer_name]["clip"] = False else: layer_configs[layer_name]["bits"] = extra_config[layer_name].get("bits", bits) - layer_configs[layer_name]["group_size"] = extra_config[layer_name].get("group_size", group_size) - layer_configs[layer_name]["data_type"] = extra_config[layer_name].get("data_type", data_type) + layer_configs[layer_name]["group_size"] = extra_config[layer_name].get("group_size", + group_size) + layer_configs[layer_name]["data_type"] = extra_config[layer_name].get("data_type", + data_type) layer_configs[layer_name]["sym"] = extra_config[layer_name].get("sym", sym) layer_configs[layer_name]["clip"] = extra_config[layer_name].get("clip", False) @@ -309,9 +326,10 @@ def convert_model(self, model: nn.Module): elif 'gptq' in quantization_config.quant_method: # pragma: no cover backend = 'gptq' else: # pragma: no cover - raise ValueError("Quantization backend must be specified.") + invalidInputError(False, "Quantization backend must be specified.") - self._replace_by_quant_layers(model, layer_configs, quantization_config.target_backend, target_device, backend) + self._replace_by_quant_layers(model, layer_configs, quantization_config.target_backend, + target_device, backend) return model @@ -321,7 +339,8 @@ def get_device(obj: Union[torch.Tensor, nn.Module]): return next(obj.parameters()).device -def _replace_by_quant_layers(self, module: nn.Module, layer_configs, target_backend, target_device, orig_backend): +def _replace_by_quant_layers(self, module: nn.Module, layer_configs, target_backend, + target_device, orig_backend): """Replaces linear layers in the given module with quantized layers. This method iterates over the specified layer configurations and replaces @@ -395,20 +414,24 @@ def remove_device_str(s, device_str): if "marlin" in target_backend and "marlin" not in orig_backend: # Need to repack - assert sym == True, "Marlin only supports symmetric quantization" - assert target_device == "cuda", "Marlin only supports CUDA device" - assert not "awq" in orig_backend, "Marlin does not support repacking from AWQ format" + invalidInputError(sym, + "Marlin only supports symmetric quantization") + invalidInputError(target_device == "cuda", + "Marlin only supports CUDA device") + invalidInputError("awq" not in orig_backend, + "Marlin does not support repacking from AWQ format") self.need_marlin_repacking = True # Using original backend to load the layer then replace layer_backend = orig_backend else: - target_backend = self.find_backend(target_backend) # TODO: Move out if have supported marlin + target_backend = self.find_backend(target_backend) layer_backend = get_layer_backend( - target_device, target_backend, orig_backend, bits, group_size, sym, in_features, out_features + target_device, target_backend, orig_backend, bits, group_size, + sym, in_features, out_features ) if "gptq" in layer_backend and "exllamav2" in layer_backend: try: - from exllamav2_kernels import gemm_half_q_half, make_q_matrix # pylint: disable=E0611 + from exllamav2_kernels import gemm_half_q_half, make_q_matrix except: logger.warning_once( "For better inference performance, please install exllamav2 kernel "