From d695ec3ea08a628774293ae98e904a380c181fd9 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 5 Jun 2024 16:00:16 -0400 Subject: [PATCH 01/11] Add compressed-tensors HFQuantizer implementation --- src/transformers/modeling_utils.py | 2 +- src/transformers/quantizers/auto.py | 4 ++ .../quantizer_compressed_tensors.py | 69 +++++++++++++++++++ src/transformers/utils/quantization_config.py | 53 ++++++++++++++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 src/transformers/quantizers/quantizer_compressed_tensors.py diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 27f26e42a84a3b..713f454118263e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3826,7 +3826,7 @@ def from_pretrained( dispatch_model(model, **device_map_kwargs) if hf_quantizer is not None: - hf_quantizer.postprocess_model(model) + hf_quantizer.postprocess_model(model, resolved_archive_file=resolved_archive_file) model.hf_quantizer = hf_quantizer if _adapter_model_path is not None: diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 2c65afa77e282c..f2922ee9677de9 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -19,6 +19,7 @@ AqlmConfig, AwqConfig, BitsAndBytesConfig, + CompressedTensorsConfig, EetqConfig, GPTQConfig, HqqConfig, @@ -30,6 +31,7 @@ from .quantizer_awq import AwqQuantizer from .quantizer_bnb_4bit import Bnb4BitHfQuantizer from .quantizer_bnb_8bit import Bnb8BitHfQuantizer +from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_gptq import GptqHfQuantizer from .quantizer_hqq import HqqHfQuantizer @@ -45,6 +47,7 @@ "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, "hqq": HqqHfQuantizer, + "compressed_tensors": CompressedTensorsHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -56,6 +59,7 @@ "aqlm": AqlmConfig, "quanto": QuantoConfig, "hqq": HqqConfig, + "compressed_tensors": CompressedTensorsConfig, } diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py new file mode 100644 index 00000000000000..a201f504fc4efa --- /dev/null +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -0,0 +1,69 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .base import HfQuantizer + + +from ..utils import is_torch_available, logging +from ..utils.quantization_config import QuantizationConfigMixin + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class CompressedTensorsHfQuantizer(HfQuantizer): + """ + Quantizer for the compressed_tensors package. Loads and restores models to + quantized state with compressed_tensors + """ + + requires_calibration = False + # requires_parameters_quantization = True + required_packages = ["compressed_tensors"] + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + + from compressed_tensors.compressors import ModelCompressor + self.compressor = ModelCompressor.from_compression_config(quantization_config) + + def validate_environment(self, *args, **kwargs): + # check torch and compressed_tensors are available, let ImportError raise otherwise + import torch + from compressed_tensors.compressors import ModelCompressor + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = torch.float16 + elif torch_dtype != torch.float16: + logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors.") + return torch_dtype + + def _process_model_before_weight_loading(self, model, **kwargs): + if self.quantization_config.quantization_config is not None: + from compressed_tensors.quantization import apply_quantization_config + apply_quantization_config(model, self.quantization_config.quantization_config) + + def _process_model_after_weight_loading(self, model, resolved_archive_file, **kwargs): + self.compressor.decompress(model_path=resolved_archive_file, model=model) + + @property + def is_trainable(self): + return False + + @property + def is_serializable(self): + return True diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index f9e503cf862f18..496501557fdea5 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -42,6 +42,7 @@ class QuantizationMethod(str, Enum): QUANTO = "quanto" EETQ = "eetq" HQQ = "hqq" + COMPRESSED_TENSORS = "compressed_tensors" class AWQLinearVersion(str, Enum): @@ -1038,3 +1039,55 @@ def post_init(self): accepted_weights = ["int8"] if self.weights not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}") + + +@dataclass +class CompressedTensorsConfig(QuantizationConfigMixin): + """ + This is a wrapper class that handles compressed-tensors quantization config options. + It is a wrapper around `compressed_tensors.QuantizationConfig` + + Args: + weights (`str`, *optional*, defaults to `"int8"`): + The target dtype for the weights. Supported value is only "int8" + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + """ + + def __init__( + self, + config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, + quant_method: str = "sparseml", + format: str = "fakequant", + quantization_status: "QuantizationStatus" = "initialized", + global_compression_ratio: Optional[float] = None, + ignore: Optional[List[str]] = None, + sparsity_config: Dict[str, Any] = None, + **kwargs, + ): + from compressed_tensors import QuantizationConfig + from compressed_tensors.config import SparsityCompressionConfig + + self.quantization_config = None + self.sparsity_configq = None + + # parse from dict to load nested QuantizationScheme objects + if config_groups: + self.quantization_config = QuantizationConfig.parse_obj( + dict( + config_groups=config_groups, + quant_method=quant_method, + format=format, + quantization_status=quantization_status, + global_compression_ratio=global_compression_ratio, + ignore=ignore, + ) + ) + + if sparsity_config: + self.sparsity_config = SparsityCompressionConfig.load_from_registry( + sparsity_config.get("format"), **sparsity_config + ) + + super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS) From f4689647e5620fea55118fcafc414bedc76759a0 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 5 Jun 2024 16:10:08 -0400 Subject: [PATCH 02/11] flag serializable as False --- src/transformers/quantizers/quantizer_compressed_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index a201f504fc4efa..d1ebe4dc664038 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -66,4 +66,4 @@ def is_trainable(self): @property def is_serializable(self): - return True + return False From 41224d3d5f0f194b87ec1099bba1e42aff0056f1 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 10 Jun 2024 19:56:33 +0000 Subject: [PATCH 03/11] run --- .../quantizer_compressed_tensors.py | 14 +++++++----- src/transformers/utils/quantization_config.py | 22 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index d1ebe4dc664038..b24700d38c4e6c 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .base import HfQuantizer - - from ..utils import is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin +from .base import HfQuantizer if is_torch_available(): @@ -38,23 +36,27 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): super().__init__(quantization_config, **kwargs) from compressed_tensors.compressors import ModelCompressor + + # self.compressor = ModelCompressor.from_compression_config(quantization_config.to_dict()) self.compressor = ModelCompressor.from_compression_config(quantization_config) def validate_environment(self, *args, **kwargs): # check torch and compressed_tensors are available, let ImportError raise otherwise - import torch - from compressed_tensors.compressors import ModelCompressor + pass def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: torch_dtype = torch.float16 elif torch_dtype != torch.float16: - logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors.") + logger.info( + "We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors." + ) return torch_dtype def _process_model_before_weight_loading(self, model, **kwargs): if self.quantization_config.quantization_config is not None: from compressed_tensors.quantization import apply_quantization_config + apply_quantization_config(model, self.quantization_config.quantization_config) def _process_model_after_weight_loading(self, model, resolved_archive_file, **kwargs): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 496501557fdea5..3031ef6ff3415f 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -22,6 +22,8 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union +from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.quant_scheme import QuantizationScheme from packaging import version from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging @@ -1059,7 +1061,7 @@ def __init__( self, config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, quant_method: str = "sparseml", - format: str = "fakequant", + format: str = "dense", # "fakequant" not in CompressionFormat quantization_status: "QuantizationStatus" = "initialized", global_compression_ratio: Optional[float] = None, ignore: Optional[List[str]] = None, @@ -1070,19 +1072,19 @@ def __init__( from compressed_tensors.config import SparsityCompressionConfig self.quantization_config = None - self.sparsity_configq = None + self.sparsity_config = None # parse from dict to load nested QuantizationScheme objects if config_groups: self.quantization_config = QuantizationConfig.parse_obj( - dict( - config_groups=config_groups, - quant_method=quant_method, - format=format, - quantization_status=quantization_status, - global_compression_ratio=global_compression_ratio, - ignore=ignore, - ) + { + "config_groups": config_groups, + "quant_method": quant_method, + "format": format, + "quantization_status": quantization_status, + "global_compression_ratio": global_compression_ratio, + "ignore": ignore, + } ) if sparsity_config: From b61bfb968db36e1f4f1a0f03f8697f7b116b0591 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 10 Jun 2024 20:47:38 +0000 Subject: [PATCH 04/11] revive lines deleted by ruff --- src/transformers/quantizers/quantizer_compressed_tensors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index b24700d38c4e6c..3d0d2e00994282 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -37,12 +37,12 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): from compressed_tensors.compressors import ModelCompressor - # self.compressor = ModelCompressor.from_compression_config(quantization_config.to_dict()) self.compressor = ModelCompressor.from_compression_config(quantization_config) def validate_environment(self, *args, **kwargs): # check torch and compressed_tensors are available, let ImportError raise otherwise - pass + import torch # noqa + from compressed_tensors.compressors import ModelCompressor # noqa def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: From ff8f1c5af0be2eb22095025cd12175345cb2f52e Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Tue, 11 Jun 2024 18:19:41 +0000 Subject: [PATCH 05/11] fixes to load+save from sparseml, edit config to quantization_config, and load back --- src/transformers/quantizers/auto.py | 2 +- .../quantizer_compressed_tensors.py | 2 +- src/transformers/utils/quantization_config.py | 22 +++++++++++++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index f2922ee9677de9..5e26ed91dc4000 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -59,7 +59,7 @@ "aqlm": AqlmConfig, "quanto": QuantoConfig, "hqq": HqqConfig, - "compressed_tensors": CompressedTensorsConfig, + "compressed-tensors": CompressedTensorsConfig, } diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index 3d0d2e00994282..8493fbfd7fa30b 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -68,4 +68,4 @@ def is_trainable(self): @property def is_serializable(self): - return False + return True diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 3031ef6ff3415f..f5aa79e70e4272 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -18,13 +18,14 @@ import importlib.metadata import json import os -from dataclasses import dataclass +from dataclasses import asdict, dataclass, is_dataclass from enum import Enum from typing import Any, Dict, List, Optional, Union from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from packaging import version +from pydantic import BaseModel from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging @@ -70,6 +71,23 @@ class AwqBackendPackingMethod(str, Enum): LLMAWQ = "llm-awq" +def convert_to_dict(obj): + if is_dataclass(obj): + return asdict(obj) + elif isinstance(obj, BaseModel): + return obj.dict() + elif isinstance(obj, Enum): + return obj.value + elif isinstance(obj, dict): + return {k: convert_to_dict(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_to_dict(i) for i in obj] + elif isinstance(obj, tuple): + return tuple(convert_to_dict(i) for i in obj) + else: + return obj + + @dataclass class QuantizationConfigMixin: """ @@ -133,7 +151,7 @@ def to_dict(self) -> Dict[str, Any]: Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ - return copy.deepcopy(self.__dict__) + return convert_to_dict(copy.deepcopy(self.__dict__)) def __iter__(self): """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" From c1cb55debbd174b9208961b5951c99a9472dfebf Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Tue, 11 Jun 2024 18:21:26 +0000 Subject: [PATCH 06/11] address satrat comment --- src/transformers/utils/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index f5aa79e70e4272..4ad61ddb229e81 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -45,7 +45,7 @@ class QuantizationMethod(str, Enum): QUANTO = "quanto" EETQ = "eetq" HQQ = "hqq" - COMPRESSED_TENSORS = "compressed_tensors" + COMPRESSED_TENSORS = "compressed-tensors" class AWQLinearVersion(str, Enum): From ef9d3f174dfc5fe8022b71aa56c4cb13fdf3e66f Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 12 Jun 2024 14:20:01 +0000 Subject: [PATCH 07/11] compressed_tensors to compressed-tensors and revert back is_serializable --- src/transformers/quantizers/auto.py | 2 +- src/transformers/quantizers/quantizer_compressed_tensors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 5e26ed91dc4000..13b8f2bd68ddd4 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -47,7 +47,7 @@ "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, "hqq": HqqHfQuantizer, - "compressed_tensors": CompressedTensorsHfQuantizer, + "compressed-tensors": CompressedTensorsHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index 8493fbfd7fa30b..3d0d2e00994282 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -68,4 +68,4 @@ def is_trainable(self): @property def is_serializable(self): - return True + return False From 117d0504899aac6210c4ec786c3be12087c1a60a Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 12 Jun 2024 15:53:00 +0000 Subject: [PATCH 08/11] rename quant_method from sparseml to compressed-tensors --- src/transformers/utils/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 4ad61ddb229e81..c9824f24e6e847 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1078,7 +1078,7 @@ class CompressedTensorsConfig(QuantizationConfigMixin): def __init__( self, config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, - quant_method: str = "sparseml", + quant_method: str = "compressed-tensors", format: str = "dense", # "fakequant" not in CompressionFormat quantization_status: "QuantizationStatus" = "initialized", global_compression_ratio: Optional[float] = None, From 1901c3e51d7801bae4ef9f129b4278978a74afa6 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 12 Jun 2024 22:05:50 +0000 Subject: [PATCH 09/11] tests --- .../compressed_tensor/__init__.py | 0 .../test_compressed_tensors.py | 76 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 tests/quantization/compressed_tensor/__init__.py create mode 100644 tests/quantization/compressed_tensor/test_compressed_tensors.py diff --git a/tests/quantization/compressed_tensor/__init__.py b/tests/quantization/compressed_tensor/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/quantization/compressed_tensor/test_compressed_tensors.py b/tests/quantization/compressed_tensor/test_compressed_tensors.py new file mode 100644 index 00000000000000..53c294f25cec5d --- /dev/null +++ b/tests/quantization/compressed_tensor/test_compressed_tensors.py @@ -0,0 +1,76 @@ +# from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer +# from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer + +import gc +import unittest + +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig + + +class CompressedTensorsTest(unittest.TestCase): + model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + source_quantized_model_name = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3" + + prompt = "Paris is the capital of which country?" + # [' Paris is the capital of which country?\n\nA. London\n\nB. New York\n\nC. Paris\n\nD. Tokyo\n\n4. Which country is the capital of the European Union?\n\nA. France\n'] + expected_response = "" + + def tear_down(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + @classmethod + def setUpClass(self): + """ + Setup quantized model + """ + self.tokenizer = AutoTokenizer.from_pretrained(self.source_quantized_model_name) + self.source_quantized_model = AutoModelForCausalLM.from_pretrained(self.source_quantized_model_name) + + self.device = self.source_quantized_model.device + compression_config = self.source_quantized_model.config.quantization_config.quantization_config.config_groups + + self.config = CompressedTensorsConfig( + config_groups=compression_config, + sparsity_config=self.source_quantized_model.config.quantization_config.sparsity_config.dict(), + ) + + self.assertIsNotNone(self.config.sparsity_config, "sparsity_config should not be None") + self.assertIsNotNone(self.config.quantization_config, "quantization_config should not be None") + + @unittest.skip("scales not populated") + def test_apply_quantization(self): + # fails bc state_dict_scale = state_dict[f"{module_name}.{scale_name}"] + # KeyError: 'model.layers.0.self_attn.q_proj.weight_scale + self.quantization_model = AutoModelForCausalLM.from_pretrained( + self.model_name, quantization_config=self.config + ) + # check that the input layers of self.source_quantized_model and self.quantization_model is the same + + def test_quantized_model(self): + # test the quantized model, not the original model + + inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.device) + generated_ids = self.source_quantized_model.generate(**inputs, max_length=50) + outputs = self.tokenizer.batch_decode(generated_ids) + + self.expected_response = outputs + self.assertEqual(outputs, self.expected_response) + self.tear_down() + + def test_forward(self): + batch_size = context_size = 1024 + tensor1 = torch.rand(1024).long() + tensor2 = torch.rand(1024).long() + + input_tensor = torch.cat((tensor1, tensor2), dim=0) + input_tensor = input_tensor.unsqueeze(0) + with torch.no_grad(): + out = self.source_quantized_model(input_tensor) + self.assertEqual(out.shape[0], batch_size) + self.assertEqual(out.shape[1], context_size) + + self.tear_down() From 3ca270dfb50f8365eb6c9f92c8f5ca426339c5ca Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Thu, 13 Jun 2024 18:28:14 +0000 Subject: [PATCH 10/11] edit tests --- src/transformers/__init__.py | 1 + src/transformers/utils/quantization_config.py | 2 +- .../test_compressed_tensors.py | 30 ++++++++----------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 40b7905bfdbb04..e4b7a227ab942b 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -916,6 +916,7 @@ "AqlmConfig", "AwqConfig", "BitsAndBytesConfig", + "CompressedTensorsConfig", "EetqConfig", "GPTQConfig", "HqqConfig", diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index c9824f24e6e847..07e88d6394c3c6 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1079,7 +1079,7 @@ def __init__( self, config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, quant_method: str = "compressed-tensors", - format: str = "dense", # "fakequant" not in CompressionFormat + format: str = "dense", quantization_status: "QuantizationStatus" = "initialized", global_compression_ratio: Optional[float] = None, ignore: Optional[List[str]] = None, diff --git a/tests/quantization/compressed_tensor/test_compressed_tensors.py b/tests/quantization/compressed_tensor/test_compressed_tensors.py index 53c294f25cec5d..0b368f7fd785bc 100644 --- a/tests/quantization/compressed_tensor/test_compressed_tensors.py +++ b/tests/quantization/compressed_tensor/test_compressed_tensors.py @@ -7,6 +7,7 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig +from transformers.testing_utils import slow class CompressedTensorsTest(unittest.TestCase): @@ -14,8 +15,6 @@ class CompressedTensorsTest(unittest.TestCase): source_quantized_model_name = "nm-testing/tinyllama-oneshot-w8a8-test-static-shape-change-v3" prompt = "Paris is the capital of which country?" - # [' Paris is the capital of which country?\n\nA. London\n\nB. New York\n\nC. Paris\n\nD. Tokyo\n\n4. Which country is the capital of the European Union?\n\nA. France\n'] - expected_response = "" def tear_down(self): gc.collect() @@ -41,35 +40,30 @@ def setUpClass(self): self.assertIsNotNone(self.config.sparsity_config, "sparsity_config should not be None") self.assertIsNotNone(self.config.quantization_config, "quantization_config should not be None") - @unittest.skip("scales not populated") - def test_apply_quantization(self): - # fails bc state_dict_scale = state_dict[f"{module_name}.{scale_name}"] - # KeyError: 'model.layers.0.self_attn.q_proj.weight_scale - self.quantization_model = AutoModelForCausalLM.from_pretrained( - self.model_name, quantization_config=self.config - ) - # check that the input layers of self.source_quantized_model and self.quantization_model is the same + # apply quantization config to the base model + self.quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, quantization_config=self.config) def test_quantized_model(self): - # test the quantized model, not the original model - + """Carry out generation""" inputs = self.tokenizer(self.prompt, return_tensors="pt").to(self.device) - generated_ids = self.source_quantized_model.generate(**inputs, max_length=50) + generated_ids = self.quantized_model.generate(**inputs, max_length=50) outputs = self.tokenizer.batch_decode(generated_ids) - self.expected_response = outputs - self.assertEqual(outputs, self.expected_response) + self.assertIsNotNone(outputs) self.tear_down() + @slow def test_forward(self): batch_size = context_size = 1024 - tensor1 = torch.rand(1024).long() - tensor2 = torch.rand(1024).long() + tensor1 = torch.rand(1024) * 1000 + tensor1 = tensor1.long() + tensor2 = torch.rand(1024) * 1000 + tensor2 = tensor2.long() input_tensor = torch.cat((tensor1, tensor2), dim=0) input_tensor = input_tensor.unsqueeze(0) with torch.no_grad(): - out = self.source_quantized_model(input_tensor) + out = self.quantized_model(input_tensor) self.assertEqual(out.shape[0], batch_size) self.assertEqual(out.shape[1], context_size) From 823562f882d67cd76987dd58b0f6ee12021bc976 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Fri, 14 Jun 2024 14:15:54 +0000 Subject: [PATCH 11/11] add test to check scale and zp is populated --- .../compressed_tensor/test_compressed_tensors.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/quantization/compressed_tensor/test_compressed_tensors.py b/tests/quantization/compressed_tensor/test_compressed_tensors.py index 0b368f7fd785bc..3e74a854a3c041 100644 --- a/tests/quantization/compressed_tensor/test_compressed_tensors.py +++ b/tests/quantization/compressed_tensor/test_compressed_tensors.py @@ -52,6 +52,13 @@ def test_quantized_model(self): self.assertIsNotNone(outputs) self.tear_down() + def test_quantized_model_initialized_scale_zero_point(self): + for _, module in self.quantized_model.named_modules(): + scheme = getattr(module, "quantization_scheme", None) + if scheme is not None: + self.assertIsNotNone(module.weight_scale) + self.assertIsNotNone(module.weight_zero_point) + @slow def test_forward(self): batch_size = context_size = 1024