Skip to content

Commit

Permalink
[GPTQ] Iterative Parameter Updating (#863)
Browse files Browse the repository at this point in the history
* Implement iterative parameter updating

Signed-off-by: Kyle Sayers <[email protected]>

* [Bugfix] Use weight parameter of linear layer (#836)

* use weight parameter of linear layer

* add weight attribute check

Signed-off-by: Kyle Sayers <[email protected]>

* [Bugfix] Rename files to remove colons (#846)

* rename files to remove colons

Signed-off-by: Kyle Sayers <[email protected]>

* [Bugfix] Workaround tied tensors bug (#659)

* load offload state dict

* add test

* remove merge duplication

* prepare to fix tie_word_embeddings

* add full tests

* patch second bug

* comment out failing tests, point to next pr

* link to issue

* accomodate offloaded models in test

* add back passing test

* WIP

* add error if not in expected list

* apply style

* update passing failing list

* add shared tensors tests

* clean up

* add comment with link

* make failing tests a todo

* Remove failing tests

* explicitly set safe_serialization

* separate out gpu tests, apply style

---------

Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* only untie word embeddings (#839)

Signed-off-by: Kyle Sayers <[email protected]>

* check for config hidden size (#840)

Signed-off-by: Kyle Sayers <[email protected]>

* Use float32 for Hessian dtype (#847)

* use float32 for hessian dtype

* explicitly set inp dtype as well

* float precision for obcq hessian

Signed-off-by: Kyle Sayers <[email protected]>

* GPTQ: Depreciate non-sequential update option (#762)

* remove from gptq, apply style

* remove instances of sequential_update argument in GPTQ tests

* update examples

* update example tests

* documentation, remove from example

* apply style

* revert back to auto type

* apply style

---------

Co-authored-by: Dipika Sikka <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* Typehint nits (#826)

Signed-off-by: Kyle Sayers <[email protected]>

* [ DOC ] Remove version restrictions in W8A8 exmaple (#849)

The latest compressored-tensor 0.8.0 removed some API,
https://github.com/neuralmagic/compressed-tensors/pull/156/files
If installed the older llmcompressor from pip, it would throw the
error like:
```
ImportError: cannot import name 'update_layer_weight_quant_params' from 'compressed_tensors.quantization'
```

Signed-off-by: Kyle Sayers <[email protected]>

* Fix inconsistence (#80)

Use group strategy with 128 group size instead of channel

Co-authored-by: Dipika Sikka <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* 2of4

Signed-off-by: Kyle Sayers <[email protected]>

* revert change to unrelated example

Signed-off-by: Kyle Sayers <[email protected]>

* rename test file

Signed-off-by: Kyle Sayers <[email protected]>

* fix fwd func call (#845)

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Jincheng Miao <[email protected]>
Co-authored-by: 黄石 <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* cover all 3.9-3.12 in commit testing (#864)

Co-authored-by: dhuangnm <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* Add marlin-24 recipe/configs for e2e testing (#866)

* add marlin-24 recipe/configs for e2e testing

* update

Signed-off-by: Kyle Sayers <[email protected]>

* [Bugfix] onload during sparsity calculation (#862)

* onload during sparsity calculation

* fix sparsity

---------

Co-authored-by: Dipika <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* Fix HFTrainer overloads (#869)

* add missing arguments

Signed-off-by: Kyle Sayers <[email protected]>

* names

Signed-off-by: Kyle Sayers <[email protected]>

* style

Signed-off-by: Kyle Sayers <[email protected]>

* named args all around

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* Support Model Offloading Tied Tensors Patch (#872)

* update parameter of offloaded modules

Signed-off-by: Kyle Sayers <[email protected]>

* in place function

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>

* add advice about dealing with non-invertable hessians (#875)

Signed-off-by: Kyle Sayers <[email protected]>

* seed commit workflow (#877)

* seed commit workflow

Signed-off-by: andy-neuma <[email protected]>

* tickle

Signed-off-by: andy-neuma <[email protected]>

* let's give it a try

Signed-off-by: andy-neuma <[email protected]>

* whitespace

Signed-off-by: andy-neuma <[email protected]>

* delete unneeded workflow

Signed-off-by: andy-neuma <[email protected]>

* adjust trigger

Signed-off-by: andy-neuma <[email protected]>

---------

Signed-off-by: andy-neuma <[email protected]>
Co-authored-by: andy-neuma <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* [Observer Restructure]: Add Observers; Add `calibration` and `frozen` steps to `QuantizationModifier` (#837)

* update functioon

* wip

* clean-up; fix imports

* clean-up

* more clean-up

* bug fix

* update for kvcache

* get kv_cache to work

* docstring

* fix comment

* fix condition for dynamic

* update

* update tests

* add observer tests

* add flake8 skip

* apply updated mse fixes

* fix import

* Update src/llmcompressor/modifiers/quantization/calibration.py

Co-authored-by: Kyle Sayers <[email protected]>

* Update src/llmcompressor/modifiers/quantization/calibration.py

Co-authored-by: Kyle Sayers <[email protected]>

* PR comments

* clean-up

* move hook check to observer call

* update

* separate out calibration step

---------

Co-authored-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* WIP, observer

Signed-off-by: Kyle Sayers <[email protected]>

* use minmax observer

Signed-off-by: Kyle Sayers <[email protected]>

* Bugfix get observer from name (#883)

Signed-off-by: Rahul Tuli <[email protected]>

* BugFix: Fix Sparsity Reload Testing (#882)

* fix

* fix remaining test cases

* add comments

* fix

Signed-off-by: Kyle Sayers <[email protected]>

* Use custom unique test names for e2e tests (#892)

* Include `testconfig_path` in parsed config data

Signed-off-by: Domenic Barbuzzi <[email protected]>

* Use custom unique names for e2e tests

Signed-off-by: Domenic Barbuzzi <[email protected]>

---------

Signed-off-by: Domenic Barbuzzi <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* Revert "Use custom unique test names for e2e tests (#892)" (#893)

This reverts commit 10facf2.

Signed-off-by: Kyle Sayers <[email protected]>

* Move config["testconfig_path"] assignment (#895)

* Use custom unique test names for e2e tests (#892)

* Include `testconfig_path` in parsed config data

Signed-off-by: Domenic Barbuzzi <[email protected]>

* Use custom unique names for e2e tests

Signed-off-by: Domenic Barbuzzi <[email protected]>

---------

Signed-off-by: Domenic Barbuzzi <[email protected]>

* Revert "Use custom unique test names for e2e tests (#892)" (#893)

This reverts commit 10facf2.

Signed-off-by: Domenic Barbuzzi <[email protected]>

* Move config["testconfig_path"] assignment

Signed-off-by: Domenic Barbuzzi <[email protected]>

* Use a function name generator for e2e test names

Signed-off-by: Domenic Barbuzzi <[email protected]>

---------

Signed-off-by: Domenic Barbuzzi <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* cap accelerate version to avoid bug (#897)

Signed-off-by: Kyle Sayers <[email protected]>

* Fix observing offloaded weight (#896)

* load weight within onloading

Signed-off-by: Kyle Sayers <[email protected]>

* remove moving activation to execution device, since this is already done since activation calibration always happens within forward pass

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* Update image in README.md (#861)

Co-authored-by: Dipika Sikka <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>

* use user-specified observer

Signed-off-by: Kyle Sayers <[email protected]>

---------

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: andy-neuma <[email protected]>
Signed-off-by: Rahul Tuli <[email protected]>
Signed-off-by: Domenic Barbuzzi <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Jincheng Miao <[email protected]>
Co-authored-by: 黄石 <[email protected]>
Co-authored-by: dhuangnm <[email protected]>
Co-authored-by: dhuangnm <[email protected]>
Co-authored-by: Andy Linfoot <[email protected]>
Co-authored-by: andy-neuma <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
Co-authored-by: Domenic Barbuzzi <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
13 people committed Nov 21, 2024
1 parent 5f6f568 commit fa61cf6
Showing 1 changed file with 36 additions and 29 deletions.
65 changes: 36 additions & 29 deletions src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import time
from typing import Tuple

from compressed_tensors.quantization import (
ActivationOrdering,
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import fake_quantize

from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD
Expand Down Expand Up @@ -100,20 +96,27 @@ def compress(
diagonal norm
"""
args_loc = "quantization_scheme.weights"
weight_quant_args = getattr_chain(self.layer, args_loc, None)
if weight_quant_args is None:
quant_args = getattr_chain(self.layer, args_loc, None)
if quant_args is None:
logger.debug(f"Skipping unquantized layer {self.name}...")
return

if is_module_offloaded(self.layer):
self.layer._hf_hook.pre_forward(self.layer)

strategy = weight_quant_args.strategy
actorder = weight_quant_args.actorder
strategy = quant_args.strategy
actorder = quant_args.actorder
final_shape = self.layer.weight.shape
final_dtype = self.layer.weight.dtype
W = self.layer.weight.data.clone()

# create observer for calculating quantization parameters
observer = Observer.load_from_registry(
quant_args.observer,
quantization_args=quant_args,
averaging_constant=1.0, # ignore moving average
)

# standardize shape and dtype
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
Expand All @@ -127,26 +130,28 @@ def compress(
# mapping from column index to group index
g_idx = (
torch.arange(self.columns, device=W.device, dtype=torch.int)
// weight_quant_args.group_size
// quant_args.group_size
)

if actorder == ActivationOrdering.GROUP:
# permute by activation order first, then update groups
W, self.H, perm = self._apply_activation_ordering(W, self.H)
self._update_quantization_parameters(weight_quant_args, W)
scale, zero_point = observer(W, g_idx=None)

# use identity g_idx (invert permutation later)

elif actorder == ActivationOrdering.WEIGHT:
# update groups first, then permute by activation order
self._update_quantization_parameters(weight_quant_args, W)
scale, zero_point = observer(W, g_idx=None)
W, self.H, perm = self._apply_activation_ordering(W, self.H)

# permute g_idx to maintain identity mapping after unpermutation
g_idx = g_idx[perm]

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
else:
scale, zero_point = observer(W, g_idx=None)
else:
scale, zero_point = observer(W, g_idx=None)

# sparsity mask
sparsity = tensor_sparsity(W)
Expand Down Expand Up @@ -212,16 +217,28 @@ def compress(
q,
scale[:, 0],
zero_point[:, 0],
weight_quant_args,
quant_args,
)
elif strategy == QuantizationStrategy.GROUP:
# get the group index for the current column
column_idx = i1 + i
group_index = g_idx[column_idx]

# update quantization parameters to reflect changes
# resulting from previous blocks
if (
actorder != ActivationOrdering.WEIGHT
and column_idx % quant_args.group_size == 0
):
_scale, _zero_point = observer.get_qparams_along_dim(
W[:, g_idx == group_index], dim=0
)
scale[:, group_index] = _scale[:, 0]
zero_point[:, group_index] = _zero_point[:, 0]

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(weight_quant_args)
altered_qargs = copy(quant_args)
altered_qargs.strategy = QuantizationStrategy.CHANNEL
q = fake_quantize(
q,
Expand Down Expand Up @@ -279,6 +296,9 @@ def compress(
W.transpose_(0, 1)
W = W.reshape(final_shape).to(final_dtype)

update_parameter_data(self.layer, scale, "weight_scale")
update_parameter_data(self.layer, zero_point, "weight_zero_point")

# This is a bit hacky, but FSDP updates only work if we change
# the weight in place, clone() or direct assignment won't work
self.layer.weight -= self.layer.weight
Expand All @@ -296,19 +316,6 @@ def free(self):
delattr(self, "H")
super().free()

def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor):
"""
Update layer quantization parameters with potentially permuted weight
:param args: quantization arguments
:param W: weight to calculate quantization parameters from
"""
observer = args.get_observer()
observer = Observer.load_from_registry(observer, quantization_args=args)
_scale, _zero_point = observer(W, g_idx=None)
update_parameter_data(self.layer, _scale, "weight_scale")
update_parameter_data(self.layer, _zero_point, "weight_zero_point")

def _apply_activation_ordering(
self, W: torch.Tensor, H: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down

0 comments on commit fa61cf6

Please sign in to comment.