Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ipex XPU support #608

Merged
merged 13 commits into from
Nov 27, 2024
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ Public tests/papers and ModelCloud's internal tests have shown that GPTQ is on-p

## Platform Requirements

GPTQModel is validated for Linux x86_64 with Nvidia GPUs. Windows WSL2 may work but un-tested.
GPTQModel is validated for Linux x86_64 with the following devices:
| NV GPU | ✅ |
| Intel CPU | ✅ |
| Intel GPU | ✅ |

Windows WSL2 may work but un-tested.

## Install

Expand All @@ -110,7 +115,7 @@ git clone https://github.com/ModelCloud/GPTQModel.git && cd GPTQModel

# pip: compile and install
# You can install optional modules like autoround, ipex, vllm, sglang, bitblas, and ipex.
# Example: pip install -v --no-build-isolation gptqmodel[vllm,sglang,bitblas,ipex,auto_round]
# Example: pip install -v --no-build-isolation .[vllm,sglang,bitblas,ipex,auto_round]
pip install -v . --no-build-isolation
```

Expand All @@ -121,7 +126,7 @@ Below is a basic sample using `GPTQModel` to quantize a llm model and perform po
```py
from datasets import load_dataset
from transformers import AutoTokenizer
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel import GPTQModel, QuantizeConfig, get_best_device

model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"
Expand All @@ -145,7 +150,7 @@ model.quantize(calibration_dataset)

model.save(quant_path)

model = GPTQModel.load(quant_path)
model = GPTQModel.load(quant_path, device=get_best_device())

