Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool:
"""
# Get the names of layers in quantization blocks
supported_types = self.supported_types
regex_config = {}
layers_in_blocks = get_layer_names_in_block(
self.model, supported_types, self.quant_block_list, self.inner_supported_types
)
Expand Down Expand Up @@ -1977,6 +1978,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool:
matched_names.append(layer_name)
if len(matched_names) > 0:
val = layer_config[name]
regex_config[name] = val # keep regex config
layer_config.pop(name)
for match_name in matched_names:
layer_config[match_name] = val
Expand Down Expand Up @@ -2067,6 +2069,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool:
if need_to_quantize_lm_head:
has_qlayer_outside_block = True

self.regex_config = regex_config
# Return whether there are quantized layers outside the blocks
return has_qlayer_outside_block

Expand Down Expand Up @@ -3160,6 +3163,7 @@ def save_quantized(
"act_data_type",
"super_bits",
"super_group_size",
"regex_config",
]
if isinstance(self.dataset, str):
serialization_keys.append("dataset")
Expand Down
38 changes: 37 additions & 1 deletion auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict

import threadpoolctl as tctl

Expand Down Expand Up @@ -48,6 +49,13 @@

import auto_round.export.export_to_autogptq.qlinear_triton
from auto_round.export.utils import save_model

GPTQ_REQUIRED_CONFIG_KEYS = (
"bits",
"group_size",
"sym",
)

from auto_round.logger import logger
from auto_round.utils import (
SUPPORTED_LAYER_TYPES,
Expand All @@ -57,7 +65,9 @@
get_autogptq_packing_qlinear,
get_block_names,
get_module,
json_serialize,
set_module,
to_standard_regex,
)

BLOCK_PATTERNS = [ ## copy from transformers optimum
Expand All @@ -68,6 +78,31 @@
]


def convert_to_autogptq_dynamic(regex_config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""
Convert AutoRound-style regex_config into AutoGPTQ-style QuantizerConfig.dynamic.

Rules:
- bits < 16 -> quantize -> positive match `+:regex`
- bits == 16 -> skip quantize -> negative match `-:regex`
"""
converted = {}
for name, cfg in regex_config.items():
bits = cfg.get("bits")
regex = to_standard_regex(name)

if bits is None:
continue # ignore invalid entries
elif bits < 16:
converted[f"+:{regex}"] = {"bits": bits}
for key in GPTQ_REQUIRED_CONFIG_KEYS: # only save keys gptq supported
converted[f"+:{regex}"][key] = regex_config[name][key]
else:
# skip quantization
converted[f"-:{regex}"] = {}
return converted


def pack_layer(name, model, backend, device=None):
if name == "lm_head": ##dese not support lm-head
return
Expand Down Expand Up @@ -156,7 +191,8 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
logger.error("auto-gptq format may not support loading this quantized model")
quantization_config["block_name_to_quantize"] = common_prefix
quantization_config.pop("to_quant_block_names", None)

regex_config = quantization_config.pop("regex_config")
quantization_config["dynamic"] = convert_to_autogptq_dynamic(regex_config)
## as layers maybe already packed, we need to check in layer_config
layer_config = kwargs["layer_config"]
for n, m in model.named_modules():
Expand Down
12 changes: 12 additions & 0 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
is_nv_fp,
is_standard_fp,
set_module,
to_standard_regex,
)


Expand Down Expand Up @@ -332,6 +333,10 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
extra_config[layer_name]["act_bits"] = layer_config[layer_name]["act_bits"]
extra_config[layer_name]["act_data_type"] = layer_config[layer_name]["act_data_type"]
extra_config[layer_name]["act_group_size"] = layer_config[layer_name]["act_group_size"]
extra_config[layer_name]["act_sym"] = layer_config[layer_name]["act_sym"]
elif layer_config[layer_name]["in_blocks"] or (
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
):
Expand All @@ -343,6 +348,13 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
for key in neq_keys:
if layer_config[layer_name][key] is not None:
extra_config[layer_name][key] = layer_config[layer_name][key]

regex_config = quantization_config.pop("regex_config")
if regex_config is not None:
for name in regex_config.keys():
regex_name = to_standard_regex(name)
extra_config[regex_name] = {**{k: regex_config[name][k] for k in REQUIRED_CONFIG_KEYS}}

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config
names = list(layer_config.keys())
Expand Down
8 changes: 8 additions & 0 deletions auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_nv_fp,
set_amax_for_all_moe_layers,
set_module,
to_standard_regex,
)
from auto_round.wrapper import WrapperWALayer

Expand Down Expand Up @@ -215,6 +216,13 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
for key in neq_keys:
if layer_config[layer_name][key] is not None:
extra_config[layer_name][key] = layer_config[layer_name][key]

regex_config = quantization_config.pop("regex_config")
if regex_config is not None:
for name in regex_config.keys():
regex_name = to_standard_regex(name)
extra_config[regex_name] = {**{k: regex_config[name][k] for k in REQUIRED_CONFIG_KEYS}}

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config
names = list(layer_config.keys())
Expand Down
1 change: 1 addition & 0 deletions auto_round/export/export_to_awq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def wrapper(name):
return model

quantization_config = kwargs["serialization_dict"]
quantization_config.pop("regex_config") # as awq do not support mixed bits config saving

if output_dir is None:
return compressed_model
Expand Down
14 changes: 6 additions & 8 deletions auto_round/export/export_to_llmcompressor/export_to_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tqdm import tqdm

