Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def on_end(self, state: State, event: Event, **kwargs):
self.ended_ = True

for _, module in tqdm(
match_named_modules(state.model, self.targets, self.ignore),
match_named_modules(state.model, self.resolved_targets, self.ignore),
desc="Calibrating weights",
):
update_weight_zp_scale(module)
Expand Down
8 changes: 6 additions & 2 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def on_initialize(self, state: State, **kwargs) -> bool:
# prepare module names
self._module_names = {
m: name
for name, m in match_named_modules(state.model, self.targets, self.ignore)
for name, m in match_named_modules(
state.model, self.resolved_targets, self.ignore
)
}

return True
Expand All @@ -176,7 +178,9 @@ def on_start(self, state: State, event: Event, **kwargs):

# register gptq hooks
added_hook = False
for _, module in match_named_modules(state.model, self.targets, self.ignore):
for _, module in match_named_modules(
state.model, self.resolved_targets, self.ignore
):
if getattr_chain(module, "quantization_scheme.weights", None) is not None:
# HACK: previously, embeddings were not quantized because they were not
# accessible by the layer compressor. For now, we manually ignore it,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def on_start(self, state: State, event: Event, **kwargs):
QuantizationMixin.start_calibration(self, state.model)

named_modules = list(
match_named_modules(state.model, self.targets, self.ignore)
match_named_modules(state.model, self.resolved_targets, self.ignore)
)
# TODO: this step can be combined with update_weight_zp_scale
# once update_fused_layer_weight_global_scales is removed
Expand Down
57 changes: 30 additions & 27 deletions src/llmcompressor/modifiers/quantization/quantization/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
preset_name_to_scheme,
)
from compressed_tensors.utils import match_named_modules
from pydantic import Field, PrivateAttr, field_validator, model_validator
from pydantic import Field, PrivateAttr, field_validator
from torch.utils.hooks import RemovableHandle

from llmcompressor.modifiers.quantization.calibration import (
Expand Down Expand Up @@ -62,6 +62,8 @@ class QuantizationMixin(HooksMixin):
:param targets: list of layer names to quantize if a scheme is provided. If unset,
will contain all targets listed in config_groups. If config_groups is also
unset, will default to ["Linear"] (i.e. all Linear layers will be targeted).
This field is not the source of truth for all targets, it must be resolved
with config_groups. Use resolved_targets instead.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param scheme: a single quantization scheme to apply to the model. This is a
Expand All @@ -83,12 +85,13 @@ class QuantizationMixin(HooksMixin):
"""

config_groups: Optional[Dict[str, QuantizationScheme]] = None
targets: Union[str, List[str]] = Field(default_factory=list)
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this returns to the default behavior. the setting to ["Linear"] was part of the validation layer that is now removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to conflict with plans to use AWQ without quantization targets? You could also potentially set the AWQ subclass default to be None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a comment encouraging use of resolved_targets

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding AWQ, I think we can get by with this default, but allowing the type to be Optional so that a user can set it to None. I wanted to do that with an example in a follow-up PR.
Regarding a comment, I did so in the docstring. If changing the field name to unresolved_targets is preferable, I can do that at as well. WDYT? It can still be serialized as targets by setting alias="targets" in the pydantic Field inputs

ignore: List[str] = Field(default_factory=list)
scheme: Optional[Union[str, Dict[str, Any]]] = None
kv_cache_scheme: Optional[QuantizationArgs] = None

_calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set)
_resolved_config: Optional[QuantizationConfig] = PrivateAttr(None)

@field_validator("targets", mode="before")
def validate_targets(cls, value: Union[str, List[str]]) -> List[str]:
Expand Down Expand Up @@ -116,27 +119,29 @@ def validate_scheme(

return value

@model_validator(mode="after")
def validate_model_after(model: "QuantizationMixin") -> "QuantizationMixin":
@property
def resolved_config(self) -> QuantizationConfig:
"""
- If targets have not been set, aggregate targets from config_groups
into a single unique list
- If targets have still not been found, default to targets=["Linear"]
Quantization config needs to be resolved just once based on
scheme and config_groups inputs.
"""
if self._resolved_config is None:
self._resolved_config = self.resolve_quantization_config()
return self._resolved_config

if len(model.targets) > 0 and model.config_groups is not None:
raise ValueError("Please specify either `targets` or `config_groups`")

if len(model.targets) == 0 and model.config_groups is not None:
for config_group in model.config_groups.values():
for target in config_group.targets:
if target not in model.targets:
model.targets.append(target)

if len(model.targets) == 0:
model.targets.append("Linear")

return model
@property
def resolved_targets(self) -> Set[str]:
"""
Set of all resolved targets, i.e. all unique targets listed
in resolved quantization config.
Use this property instead of the targets field, as targets can
also come from config_groups depending on how recipe is configured.
"""
targets = set()
for config_group in self.resolved_config.config_groups.values():
for target in config_group.targets:
targets.add(target)
return targets

def initialize_quantization(self, model: torch.nn.Module):
"""
Expand All @@ -145,13 +150,11 @@ def initialize_quantization(self, model: torch.nn.Module):

:param model: model to attach schemes and observers to
"""
# apply scheme and status to model
config = self.resolve_quantization_config()

for _, module in match_named_modules(model, self.targets, self.ignore):
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
reset_quantization_status(module) # reset any previously applied qconfigs

apply_quantization_config(model, config)
apply_quantization_config(model, self.resolved_config)

# disable quantization until calibration
model.apply(disable_quantization)
Expand All @@ -164,7 +167,7 @@ def start_calibration(self, model: torch.nn.Module):
:param model: model to prepare for calibration
"""
self._calibration_hooks = self._initialize_hooks(model)
for _, module in match_named_modules(model, self.targets, self.ignore):
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
self._initialize_observers(module)
apply_calibration_status(module)

Expand All @@ -178,7 +181,7 @@ def end_calibration(self, model: torch.nn.Module):
:param model: model to end calibration for
"""
self.remove_hooks(self._calibration_hooks)
for _, module in match_named_modules(model, self.targets, self.ignore):
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
freeze_module_quantization(module) # remove observers

model.apply(enable_quantization) # keep quantization enabled
Expand Down Expand Up @@ -270,7 +273,7 @@ def _initialize_observers(self, module: torch.nn.Module):

def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
hooks = set()
for _, module in match_named_modules(model, self.targets, self.ignore):
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
if not hasattr(module, "quantization_scheme"):
continue

Expand Down