-
Notifications
You must be signed in to change notification settings - Fork 258
MSE observer for NVFP4 #1840
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shubhra
wants to merge
21
commits into
main
Choose a base branch
from
shubhra/mse_nvfp4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+82
−16
Open
MSE observer for NVFP4 #1840
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
50930c3
MSE support for NVFP4
eba6f56
Add mse support for input activations global scale via MSE
476564e
Remove prints
shubhra 5d2eaee
Consolidate mse for both weights and activations (global scale) under…
shubhra 5e51a09
Remove imports that aren't needed
shubhra 133ac2e
Clean up init
shubhra d702c7a
Support for differentiating between activations and weights
shubhra 4ff715f
Remove prints
shubhra 20de7e4
Corrected mse implementation for fp4
shubhra c17d292
Update check for activation and local scales
shubhra ef54001
Update check for activation and local scales
shubhra d7b03c8
Change the way we check for if we are doing local scale
shubhra c6fef13
Clarify comment in MovingAverageMSEObserver regarding global scale ha…
anmarques 9e933ce
Change the local scale identification method
shubhra 4aa7452
Remove unnecessary import
shubhra d1fcb3e
Fix a long line error
shubhra 8ebdffb
Fix format errors
shubhra 814e9e4
Fix format errors
shubhra 20a0ecc
Fix format errors
shubhra d5729ce
ruff format file
shubhra 408bdbb
Update src/llmcompressor/observers/mse.py
shubhra File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
import torch | ||
from compressed_tensors.quantization.quant_args import QuantizationArgs | ||
from compressed_tensors.quantization.utils import calculate_qparams | ||
from compressed_tensors.quantization.utils import calculate_qparams, is_fp4 | ||
from torch import FloatTensor, IntTensor, Tensor | ||
|
||
from llmcompressor.observers.base import Observer | ||
|
@@ -13,8 +13,13 @@ | |
@Observer.register("mse") | ||
class MovingAverageMSEObserver(Observer): | ||
""" | ||
Implements a dynamic quantization observer that sets the scale and | ||
zero point based on a moving average of the mse-clipped min and max observed values | ||
Implements a dynamic quantization observer that sets the scale and zero | ||
point based on a moving average of observed values. | ||
|
||
Behavior: | ||
- Weights: global and local scales use MSE-optimized min/max. | ||
- Activations: global scale uses MSE-optimized min/max; local scales | ||
use plain min–max. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -25,6 +30,7 @@ def __init__( | |
averaging_constant: float = 0.01, | ||
grid: float = 100.0, | ||
norm: float = 2.4, | ||
base_name: str = "weight", | ||
**kwargs, | ||
): | ||
super().__init__(quantization_args=quantization_args) | ||
|
@@ -36,6 +42,7 @@ def __init__( | |
self.averaging_constant = averaging_constant | ||
self.grid = grid | ||
self.norm = norm | ||
self.is_activation = base_name != "weight" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this is the only place we use |
||
|
||
def calculate_mse_min_max( | ||
self, | ||
|
@@ -44,15 +51,18 @@ def calculate_mse_min_max( | |
global_scale: Optional[torch.Tensor] = None, | ||
): | ||
""" | ||
Computes the mse-clipped min and max values of the observed tensor by | ||
optimizing for quantization error | ||
Computes MSE-optimized min and max values for quantization. | ||
|
||
- Used for weights (global and local). | ||
- Used for activations only at the global scale (local activations use min–max). | ||
|
||
:param observed: observed tensor to calculate quantization parameters for | ||
:param reduce_dims: optional tuple of dimensions to reduce along, | ||
returned values will be shaped (1,) along the reduced dimensions | ||
:param global_scale: optional scale to further scale local quantization scales | ||
:return: tuple of min and max values derived from the observed tensor | ||
""" | ||
|
||
from compressed_tensors.quantization.lifecycle import fake_quantize | ||
|
||
if not reduce_dims: | ||
|
@@ -75,18 +85,38 @@ def calculate_mse_min_max( | |
shrinked_min_val = p * absolute_min_val | ||
shrinked_max_val = p * absolute_max_val | ||
|
||
from compressed_tensors.quantization.utils import generate_gparam | ||
|
||
if (is_fp4(self.quantization_args)) and global_scale is None: | ||
# If the quantization scheme is fp4 and global_scale is still None | ||
# i.e it has not yet been optimized, then we are should first get | ||
shubhra marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# the global scale and then optimize the local scales. | ||
# Local scales are set to by the absolute min and max. | ||
iteration_global_scale = generate_gparam( | ||
updated_min_val=shrinked_min_val, updated_max_val=shrinked_max_val | ||
) | ||
iteration_min_val = absolute_min_val | ||
iteration_max_val = absolute_max_val | ||
else: | ||
# Otherwise, we are optimizing local scales and use the shrinked | ||
# min and max | ||
iteration_min_val = shrinked_min_val | ||
iteration_max_val = shrinked_max_val | ||
iteration_global_scale = global_scale | ||
|
||
candidate_scales, candidate_zero_points = calculate_qparams( | ||
min_vals=shrinked_min_val, | ||
max_vals=shrinked_max_val, | ||
min_vals=iteration_min_val, | ||
max_vals=iteration_max_val, | ||
quantization_args=self.quantization_args, | ||
global_scale=global_scale, | ||
global_scale=iteration_global_scale, | ||
) | ||
|
||
q = fake_quantize( | ||
observed, | ||
candidate_scales, | ||
candidate_zero_points, | ||
self.quantization_args, | ||
global_scale=global_scale, | ||
global_scale=iteration_global_scale, | ||
) | ||
|
||
q -= observed | ||
|
@@ -116,10 +146,15 @@ def calculate_updated_min_max( | |
reduce_dims: Optional[Tuple[int]] = None, | ||
tensor_id: Optional[Any] = None, | ||
global_scale: Optional[torch.Tensor] = None, | ||
is_local: Optional[bool] = False, | ||
) -> Tuple[FloatTensor, IntTensor]: | ||
""" | ||
Updates the mse-clipped min and max values of the observed tensor using | ||
a moving average smoothed by the averaging_constant | ||
a moving average smoothed by the averaging_constant. | ||
|
||
- Weights: global and local scales use MSE-optimized values. | ||
- Activations: global scale uses MSE-optimized values, local scales use | ||
min–max. | ||
|
||
:param observed: observed tensor to calculate quantization parameters for | ||
:param reduce_dims: optional tuple of dimensions to reduce along, | ||
|
@@ -130,11 +165,20 @@ def calculate_updated_min_max( | |
:param global_scale: optional scale to further scale local quantization scales | ||
:return: updated min and max values derived from the observed value | ||
""" | ||
# TODO: will need to be expanded to support fp4 activations; | ||
# currently not supported | ||
min_val, max_val = self.calculate_mse_min_max( | ||
observed, reduce_dims, global_scale=global_scale | ||
) | ||
|
||
# Skip local scales updates for dynamic activations (this will happen at | ||
# runtime) | ||
if self.is_activation and is_local: | ||
# Activations local scales: min–max | ||
min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) | ||
max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) | ||
else: | ||
# Weights, or activations global: MSE loop | ||
min_val, max_val = self.calculate_mse_min_max( | ||
observed, reduce_dims, global_scale=global_scale | ||
) | ||
|
||
tensor_id = tensor_id or "default" | ||
|
||
running_min_val = self.min_val.get(tensor_id, None) | ||
running_max_val = self.max_val.get(tensor_id, None) | ||
|
@@ -150,7 +194,6 @@ def calculate_updated_min_max( | |
max_val - running_max_val | ||
) | ||
|
||
tensor_id = tensor_id or "default" | ||
self.min_val[tensor_id] = updated_min_val | ||
self.max_val[tensor_id] = updated_max_val | ||
return updated_min_val, updated_max_val | ||
|
@@ -180,13 +223,15 @@ def calculate_qparams( | |
tensor_id=tensor_id, | ||
reduce_dims=reduce_dims, | ||
global_scale=global_scale, | ||
is_local=True, | ||
) | ||
scale, zero_point = calculate_qparams( | ||
min_vals=updated_min_val, | ||
max_vals=updated_max_val, | ||
quantization_args=self.quantization_args, | ||
global_scale=global_scale, | ||
) | ||
|
||
return scale, zero_point | ||
|
||
def get_qparams_along_dim( | ||
|
@@ -211,3 +256,23 @@ def reset(self): | |
super().reset() | ||
self.min_val = {} | ||
self.max_val = {} | ||
|
||
def calculate_gparam(self, observed: Tensor) -> torch.Tensor: | ||
""" | ||
Generate a global scale using the observed min and max from MSE optimization. | ||
|
||
- Weights: global scale is computed with standard MSE optimization. | ||
- Activations: global scale is computed with dynamic MSE-based scaling. | ||
|
||
:param observed: observed tensor to calculate quantization parameters for | ||
:return: updated global scale derived from the observed tensor | ||
""" | ||
from compressed_tensors.quantization.utils import generate_gparam | ||
|
||
updated_min_val, updated_max_val = self.calculate_updated_min_max( | ||
observed=observed | ||
) | ||
|
||
return generate_gparam( | ||
updated_min_val=updated_min_val, updated_max_val=updated_max_val | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will have to confirm with @shanjiaz if this needs to be added elsewhere, since there are a couple different places Observers are instantiated