result = model.generate(
**tokenizer(
Expand All @@ -171,8 +176,9 @@ pip install lm-eval[gptqmodel]

### Which kernel is used by default?

* `GPU`: Marlin, Exllama v2, Triton kernels in that order for maximum inference performance. Optional Microsoft/BITBLAS kernel can be toggled.
* `CPU`: Intel/IPEX kernel
* `GPU`: Marlin, Exllama v2, Triton kernels in that order for maximum inference performance. Optional Microsoft/BITBLAS kernel can be toggled.
* `CPU`: Intel/IPEX kernel
* `XPU`: Intel/IPEX kernel

## Citation
```
Expand Down
10 changes: 6 additions & 4 deletions examples/inference/run_with_different_backends.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
import subprocess
import sys
import torch
from argparse import ArgumentParser

from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_backend
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_backend, get_best_device
from transformers import AutoTokenizer

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
pretrained_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "./TinyLlama/TinyLlama-1.1B-Chat-v1.0-4bit-128g"

def main():
Expand All @@ -18,8 +19,9 @@ def main():
args = parser.parse_args()

backend = get_backend(args.backend)

device = 'cpu' if backend == BACKEND.IPEX else 'cuda:0'
device = get_best_device()
if backend == BACKEND.IPEX and device.type == "cuda":
device = torch.device("cpu")

if backend == BACKEND.SGLANG:
subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm>=0.6.2"])
Expand Down
6 changes: 2 additions & 4 deletions examples/quantization/basic_usage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import torch
from gptqmodel import GPTQModel, QuantizeConfig
from gptqmodel import GPTQModel, QuantizeConfig, get_best_device
from transformers import AutoTokenizer

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
Expand All @@ -12,8 +12,6 @@

def main():
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id, use_fast=True)
import pdb
pdb.set_trace()
calibration_dataset = [
tokenizer(
"gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
Expand Down Expand Up @@ -52,7 +50,7 @@ def main():
model.save(quantized_model_id, use_safetensors=True)

# load quantized model to the first GPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = get_best_device()
model = GPTQModel.load(quantized_model_id, device=device)

# load quantized model to CPU with IPEX kernel linear.
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .models import GPTQModel
from .models import GPTQModel, get_best_device
from .quantization import BaseQuantizeConfig, QuantizeConfig
from .utils import BACKEND, get_backend
from .version import __version__
1 change: 1 addition & 0 deletions gptqmodel/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._const import get_best_device
from .auto import MODEL_MAP, GPTQModel
from .base import BaseGPTQModel
from .definitions import *
18 changes: 17 additions & 1 deletion gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import torch
from enum import Enum

from torch import device

CPU = device("cpu")
CUDA = device("cuda")
CUDA_0 = device("cuda:0")
XPU = device("xpu")
XPU_0 = device("xpu:0")

class DEVICE(Enum):
CPU = "cpu"
CUDA = "cuda"
XPU = "xpu"


def is_torch_support_xpu():
return hasattr(torch, "xpu") and torch.xpu.is_available()


def get_device_by_type(type_value: str):
Expand All @@ -17,6 +24,15 @@ def get_device_by_type(type_value: str):
return enum_constant
raise ValueError(f"Invalid type_value str: {type_value}")


def get_best_device():
if torch.cuda.is_available():
return CUDA_0
elif is_torch_support_xpu():
return XPU_0
else:
return CPU

SUPPORTED_MODELS = [
"bloom",
"gptj",
Expand Down
16 changes: 9 additions & 7 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from packaging import version
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils

from ._const import CPU, CUDA_0
from ._const import CPU, CUDA_0, get_best_device
from .loader import ModelLoader
from .writer import QUANT_LOG_DAMP, QUANT_LOG_LAYER, QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter
from ..quantization import GPTQ, QuantizeConfig
Expand Down Expand Up @@ -194,6 +194,8 @@ def quantize(
f"Unsupported quantization operation for quant method: {self.quantize_config.quant_method}"
)

best_device = get_best_device()

backend = BACKEND.AUTO
if not torch.cuda.is_available():
self.quantize_config.format = FORMAT.IPEX
Expand Down Expand Up @@ -270,11 +272,11 @@ def quantize(
device_map = self.hf_device_map
if device_map:
for name, device in device_map.items():
if device == "cpu" and torch.cuda.is_available():
if device == "cpu" and best_device != CPU:
logger.info(f"truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, CUDA_0)
accelerate.cpu_offload_with_hook(module, best_device)

calibration_dataset = self._prepare_dataset_for_quantization(calibration_dataset, batch_size, tokenizer,)

Expand Down Expand Up @@ -409,8 +411,8 @@ def store_input_hook(_, args, kwargs):
raise ValueError

force_layer_back_to_cpu = False
if get_device(layers[0]) == CPU and torch.cuda.is_available():
layers[0] = layers[0].to(CUDA_0)
if get_device(layers[0]) == CPU and best_device != CPU:
layers[0] = layers[0].to(best_device)
force_layer_back_to_cpu = True

ori_outside_layer_module_devices = {}
Expand Down Expand Up @@ -491,8 +493,8 @@ def store_input_hook(_, args, kwargs):
gpu_memorys.append(gpu_memory)
cpu_memorys.append(cpu_memory)
force_layer_back_to_cpu = False
if get_device(layer) == CPU and torch.cuda.is_available():
move_to(layer, CUDA_0)
if get_device(layer) == CPU and best_device != CPU:
move_to(layer, best_device)
force_layer_back_to_cpu = True
cur_layer_device = get_device(layer)
full = find_layers(layer)
Expand Down
15 changes: 9 additions & 6 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..utils.model import (auto_dtype_from_config, check_requires_version, convert_gptq_v1_to_v2_format,
find_layers, get_checkpoints, get_moe_layer_modules, gptqmodel_post_init, make_quant,
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
from ._const import CPU, DEVICE, SUPPORTED_MODELS
from ._const import get_best_device, is_torch_support_xpu, DEVICE, SUPPORTED_MODELS

logger = setup_logger()

Expand All @@ -44,10 +44,10 @@ def from_pretrained(
pass
except Exception as e:
raise ValueError(
f"IPEX is not available: {e}. Please install with `pip install -U intel-extension-for-transformers`."
f"IPEX is not available: {e}. Please install with `pip install -U intel-extension-for-ipex`."
)

model_init_kwargs["device_map"] = "cpu"
model_init_kwargs["device_map"] = "xpu" if is_torch_support_xpu() else "cpu"
torch_dtype = ipex_dtype()

if cls.require_trust_remote_code and not trust_remote_code:
Expand Down Expand Up @@ -98,7 +98,10 @@ def skip(*args, **kwargs):
raise TypeError(f"{config.model_type} isn't supported yet.")

if model_init_kwargs.get("cpu") != "cpu":
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_torch_support_xpu():
torch.xpu.empty_cache()

model = cls.loader.from_pretrained(pretrained_model_id_or_path, **model_init_kwargs)

Expand Down Expand Up @@ -144,7 +147,7 @@ def from_quantized(
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'

if backend == BACKEND.IPEX:
device = CPU
device = get_best_device()
try:
pass
except Exception as e:
Expand All @@ -157,7 +160,7 @@ def from_quantized(

if backend != BACKEND.IPEX and not torch.cuda.is_available():
raise EnvironmentError(
"Load pretrained model to do quantization requires CUDA gpu. Please set backend=BACKEND.IPEX for cpu only quantization and inference.")
"Load pretrained model to do quantization requires CUDA gpu. Please set backend=BACKEND.IPEX for cpu and xpu quantization and inference.")

"""load quantized model from local disk"""
if cls.require_trust_remote_code and not trust_remote_code:
Expand Down
31 changes: 21 additions & 10 deletions gptqmodel/nn_modules/qlinear/qlinear_ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn as nn
import transformers
from gptqmodel.models._const import DEVICE
from gptqmodel.models._const import is_torch_support_xpu, DEVICE
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

from ...utils.logger import setup_logger
Expand All @@ -20,7 +20,7 @@
IPEX_AVAILABLE = False
IPEX_ERROR_LOG = None
try:
from intel_extension_for_pytorch.nn.modules.weight_only_quantization import WeightOnlyQuantizedLinear
from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear
IPEX_AVAILABLE = True
except Exception:
IPEX_ERROR_LOG = Exception
Expand All @@ -30,7 +30,7 @@ def ipex_dtype() -> torch.dtype:
raise ImportError("intel_extension_for_pytorch not installed. "
"Please install via 'pip install intel_extension_for_pytorch'")

return torch.bfloat16
return torch.float16 if is_torch_support_xpu() else torch.bfloat16


def convert_dtype_torch2str(dtype):
Expand All @@ -50,7 +50,7 @@ def convert_dtype_torch2str(dtype):

class IPEXQuantLinear(BaseQuantLinear):
SUPPORTS_BITS = [4]
SUPPORTS_DEVICES = [DEVICE.CPU]
SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU]

def __init__(
self,
Expand All @@ -63,14 +63,16 @@ def __init__(
bias: bool,
kernel_switch_threshold=128,
training=False,
weight_dtype=torch.bfloat16,
weight_dtype=None,
**kwargs,
):
self.sym = False
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs)

if bits not in [4]:
raise NotImplementedError("Only 4-bits is supported for IPEX.")
if weight_dtype is None:
weight_dtype = torch.float16 if is_torch_support_xpu() else torch.bfloat16

self.infeatures = infeatures
self.outfeatures = outfeatures
Expand All @@ -79,6 +81,7 @@ def __init__(
self.maxq = 2**self.bits - 1
self.weight_dtype = weight_dtype
self.asym = True
self.init_ipex = False

self.register_buffer(
"qweight",
Expand Down Expand Up @@ -119,11 +122,13 @@ def __init__(

def post_init(self):
self.validate_device(self.qweight.device.type)
assert self.qweight.device.type == "cpu"
assert self.qweight.device.type in ("cpu", "xpu")

def init_ipex_linear(self):
if not self.training and IPEX_AVAILABLE:
self.ipex_linear = WeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \
self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \
self.infeatures, self.outfeatures, None, self.bias, \
self.group_size, self.g_idx, 0, 0)
self.group_size, self.g_idx, quant_method=0, dtype=0)

def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
Expand Down Expand Up @@ -183,8 +188,14 @@ def pack(self, linear, scales, zeros, g_idx=None):
self.qzeros = torch.from_numpy(qzeros)

def forward(self, x: torch.Tensor):
if not self.training and hasattr(self, "ipex_linear"):
return self.ipex_linear(x)
if not self.init_ipex:
self.init_ipex_linear()
self.init_ipex = True

if hasattr(self, "ipex_linear"):
with torch.no_grad():
outputs = self.ipex_linear(x)
return outputs

out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
Expand Down
11 changes: 10 additions & 1 deletion gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import transformers

from ...utils.logger import setup_logger
from ..triton_utils.dequant import QuantLinearFunction
from ..triton_utils.mixin import TritonModuleMixin
from . import BaseQuantLinear

triton_import_exception = None
try:
from ..triton_utils.dequant import QuantLinearFunction
except ImportError as e:
triton_import_exception = e

logger = setup_logger()


Expand All @@ -27,6 +32,10 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
"""

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures, outfeatures, bias, **kwargs,):
if triton_import_exception is not None:
raise ValueError(
f"Trying to use the triton backend, but could not import the triton with the following error: {triton_import_exception}"
)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs)
self.infeatures = infeatures
self.outfeatures = outfeatures
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
else:
FORCE_BUILD = False

extensions = []
common_setup_kwargs = {
"version": gptqmodel_version,
"name": "gptqmodel",
Expand Down
Loading
Loading