Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test base apply quant config #102

Open
wants to merge 11 commits into
base: upstream-a564d10af
Choose a base branch
from
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@
"AqlmConfig",
"AwqConfig",
"BitsAndBytesConfig",
"CompressedTensorsConfig",
"EetqConfig",
"GPTQConfig",
"HqqConfig",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3826,7 +3826,7 @@ def from_pretrained(
dispatch_model(model, **device_map_kwargs)

if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
hf_quantizer.postprocess_model(model, resolved_archive_file=resolved_archive_file)
model.hf_quantizer = hf_quantizer

if _adapter_model_path is not None:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AqlmConfig,
AwqConfig,
BitsAndBytesConfig,
CompressedTensorsConfig,
EetqConfig,
GPTQConfig,
HqqConfig,
Expand All @@ -30,6 +31,7 @@
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
Expand All @@ -45,6 +47,7 @@
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
"compressed-tensors": CompressedTensorsHfQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -56,6 +59,7 @@
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig,
}


Expand Down
71 changes: 71 additions & 0 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
from .base import HfQuantizer


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


class CompressedTensorsHfQuantizer(HfQuantizer):
"""
Quantizer for the compressed_tensors package. Loads and restores models to
quantized state with compressed_tensors
"""

requires_calibration = False
# requires_parameters_quantization = True
required_packages = ["compressed_tensors"]

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)

from compressed_tensors.compressors import ModelCompressor

self.compressor = ModelCompressor.from_compression_config(quantization_config)

def validate_environment(self, *args, **kwargs):
# check torch and compressed_tensors are available, let ImportError raise otherwise
import torch # noqa
from compressed_tensors.compressors import ModelCompressor # noqa

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.float16
elif torch_dtype != torch.float16:
logger.info(
"We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors."
)
return torch_dtype

def _process_model_before_weight_loading(self, model, **kwargs):
if self.quantization_config.quantization_config is not None:
from compressed_tensors.quantization import apply_quantization_config

apply_quantization_config(model, self.quantization_config.quantization_config)

def _process_model_after_weight_loading(self, model, resolved_archive_file, **kwargs):
self.compressor.decompress(model_path=resolved_archive_file, model=model)

@property
def is_trainable(self):
return False

@property
def is_serializable(self):
return False
77 changes: 75 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
import importlib.metadata
import json
import os
from dataclasses import dataclass
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from packaging import version
from pydantic import BaseModel

from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging

Expand All @@ -42,6 +45,7 @@ class QuantizationMethod(str, Enum):
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
COMPRESSED_TENSORS = "compressed-tensors"


class AWQLinearVersion(str, Enum):
Expand All @@ -67,6 +71,23 @@ class AwqBackendPackingMethod(str, Enum):
LLMAWQ = "llm-awq"


def convert_to_dict(obj):
if is_dataclass(obj):
return asdict(obj)
elif isinstance(obj, BaseModel):
return obj.dict()
elif isinstance(obj, Enum):
return obj.value
elif isinstance(obj, dict):
return {k: convert_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_to_dict(i) for i in obj]
elif isinstance(obj, tuple):
return tuple(convert_to_dict(i) for i in obj)
else:
return obj


@dataclass
class QuantizationConfigMixin:
"""
Expand Down Expand Up @@ -130,7 +151,7 @@ def to_dict(self) -> Dict[str, Any]:
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return copy.deepcopy(self.__dict__)
return convert_to_dict(copy.deepcopy(self.__dict__))

def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
Expand Down Expand Up @@ -1038,3 +1059,55 @@ def post_init(self):
accepted_weights = ["int8"]
if self.weights not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")


@dataclass
class CompressedTensorsConfig(QuantizationConfigMixin):
"""
This is a wrapper class that handles compressed-tensors quantization config options.
It is a wrapper around `compressed_tensors.QuantizationConfig`

Args:
weights (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights. Supported value is only "int8"
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision.
"""

def __init__(
self,
config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None,
quant_method: str = "compressed-tensors",
format: str = "dense",
quantization_status: "QuantizationStatus" = "initialized",
global_compression_ratio: Optional[float] = None,
ignore: Optional[List[str]] = None,
sparsity_config: Dict[str, Any] = None,
**kwargs,
):
from compressed_tensors import QuantizationConfig
from compressed_tensors.config import SparsityCompressionConfig

self.quantization_config = None
self.sparsity_config = None

# parse from dict to load nested QuantizationScheme objects
if config_groups:
self.quantization_config = QuantizationConfig.parse_obj(
{
"config_groups": config_groups,
"quant_method": quant_method,
"format": format,
"quantization_status": quantization_status,
"global_compression_ratio": global_compression_ratio,
"ignore": ignore,
}
)

if sparsity_config:
self.sparsity_config = SparsityCompressionConfig.load_from_registry(
sparsity_config.get("format"), **sparsity_config
)

super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS)
Empty file.
77 changes: 77 additions & 0 deletions tests/quantization/compressed_tensor/test_compressed_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer
# from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer

import gc
import unittest

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
from transformers.testing_utils import slow


class CompressedTensorsTest(unittest.TestCase):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
source_quantized_model_name = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3"

prompt = "Paris is the capital of which country?"

def tear_down(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()

@classmethod
def setUpClass(self):
"""
Setup quantized model
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.source_quantized_model_name)
self.source_quantized_model = AutoModelForCausalLM.from_pretrained(self.source_quantized_model_name)

self.device = self.source_quantized_model.device
compression_config = self.source_quantized_model.config.quantization_config.quantization_config.config_groups

self.config = CompressedTensorsConfig(
config_groups=compression_config,
sparsity_config=self.source_quantized_model.config.quantization_config.sparsity_config.dict(),
)

self.assertIsNotNone(self.config.sparsity_config, "sparsity_config should not be None")
self.assertIsNotNone(self.config.quantization_config, "quantization_config should not be None")

# apply quantization config to the base model
self.quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, quantization_config=self.config)

def test_quantized_model(self):
"""Carry out generation"""
inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.device)
generated_ids = self.quantized_model.generate(**inputs, max_length=50)
outputs = self.tokenizer.batch_decode(generated_ids)

self.assertIsNotNone(outputs)
self.tear_down()

def test_quantized_model_initialized_scale_zero_point(self):
for _, module in self.quantized_model.named_modules():
scheme = getattr(module, "quantization_scheme", None)
if scheme is not None:
self.assertIsNotNone(module.weight_scale)
self.assertIsNotNone(module.weight_zero_point)

@slow
def test_forward(self):
batch_size = context_size = 1024
tensor1 = torch.rand(1024) * 1000
tensor1 = tensor1.long()
tensor2 = torch.rand(1024) * 1000
tensor2 = tensor2.long()

input_tensor = torch.cat((tensor1, tensor2), dim=0)
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
out = self.quantized_model(input_tensor)
self.assertEqual(out.shape[0], batch_size)
self.assertEqual(out.shape[1], context_size)

self.tear_down()
Loading