From d56801f7572ca12d471324b77724f9a3415ad845 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 9 Oct 2025 14:45:12 -0500 Subject: [PATCH 1/7] quant mixin resolved config Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 2 +- .../modifiers/quantization/gptq/base.py | 4 +- .../quantization/quantization/base.py | 2 +- .../quantization/quantization/mixin.py | 56 +++++++++++-------- 4 files changed, 37 insertions(+), 27 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 72a7240e3..6bc97b446 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -268,7 +268,7 @@ def on_end(self, state: State, event: Event, **kwargs): self.ended_ = True for _, module in tqdm( - match_named_modules(state.model, self.targets, self.ignore), + match_named_modules(state.model, self.resolved_targets, self.ignore), desc="Calibrating weights", ): update_weight_zp_scale(module) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 09f3e681c..430ee9ae7 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -162,7 +162,9 @@ def on_initialize(self, state: State, **kwargs) -> bool: # prepare module names self._module_names = { m: name - for name, m in match_named_modules(state.model, self.targets, self.ignore) + for name, m in match_named_modules( + state.model, self.resolved_targets, self.ignore + ) } return True diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 424cabcf6..1330d16ad 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -71,7 +71,7 @@ def on_start(self, state: State, event: Event, **kwargs): QuantizationMixin.start_calibration(self, state.model) named_modules = list( - match_named_modules(state.model, self.targets, self.ignore) + match_named_modules(state.model, self.resolved_targets, self.ignore) ) # TODO: this step can be combined with update_weight_zp_scale # once update_fused_layer_weight_global_scales is removed diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index f37efb56a..7a27ea40e 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -62,6 +62,8 @@ class QuantizationMixin(HooksMixin): :param targets: list of layer names to quantize if a scheme is provided. If unset, will contain all targets listed in config_groups. If config_groups is also unset, will default to ["Linear"] (i.e. all Linear layers will be targeted). + This field is not the source of truth for all targets, it must be resolved + with config_groups. Use resolved_targets instead. :param ignore: optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list. :param scheme: a single quantization scheme to apply to the model. This is a @@ -83,12 +85,14 @@ class QuantizationMixin(HooksMixin): """ config_groups: Optional[Dict[str, QuantizationScheme]] = None - targets: Union[str, List[str]] = Field(default_factory=list) + targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) ignore: List[str] = Field(default_factory=list) scheme: Optional[Union[str, Dict[str, Any]]] = None kv_cache_scheme: Optional[QuantizationArgs] = None _calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set) + _resolved_config: Optional[QuantizationConfig] = PrivateAttr(None) + _resolved_targets: Optional[List[str]] = PrivateAttr(None) @field_validator("targets", mode="before") def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: @@ -116,27 +120,33 @@ def validate_scheme( return value - @model_validator(mode="after") - def validate_model_after(model: "QuantizationMixin") -> "QuantizationMixin": + @property + def resolved_config(self): """ - - If targets have not been set, aggregate targets from config_groups - into a single unique list - - If targets have still not been found, default to targets=["Linear"] + Quantization config needs to be resolved just once based on + scheme and config_groups inputs. """ + if self._resolved_config is None: + self._resolved_config = self.resolve_quantization_config() + return self._resolved_config - if len(model.targets) > 0 and model.config_groups is not None: - raise ValueError("Please specify either `targets` or `config_groups`") - - if len(model.targets) == 0 and model.config_groups is not None: - for config_group in model.config_groups.values(): + @property + def resolved_targets(self): + """ + List of all resolved targets, i.e. all unique targets listed + in resolved quantization config. + Use this property instead of the targets field, as targets can + also come from config_groups depending on how recipe is configured. + """ + if self._resolved_targets is None: + targets = [] + for config_group in self.resolved_config.config_groups.items(): for target in config_group.targets: - if target not in model.targets: - model.targets.append(target) - - if len(model.targets) == 0: - model.targets.append("Linear") + if target not in targets: + targets.append(target) + self._resolved_targets = targets - return model + return self._resolved_targets def initialize_quantization(self, model: torch.nn.Module): """ @@ -145,13 +155,11 @@ def initialize_quantization(self, model: torch.nn.Module): :param model: model to attach schemes and observers to """ - # apply scheme and status to model - config = self.resolve_quantization_config() - for _, module in match_named_modules(model, self.targets, self.ignore): + for _, module in match_named_modules(model, self.resolved_targets, self.ignore): reset_quantization_status(module) # reset any previously applied qconfigs - apply_quantization_config(model, config) + apply_quantization_config(model, self.resolved_config) # disable quantization until calibration model.apply(disable_quantization) @@ -164,7 +172,7 @@ def start_calibration(self, model: torch.nn.Module): :param model: model to prepare for calibration """ self._calibration_hooks = self._initialize_hooks(model) - for _, module in match_named_modules(model, self.targets, self.ignore): + for _, module in match_named_modules(model, self.resolved_targets, self.ignore): self._initialize_observers(module) apply_calibration_status(module) @@ -178,7 +186,7 @@ def end_calibration(self, model: torch.nn.Module): :param model: model to end calibration for """ self.remove_hooks(self._calibration_hooks) - for _, module in match_named_modules(model, self.targets, self.ignore): + for _, module in match_named_modules(model, self.resolved_targets, self.ignore): freeze_module_quantization(module) # remove observers model.apply(enable_quantization) # keep quantization enabled @@ -270,7 +278,7 @@ def _initialize_observers(self, module: torch.nn.Module): def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() - for _, module in match_named_modules(model, self.targets, self.ignore): + for _, module in match_named_modules(model, self.resolved_targets, self.ignore): if not hasattr(module, "quantization_scheme"): continue From 54ede55afc86fda3f1b43d545ce797b5580261f6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 9 Oct 2025 14:49:37 -0500 Subject: [PATCH 2/7] typo Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/quantization/quantization/mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 7a27ea40e..7e30f04a3 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -140,7 +140,7 @@ def resolved_targets(self): """ if self._resolved_targets is None: targets = [] - for config_group in self.resolved_config.config_groups.items(): + for config_group in self.resolved_config.config_groups.values(): for target in config_group.targets: if target not in targets: targets.append(target) From 550f002245c998d3fb38791dd85cf0ef43ba8041 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 9 Oct 2025 15:01:28 -0500 Subject: [PATCH 3/7] missing resolved_targets Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/quantization/gptq/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 430ee9ae7..385de9840 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -178,7 +178,9 @@ def on_start(self, state: State, event: Event, **kwargs): # register gptq hooks added_hook = False - for _, module in match_named_modules(state.model, self.targets, self.ignore): + for _, module in match_named_modules( + state.model, self.resolved_targets, self.ignore + ): if getattr_chain(module, "quantization_scheme.weights", None) is not None: # HACK: previously, embeddings were not quantized because they were not # accessible by the layer compressor. For now, we manually ignore it, From be8bda7c030cce0d7741a99b0e0144a7e0fcef05 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 9 Oct 2025 15:36:02 -0500 Subject: [PATCH 4/7] stylefix Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/quantization/quantization/mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 7e30f04a3..7f4491941 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -15,7 +15,7 @@ preset_name_to_scheme, ) from compressed_tensors.utils import match_named_modules -from pydantic import Field, PrivateAttr, field_validator, model_validator +from pydantic import Field, PrivateAttr, field_validator from torch.utils.hooks import RemovableHandle from llmcompressor.modifiers.quantization.calibration import ( From 41aa5102401db063ff546a2692e27a3f960c9ac5 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 10 Oct 2025 16:15:41 -0500 Subject: [PATCH 5/7] codereview updates Signed-off-by: Brian Dellabetta --- .../quantization/quantization/mixin.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 7f4491941..918308660 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -92,7 +92,6 @@ class QuantizationMixin(HooksMixin): _calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set) _resolved_config: Optional[QuantizationConfig] = PrivateAttr(None) - _resolved_targets: Optional[List[str]] = PrivateAttr(None) @field_validator("targets", mode="before") def validate_targets(cls, value: Union[str, List[str]]) -> List[str]: @@ -121,7 +120,7 @@ def validate_scheme( return value @property - def resolved_config(self): + def resolved_config(self) -> QuantizationConfig: """ Quantization config needs to be resolved just once based on scheme and config_groups inputs. @@ -131,22 +130,18 @@ def resolved_config(self): return self._resolved_config @property - def resolved_targets(self): + def resolved_targets(self) -> Set[str]: """ - List of all resolved targets, i.e. all unique targets listed + Set of all resolved targets, i.e. all unique targets listed in resolved quantization config. Use this property instead of the targets field, as targets can also come from config_groups depending on how recipe is configured. """ - if self._resolved_targets is None: - targets = [] - for config_group in self.resolved_config.config_groups.values(): - for target in config_group.targets: - if target not in targets: - targets.append(target) - self._resolved_targets = targets - - return self._resolved_targets + targets = set() + for config_group in self.resolved_config.config_groups.values(): + for target in config_group.targets: + targets.add(target) + return targets def initialize_quantization(self, model: torch.nn.Module): """ From bded1be2465ff71a16b148ab198dd9179cd134fc Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 13 Oct 2025 14:01:40 -0500 Subject: [PATCH 6/7] codereview updates Signed-off-by: Brian Dellabetta --- .../modifiers/quantization/quantization/mixin.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 918308660..719431d88 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -62,8 +62,9 @@ class QuantizationMixin(HooksMixin): :param targets: list of layer names to quantize if a scheme is provided. If unset, will contain all targets listed in config_groups. If config_groups is also unset, will default to ["Linear"] (i.e. all Linear layers will be targeted). - This field is not the source of truth for all targets, it must be resolved - with config_groups. Use resolved_targets instead. + This field is not the source of truth for finding all matching target layers + in a model. Additional information can be stored in `config_groups`. Use + self.resolved_targets instead. :param ignore: optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list. :param scheme: a single quantization scheme to apply to the model. This is a @@ -85,6 +86,9 @@ class QuantizationMixin(HooksMixin): """ config_groups: Optional[Dict[str, QuantizationScheme]] = None + # NOTE: targets is not the sole source of truth for finding all matching target + # layers in a model. Additional information can be stored in `config_groups` + # Use self.resolved_targets as source of truth. targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"]) ignore: List[str] = Field(default_factory=list) scheme: Optional[Union[str, Dict[str, Any]]] = None From df87c1b079f8e81a8ee416a0c181d0f64c4c1d34 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 13 Oct 2025 14:28:03 -0500 Subject: [PATCH 7/7] unit test Signed-off-by: Brian Dellabetta --- .../modifiers/quantization/test_base.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/llmcompressor/modifiers/quantization/test_base.py b/tests/llmcompressor/modifiers/quantization/test_base.py index b582f9cea..51ea28f13 100644 --- a/tests/llmcompressor/modifiers/quantization/test_base.py +++ b/tests/llmcompressor/modifiers/quantization/test_base.py @@ -159,3 +159,55 @@ def test_serialize_actorder(has_actorder, actorder, exp_actorder): modifier = GPTQModifier(targets=["Linear"], scheme="W8A8") assert modifier.model_dump()["actorder"] == exp_actorder + + +@pytest.mark.parametrize( + "scheme,targets,config_groups,resolved_targets,should_error", + [ + ("W4A16", ["Linear"], None, {"Linear"}, False), + ( + "W4A16", + [r"re:.*q_proj$", r"re:.*k_proj$"], + None, + {r"re:.*q_proj$", r"re:.*k_proj$"}, + False, + ), + ( + None, + ["Linear"], + dict( + group_0=dict( + targets=[r"re:.*q_proj$"], + ), + group_1=dict( + targets=[r"re:.*k_proj$"], + ), + ), + {r"re:.*q_proj$", r"re:.*k_proj$"}, + False, + ), + ( + "W4AA16", + ["Linear"], + dict( + group_0=dict( + targets=[r"re:.*q_proj$"], + ), + ), + {}, + True, + ), + ], +) +def test_resolved_targets( + scheme, targets, config_groups, should_error, resolved_targets +): + if should_error: + with pytest.raises(ValueError): + GPTQModifier(targets=targets, scheme=scheme, config_groups=config_groups) + else: + modifier = GPTQModifier( + targets=targets, scheme=scheme, config_groups=config_groups + ) + + assert modifier.resolved_targets == resolved_targets