from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear
from auto_round.export.export_to_llmcompressor.utils import generate_ignore_regex_list
from auto_round.export.utils import save_model
from auto_round.logger import logger
from auto_round.utils import (
Expand Down Expand Up @@ -114,9 +115,8 @@ def pack_layer(name, model, backend, device=None):
scale = layer.scale
global_scale = getattr(layer, "weight_global_scale", None)
input_global_scale = getattr(layer, "input_global_scale", None)
# zero = layer.zp
# zero = layer.zp # no zeros to handle, as mxfp not support asym quantization
qlayer.pack(layer, scale, global_scale=global_scale, input_global_scale=input_global_scale, device=device)
## no zeros to handle, as mxfp not support asym quantization
qlayer.to(orig_device)


Expand Down Expand Up @@ -155,6 +155,9 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
device = kwargs.get("device", None)
tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)
ar_quantization_config = kwargs["serialization_dict"]
regex_config = ar_quantization_config.pop("regex_config")
layer_config = kwargs["layer_config"]
extra_config = {}

if act_bits <= 8:
Expand Down Expand Up @@ -199,12 +202,7 @@ def wrapper(name):
for _ in executor.map(wrapper, names):
pass

# TODO fix the ignore re match issue, compile with fp8 & int8 config
ignore = ["lm_head"]
for layer_name in layer_config:
if layer_config[layer_name]["bits"] > 8: ## find ignore layers
ignore.append(layer_name)
ignore = list(set(ignore))
ignore = generate_ignore_regex_list(regex_config=regex_config, layer_config=layer_config)

# get llm-compressor format config
check_compressed_tensors_supported()
Expand Down
51 changes: 51 additions & 0 deletions auto_round/export/export_to_llmcompressor/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List

from auto_round.utils import matches_any_regex, to_standard_regex


def generate_ignore_regex_list(regex_config: Dict[str, Dict], layer_config: Dict[str, Dict]) -> List[str]:
"""
Generate ignore regex list for llm_compressor based on regex_config and layer_config.

Rules:
1. Any layer in regex_config with bits >= 16 is ignored.
2. Any layer in layer_config with bits >= 16 is ignored if not already included.
3. Output regex patterns are normalized for llm_compressor ('re:...' style).

Args:
regex_config (Dict[str, Dict]): dynamic quantization config
layer_config (Dict[str, Dict]): layer-wise quantization config

Returns:
List[str]: List of regex patterns to ignore during quantization.
"""
prefix = "re:"
ignore_regex: List[str] = []

# Step 1: Add regex_config keys with bits >= 16
for key, cfg in regex_config.items():
bits = cfg.get("bits")
if bits > 8:
ignore_regex.append(prefix + to_standard_regex(key))

# Step 2: Add all full named layer from layer_config if bits >= 16
for key, cfg in layer_config.items():
bits = cfg.get("bits")
if bits > 8:
ignore_regex.append(key)

return ignore_regex
78 changes: 77 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable, List, Tuple, Union

import cpuinfo
import torch
Expand Down Expand Up @@ -2755,3 +2755,79 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]):
return True

return False


def to_standard_regex(pattern: str) -> str:
"""
Convert a user-specified string into a standardized regex for layer matching.

Rules:
- If the pattern already contains regex tokens ('.*', '^', '$', etc.),
keep them as-is.
- Otherwise, wrap the pattern with `.*` on both sides to allow substring matching.
- Always ensure the returned regex is valid (compilable by re).

Examples:
>>> to_standard_regex("model.embed_tokens")
'.*model\\.embed_tokens.*'
>>> to_standard_regex("mlp.gate")
'.*mlp\\.gate.*'
>>> to_standard_regex("mlp.gate$")
'.*mlp\\.gate$'
>>> to_standard_regex("mlp.*gate")
'.*mlp.*gate.*'
"""
# Heuristic: if pattern contains regex meta characters, assume partial regex
meta_chars = {".*", "^", "$", "|", "(", ")", "[", "]", "?", "+"}
has_regex = any(tok in pattern for tok in meta_chars)
if not has_regex:
# Escape literal dots, etc., and wrap with .* for substring matching
pattern = re.escape(pattern)
regex = f".*{pattern}.*"
else:
# Only escape bare dots that are not already part of regex constructs
# Avoid double escaping .* sequences
tmp = []
i = 0
while i < len(pattern):
if pattern[i] == ".":
if i + 1 < len(pattern) and pattern[i + 1] == "*":
tmp.append(".*") # keep regex token
i += 2
continue
else:
tmp.append("\\.") # escape bare dot
else:
tmp.append(pattern[i])
i += 1
regex = "".join(tmp)
# If no anchors are provided, allow substring matching
if not regex.startswith("^") and not regex.startswith(".*"):
regex = ".*" + regex
if not regex.endswith("$") and not regex.endswith(".*"):
regex = regex + ".*"
# Validate regex
try:
re.compile(regex)
except re.error as e:
raise ValueError(f"Invalid regex generated from pattern '{pattern}': {e}")
return regex


def matches_any_regex(layer_name: str, regex_list: List[str], prefix="re:") -> bool:
"""
Check if layer_name matches any regex pattern in regex_list.
"""
for pattern in regex_list:
# Remove 're:' prefix for matching
pat = pattern.removeprefix(prefix)
if re.fullmatch(pat, layer_name):
return True
return False


def json_serialize(obj: Any):
"""Convert non-JSON-serializable objects into JSON-friendly formats."""
if isinstance(obj, torch.dtype):
return str(obj).split(".")[-1] # e.g., torch.float16 -> "float16"
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
Loading