diff --git a/README.md b/README.md index 6f438267..26f6ece4 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,12 @@ Public tests/papers and ModelCloud's internal tests have shown that GPTQ is on-p ## Platform Requirements -GPTQModel is validated for Linux x86_64 with Nvidia GPUs. Windows WSL2 may work but un-tested. +GPTQModel is validated for Linux x86_64 with the following devices: +| NV GPU | ✅ | +| Intel CPU | ✅ | +| Intel GPU | ✅ | + +Windows WSL2 may work but un-tested. ## Install @@ -110,7 +115,7 @@ git clone https://github.com/ModelCloud/GPTQModel.git && cd GPTQModel # pip: compile and install # You can install optional modules like autoround, ipex, vllm, sglang, bitblas, and ipex. -# Example: pip install -v --no-build-isolation gptqmodel[vllm,sglang,bitblas,ipex,auto_round] +# Example: pip install -v --no-build-isolation .[vllm,sglang,bitblas,ipex,auto_round] pip install -v . --no-build-isolation ``` @@ -121,7 +126,7 @@ Below is a basic sample using `GPTQModel` to quantize a llm model and perform po ```py from datasets import load_dataset from transformers import AutoTokenizer -from gptqmodel import GPTQModel, QuantizeConfig +from gptqmodel import GPTQModel, QuantizeConfig, get_best_device model_id = "meta-llama/Llama-3.2-1B-Instruct" quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit" @@ -145,7 +150,7 @@ model.quantize(calibration_dataset) model.save(quant_path) -model = GPTQModel.load(quant_path) +model = GPTQModel.load(quant_path, device=get_best_device()) result = model.generate( **tokenizer( @@ -171,8 +176,9 @@ pip install lm-eval[gptqmodel] ### Which kernel is used by default? -* `GPU`: Marlin, Exllama v2, Triton kernels in that order for maximum inference performance. Optional Microsoft/BITBLAS kernel can be toggled. -* `CPU`: Intel/IPEX kernel +* `GPU`: Marlin, Exllama v2, Triton kernels in that order for maximum inference performance. Optional Microsoft/BITBLAS kernel can be toggled. +* `CPU`: Intel/IPEX kernel +* `XPU`: Intel/IPEX kernel ## Citation ``` diff --git a/examples/inference/run_with_different_backends.py b/examples/inference/run_with_different_backends.py index 02686a81..92e80bbd 100644 --- a/examples/inference/run_with_different_backends.py +++ b/examples/inference/run_with_different_backends.py @@ -1,13 +1,14 @@ import os import subprocess import sys +import torch from argparse import ArgumentParser -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_backend +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_backend, get_best_device from transformers import AutoTokenizer os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +pretrained_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" quantized_model_id = "./TinyLlama/TinyLlama-1.1B-Chat-v1.0-4bit-128g" def main(): @@ -18,8 +19,9 @@ def main(): args = parser.parse_args() backend = get_backend(args.backend) - - device = 'cpu' if backend == BACKEND.IPEX else 'cuda:0' + device = get_best_device() + if backend == BACKEND.IPEX and device.type == "cuda": + device = torch.device("cpu") if backend == BACKEND.SGLANG: subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm>=0.6.2"]) diff --git a/examples/quantization/basic_usage.py b/examples/quantization/basic_usage.py index bc95e9a8..5b2c03fc 100644 --- a/examples/quantization/basic_usage.py +++ b/examples/quantization/basic_usage.py @@ -1,7 +1,7 @@ import os import torch -from gptqmodel import GPTQModel, QuantizeConfig +from gptqmodel import GPTQModel, QuantizeConfig, get_best_device from transformers import AutoTokenizer os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -12,8 +12,6 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id, use_fast=True) - import pdb - pdb.set_trace() calibration_dataset = [ tokenizer( "gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." @@ -52,7 +50,7 @@ def main(): model.save(quantized_model_id, use_safetensors=True) # load quantized model to the first GPU - device = "cuda:0" if torch.cuda.is_available() else "cpu" + device = get_best_device() model = GPTQModel.load(quantized_model_id, device=device) # load quantized model to CPU with IPEX kernel linear. diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index 9e50a59c..0f31aa46 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -1,4 +1,4 @@ -from .models import GPTQModel +from .models import GPTQModel, get_best_device from .quantization import BaseQuantizeConfig, QuantizeConfig from .utils import BACKEND, get_backend from .version import __version__ diff --git a/gptqmodel/models/__init__.py b/gptqmodel/models/__init__.py index bc2952b5..6c71cf9c 100644 --- a/gptqmodel/models/__init__.py +++ b/gptqmodel/models/__init__.py @@ -1,3 +1,4 @@ +from ._const import get_best_device from .auto import MODEL_MAP, GPTQModel from .base import BaseGPTQModel from .definitions import * diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py index 8c637135..c7fd15a9 100644 --- a/gptqmodel/models/_const.py +++ b/gptqmodel/models/_const.py @@ -1,14 +1,21 @@ +import torch from enum import Enum - from torch import device CPU = device("cpu") CUDA = device("cuda") CUDA_0 = device("cuda:0") +XPU = device("xpu") +XPU_0 = device("xpu:0") class DEVICE(Enum): CPU = "cpu" CUDA = "cuda" + XPU = "xpu" + + +def is_torch_support_xpu(): + return hasattr(torch, "xpu") and torch.xpu.is_available() def get_device_by_type(type_value: str): @@ -17,6 +24,15 @@ def get_device_by_type(type_value: str): return enum_constant raise ValueError(f"Invalid type_value str: {type_value}") + +def get_best_device(): + if torch.cuda.is_available(): + return CUDA_0 + elif is_torch_support_xpu(): + return XPU_0 + else: + return CPU + SUPPORTED_MODELS = [ "bloom", "gptj", diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 743be768..aeed1453 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -10,7 +10,7 @@ from packaging import version from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils -from ._const import CPU, CUDA_0 +from ._const import CPU, CUDA_0, get_best_device from .loader import ModelLoader from .writer import QUANT_LOG_DAMP, QUANT_LOG_LAYER, QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter from ..quantization import GPTQ, QuantizeConfig @@ -194,6 +194,8 @@ def quantize( f"Unsupported quantization operation for quant method: {self.quantize_config.quant_method}" ) + best_device = get_best_device() + backend = BACKEND.AUTO if not torch.cuda.is_available(): self.quantize_config.format = FORMAT.IPEX @@ -270,11 +272,11 @@ def quantize( device_map = self.hf_device_map if device_map: for name, device in device_map.items(): - if device == "cpu" and torch.cuda.is_available(): + if device == "cpu" and best_device != CPU: logger.info(f"truly offloading {name} to cpu with hook.") module = get_module_by_name_suffix(self.model, name) remove_hook_from_module(module, recurse=True) - accelerate.cpu_offload_with_hook(module, CUDA_0) + accelerate.cpu_offload_with_hook(module, best_device) calibration_dataset = self._prepare_dataset_for_quantization(calibration_dataset, batch_size, tokenizer,) @@ -409,8 +411,8 @@ def store_input_hook(_, args, kwargs): raise ValueError force_layer_back_to_cpu = False - if get_device(layers[0]) == CPU and torch.cuda.is_available(): - layers[0] = layers[0].to(CUDA_0) + if get_device(layers[0]) == CPU and best_device != CPU: + layers[0] = layers[0].to(best_device) force_layer_back_to_cpu = True ori_outside_layer_module_devices = {} @@ -491,8 +493,8 @@ def store_input_hook(_, args, kwargs): gpu_memorys.append(gpu_memory) cpu_memorys.append(cpu_memory) force_layer_back_to_cpu = False - if get_device(layer) == CPU and torch.cuda.is_available(): - move_to(layer, CUDA_0) + if get_device(layer) == CPU and best_device != CPU: + move_to(layer, best_device) force_layer_back_to_cpu = True cur_layer_device = get_device(layer) full = find_layers(layer) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index ac09a73e..82e00a81 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -22,7 +22,7 @@ from ..utils.model import (auto_dtype_from_config, check_requires_version, convert_gptq_v1_to_v2_format, find_layers, get_checkpoints, get_moe_layer_modules, gptqmodel_post_init, make_quant, simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes) -from ._const import CPU, DEVICE, SUPPORTED_MODELS +from ._const import get_best_device, is_torch_support_xpu, DEVICE, SUPPORTED_MODELS logger = setup_logger() @@ -44,10 +44,10 @@ def from_pretrained( pass except Exception as e: raise ValueError( - f"IPEX is not available: {e}. Please install with `pip install -U intel-extension-for-transformers`." + f"IPEX is not available: {e}. Please install with `pip install -U intel-extension-for-ipex`." ) - model_init_kwargs["device_map"] = "cpu" + model_init_kwargs["device_map"] = "xpu" if is_torch_support_xpu() else "cpu" torch_dtype = ipex_dtype() if cls.require_trust_remote_code and not trust_remote_code: @@ -98,7 +98,10 @@ def skip(*args, **kwargs): raise TypeError(f"{config.model_type} isn't supported yet.") if model_init_kwargs.get("cpu") != "cpu": - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif is_torch_support_xpu(): + torch.xpu.empty_cache() model = cls.loader.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs) @@ -144,7 +147,7 @@ def from_quantized( os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER' if backend == BACKEND.IPEX: - device = CPU + device = get_best_device() try: pass except Exception as e: @@ -157,7 +160,7 @@ def from_quantized( if backend != BACKEND.IPEX and not torch.cuda.is_available(): raise EnvironmentError( - "Load pretrained model to do quantization requires CUDA gpu. Please set backend=BACKEND.IPEX for cpu only quantization and inference.") + "Load pretrained model to do quantization requires CUDA gpu. Please set backend=BACKEND.IPEX for cpu and xpu quantization and inference.") """load quantized model from local disk""" if cls.require_trust_remote_code and not trust_remote_code: diff --git a/gptqmodel/nn_modules/qlinear/qlinear_ipex.py b/gptqmodel/nn_modules/qlinear/qlinear_ipex.py index ff853725..cd623222 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_ipex.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_ipex.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn import transformers -from gptqmodel.models._const import DEVICE +from gptqmodel.models._const import is_torch_support_xpu, DEVICE from gptqmodel.nn_modules.qlinear import BaseQuantLinear from ...utils.logger import setup_logger @@ -20,7 +20,7 @@ IPEX_AVAILABLE = False IPEX_ERROR_LOG = None try: - from intel_extension_for_pytorch.nn.modules.weight_only_quantization import WeightOnlyQuantizedLinear + from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear IPEX_AVAILABLE = True except Exception: IPEX_ERROR_LOG = Exception @@ -30,7 +30,7 @@ def ipex_dtype() -> torch.dtype: raise ImportError("intel_extension_for_pytorch not installed. " "Please install via 'pip install intel_extension_for_pytorch'") - return torch.bfloat16 + return torch.float16 if is_torch_support_xpu() else torch.bfloat16 def convert_dtype_torch2str(dtype): @@ -50,7 +50,7 @@ def convert_dtype_torch2str(dtype): class IPEXQuantLinear(BaseQuantLinear): SUPPORTS_BITS = [4] - SUPPORTS_DEVICES = [DEVICE.CPU] + SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU] def __init__( self, @@ -63,7 +63,7 @@ def __init__( bias: bool, kernel_switch_threshold=128, training=False, - weight_dtype=torch.bfloat16, + weight_dtype=None, **kwargs, ): self.sym = False @@ -71,6 +71,8 @@ def __init__( if bits not in [4]: raise NotImplementedError("Only 4-bits is supported for IPEX.") + if weight_dtype is None: + weight_dtype = torch.float16 if is_torch_support_xpu() else torch.bfloat16 self.infeatures = infeatures self.outfeatures = outfeatures @@ -79,6 +81,7 @@ def __init__( self.maxq = 2**self.bits - 1 self.weight_dtype = weight_dtype self.asym = True + self.init_ipex = False self.register_buffer( "qweight", @@ -119,11 +122,13 @@ def __init__( def post_init(self): self.validate_device(self.qweight.device.type) - assert self.qweight.device.type == "cpu" + assert self.qweight.device.type in ("cpu", "xpu") + + def init_ipex_linear(self): if not self.training and IPEX_AVAILABLE: - self.ipex_linear = WeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \ + self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \ self.infeatures, self.outfeatures, None, self.bias, \ - self.group_size, self.g_idx, 0, 0) + self.group_size, self.g_idx, quant_method=0, dtype=0) def pack(self, linear, scales, zeros, g_idx=None): W = linear.weight.data.clone() @@ -183,8 +188,14 @@ def pack(self, linear, scales, zeros, g_idx=None): self.qzeros = torch.from_numpy(qzeros) def forward(self, x: torch.Tensor): - if not self.training and hasattr(self, "ipex_linear"): - return self.ipex_linear(x) + if not self.init_ipex: + self.init_ipex_linear() + self.init_ipex = True + + if hasattr(self, "ipex_linear"): + with torch.no_grad(): + outputs = self.ipex_linear(x) + return outputs out_shape = x.shape[:-1] + (self.outfeatures,) x = x.reshape(-1, x.shape[-1]) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py b/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py index 77736e94..17a1136c 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py @@ -6,10 +6,15 @@ import transformers from ...utils.logger import setup_logger -from ..triton_utils.dequant import QuantLinearFunction from ..triton_utils.mixin import TritonModuleMixin from . import BaseQuantLinear +triton_import_exception = None +try: + from ..triton_utils.dequant import QuantLinearFunction +except ImportError as e: + triton_import_exception = e + logger = setup_logger() @@ -27,6 +32,10 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin): """ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures, outfeatures, bias, **kwargs,): + if triton_import_exception is not None: + raise ValueError( + f"Trying to use the triton backend, but could not import the triton with the following error: {triton_import_exception}" + ) super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs) self.infeatures = infeatures self.outfeatures = outfeatures diff --git a/setup.py b/setup.py index d088e4d2..a0aecbeb 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ else: FORCE_BUILD = False +extensions = [] common_setup_kwargs = { "version": gptqmodel_version, "name": "gptqmodel", diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py index 053fbc6f..66b1e1e2 100644 --- a/tests/test_quant_formats.py +++ b/tests/test_quant_formats.py @@ -12,7 +12,7 @@ import torch.cuda # noqa: E402 from datasets import load_dataset # noqa: E402 -from gptqmodel import BACKEND, GPTQModel, __version__ # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, get_best_device, __version__ # noqa: E402 from gptqmodel.quantization import FORMAT, QUANT_CONFIG_FILENAME, QUANT_METHOD # noqa: E402 from gptqmodel.quantization.config import (META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL, # noqa: E402 AutoRoundQuantizeConfig, QuantizeConfig) @@ -79,7 +79,7 @@ def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, forma model = GPTQModel.load( tmpdirname, - device="cuda:0" if backend != BACKEND.IPEX else "cpu", + device="cuda:0" if backend != BACKEND.IPEX else get_best_device(), backend=backend, ) @@ -106,7 +106,7 @@ def test_quantize(self, method: QUANT_METHOD, backend: BACKEND, sym: bool, forma model = GPTQModel.load( tmpdirname, - device="cuda:0" if backend != BACKEND.IPEX else "cpu", + device="cuda:0" if backend != BACKEND.IPEX else get_best_device(), quantize_config=compat_quantize_config, ) assert isinstance(model.quantize_config, QuantizeConfig) diff --git a/tests/test_save_loaded_quantized_model.py b/tests/test_save_loaded_quantized_model.py index 8c7266ea..edb7e1d6 100644 --- a/tests/test_save_loaded_quantized_model.py +++ b/tests/test_save_loaded_quantized_model.py @@ -7,7 +7,7 @@ import unittest # noqa: E402 import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 @@ -26,7 +26,7 @@ class TestSave(unittest.TestCase): ) def test_save(self, backend): prompt = "I am in Paris and" - device = torch.device("cuda:0") if backend != BACKEND.IPEX else torch.device("cpu") + device = torch.device("cuda:0") if backend != BACKEND.IPEX else get_best_device() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) inp = tokenizer(prompt, return_tensors="pt").to(device)