Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Shortcutting Min-max Observer #887

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ def initialize_observer(

quantization_args = getattr(quantization_scheme, arg_name, None)
# dont need observers for dynamic
if quantization_args and not quantization_args.dynamic:
observer = quantization_args.get_observer()
if quantization_args is not None and not quantization_args.dynamic:
observer = Observer.load_from_registry(
observer, quantization_args=quantization_args
quantization_args.observer, quantization_args=quantization_args
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be consistent in how we're fetching the observer - either use the get_observer method or remove it and do it how you're doing it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally in favor of removing the get_observer now that observer refactor work is done

#939

)
module.register_module(f"{base_name}_observer", observer)

Expand Down
25 changes: 16 additions & 9 deletions src/llmcompressor/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import calculate_qparams
from torch import FloatTensor, IntTensor, Tensor

from llmcompressor.observers.base import Observer

__all__ = ["MovingAverageMinMaxObserver"]
__all__ = ["MinMaxObserver"]


@Observer.register("minmax")
class MovingAverageMinMaxObserver(Observer):
class MinMaxObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
zero point based on a moving average of the overall min and max observed values
Implements a quantization observer that calculates scale and zero point based on the
minimum and maximum values of the tensor being observed. If averaging_constant is
specified, then the scales are updated using a moving average
"""

def __init__(
Expand All @@ -42,13 +42,13 @@ def __init__(

def calculate_qparams(
self,
observed: Tensor,
observed: torch.Tensor,
reduce_dims: Optional[Tuple[int]] = None,
tensor_id: Optional[Any] = None,
) -> Tuple[FloatTensor, IntTensor]:
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
"""
Updates the observed min and max using a moving average smoothed by the
averaging_constant
averaging_constant. Set the averaging_constant to 1.0 to disable averaging.

:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
Expand All @@ -66,6 +66,10 @@ def calculate_qparams(
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

# early stopping, save some computation and memory
if self.averaging_constant == 1.0:
return calculate_qparams(min_val, max_val, self.quantization_args)

running_min_val = self.min_val.get(tensor_id, None)
running_max_val = self.max_val.get(tensor_id, None)

Expand All @@ -88,8 +92,11 @@ def calculate_qparams(
)

def get_qparams_along_dim(
self, observed, dim: int, tensor_id: Optional[Any] = None
self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None
):
"""
Calculate quantization parameters along the specified dimension
"""
reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
return self.calculate_qparams(
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
Expand Down
Loading