Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def initialize_observer(
observer = Observer.load_from_registry(
quantization_args.observer,
quantization_args=quantization_args,
base_name=base_name,
averaging_constant=observer_kwargs.get(
"averaging_constant", DEFAULT_AVERAGING_CONSTANT
),
Expand Down
97 changes: 81 additions & 16 deletions src/llmcompressor/observers/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.utils import calculate_qparams
from compressed_tensors.quantization.utils import calculate_qparams, is_fp4
from torch import FloatTensor, IntTensor, Tensor

from llmcompressor.observers.base import Observer
Expand All @@ -13,8 +13,13 @@
@Observer.register("mse")
class MovingAverageMSEObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
zero point based on a moving average of the mse-clipped min and max observed values
Implements a dynamic quantization observer that sets the scale and zero
point based on a moving average of observed values.

Behavior:
- Weights: global and local scales use MSE-optimized min/max.
- Activations: global scale uses MSE-optimized min/max; local scales
use plain min–max.
"""

def __init__(
Expand All @@ -25,6 +30,7 @@ def __init__(
averaging_constant: float = 0.01,
grid: float = 100.0,
norm: float = 2.4,
base_name: str = "weight",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will have to confirm with @shanjiaz if this needs to be added elsewhere, since there are a couple different places Observers are instantiated

**kwargs,
):
super().__init__(quantization_args=quantization_args)
Expand All @@ -36,6 +42,7 @@ def __init__(
self.averaging_constant = averaging_constant
self.grid = grid
self.norm = norm
self.is_activation = base_name != "weight"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is the only place we use base_name, it might just be better to expose is_activation: bool = False to the constructor instead of base_name


def calculate_mse_min_max(
self,
Expand All @@ -44,15 +51,18 @@ def calculate_mse_min_max(
global_scale: Optional[torch.Tensor] = None,
):
"""
Computes the mse-clipped min and max values of the observed tensor by
optimizing for quantization error
Computes MSE-optimized min and max values for quantization.

- Used for weights (global and local).
- Used for activations only at the global scale (local activations use min–max).

:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned values will be shaped (1,) along the reduced dimensions
:param global_scale: optional scale to further scale local quantization scales
:return: tuple of min and max values derived from the observed tensor
"""

from compressed_tensors.quantization.lifecycle import fake_quantize

if not reduce_dims:
Expand All @@ -75,18 +85,38 @@ def calculate_mse_min_max(
shrinked_min_val = p * absolute_min_val
shrinked_max_val = p * absolute_max_val

from compressed_tensors.quantization.utils import generate_gparam

if (is_fp4(self.quantization_args)) and global_scale is None:
# If the quantization scheme is fp4 and global_scale is still None
# i.e it has not yet been optimized, then we are should first get
# the global scale and then optimize the local scales.
# Local scales are set to by the absolute min and max.
iteration_global_scale = generate_gparam(
updated_min_val=shrinked_min_val, updated_max_val=shrinked_max_val
)
iteration_min_val = absolute_min_val
iteration_max_val = absolute_max_val
else:
# Otherwise, we are optimizing local scales and use the shrinked
# min and max
iteration_min_val = shrinked_min_val
iteration_max_val = shrinked_max_val
iteration_global_scale = global_scale

candidate_scales, candidate_zero_points = calculate_qparams(
min_vals=shrinked_min_val,
max_vals=shrinked_max_val,
min_vals=iteration_min_val,
max_vals=iteration_max_val,
quantization_args=self.quantization_args,
global_scale=global_scale,
global_scale=iteration_global_scale,
)

q = fake_quantize(
observed,
candidate_scales,
candidate_zero_points,
self.quantization_args,
global_scale=global_scale,
global_scale=iteration_global_scale,
)

q -= observed
Expand Down Expand Up @@ -116,10 +146,15 @@ def calculate_updated_min_max(
reduce_dims: Optional[Tuple[int]] = None,
tensor_id: Optional[Any] = None,
global_scale: Optional[torch.Tensor] = None,
is_local: Optional[bool] = False,
) -> Tuple[FloatTensor, IntTensor]:
"""
Updates the mse-clipped min and max values of the observed tensor using
a moving average smoothed by the averaging_constant
a moving average smoothed by the averaging_constant.

- Weights: global and local scales use MSE-optimized values.
- Activations: global scale uses MSE-optimized values, local scales use
min–max.

:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
Expand All @@ -130,11 +165,20 @@ def calculate_updated_min_max(
:param global_scale: optional scale to further scale local quantization scales
:return: updated min and max values derived from the observed value
"""
# TODO: will need to be expanded to support fp4 activations;
# currently not supported
min_val, max_val = self.calculate_mse_min_max(
observed, reduce_dims, global_scale=global_scale
)

# Skip local scales updates for dynamic activations (this will happen at
# runtime)
if self.is_activation and is_local:
# Activations local scales: min–max
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
else:
# Weights, or activations global: MSE loop
min_val, max_val = self.calculate_mse_min_max(
observed, reduce_dims, global_scale=global_scale
)

tensor_id = tensor_id or "default"

running_min_val = self.min_val.get(tensor_id, None)
running_max_val = self.max_val.get(tensor_id, None)
Expand All @@ -150,7 +194,6 @@ def calculate_updated_min_max(
max_val - running_max_val
)

tensor_id = tensor_id or "default"
self.min_val[tensor_id] = updated_min_val
self.max_val[tensor_id] = updated_max_val
return updated_min_val, updated_max_val
Expand Down Expand Up @@ -180,13 +223,15 @@ def calculate_qparams(
tensor_id=tensor_id,
reduce_dims=reduce_dims,
global_scale=global_scale,
is_local=True,
)
scale, zero_point = calculate_qparams(
min_vals=updated_min_val,
max_vals=updated_max_val,
quantization_args=self.quantization_args,
global_scale=global_scale,
)

return scale, zero_point

def get_qparams_along_dim(
Expand All @@ -211,3 +256,23 @@ def reset(self):
super().reset()
self.min_val = {}
self.max_val = {}

def calculate_gparam(self, observed: Tensor) -> torch.Tensor:
"""
Generate a global scale using the observed min and max from MSE optimization.

- Weights: global scale is computed with standard MSE optimization.
- Activations: global scale is computed with dynamic MSE-based scaling.

:param observed: observed tensor to calculate quantization parameters for
:return: updated global scale derived from the observed tensor
"""
from compressed_tensors.quantization.utils import generate_gparam

updated_min_val, updated_max_val = self.calculate_updated_min_max(
observed=observed
)

return generate_gparam(
updated_min_val=updated_min_val, updated_max_val=updated_max_val
)