From 5a040c49aa38d6e87f7ccd0dd379c8951568989b Mon Sep 17 00:00:00 2001 From: xinhe3 Date: Thu, 18 Sep 2025 12:15:55 +0300 Subject: [PATCH 1/4] dump avg_bits Signed-off-by: xinhe3 Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 8 ++++++ auto_round/utils.py | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 9ed346bb8..b3c42b1c8 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -95,6 +95,7 @@ to_device, to_dtype, unsupport_meta_device, + get_avg_bits, ) from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block @@ -1690,6 +1691,13 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: # because it may cause the gguf format to not be exported normally. self.model = _handle_moe_model(self.model, formats=formats) self.has_qlayer_outside_block = self._set_layerwise_config(self.layer_config) + average_bits = get_avg_bits(self.model) + average_bits_w_lm_head = get_avg_bits(self.model, with_lm_head=True) + if average_bits_w_lm_head != average_bits: + logger.info(f"The target average bits of blocks in the model (without lm_head): {average_bits:.3f} bits") + logger.info(f"The target average bits of the entire model (with lm_head): {average_bits_w_lm_head:.3f} bits") + else: + logger.info(f"The target average bits of the entire model: {average_bits:.3f} bits") if not hasattr(self, "formats"): logger.warning("this API is deprecated, please use `quantize_and_save` instead") else: diff --git a/auto_round/utils.py b/auto_round/utils.py index 2d98ae121..ee4554efc 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -2730,3 +2730,48 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]): return True return False + + +def get_avg_bits(module, with_lm_head=False): + """ + Calculates the average number of bits per weight element for supported layers in a given module. + + Iterates through all named modules in the module, accumulating the total number of weight elements + and the corresponding bit usage, including additional scale bits for specific data types. + + Args: + module: A neural network module containing layers to be analyzed. + + Returns: + float: The average number of bits per weight element across all supported layers. + + Note: + - Only layers of types specified in SUPPORTED_LAYER_TYPES are considered. + - For certain data types ("fp4_v2", "nv_fp4", "mx_fp4", "mx_fp8"), scale bits are added. + - For "fp4_v2" and "nv_fp4", an additional 32 global scale bits are included. + """ + all_numel = 0 + all_bits = 0 + + lm_head_name = get_lm_head_name(module) + if lm_head_name is None: + with_lm_head = False + for n, m in module.named_modules(): + if n == lm_head_name and not with_lm_head: + continue + if isinstance(m, SUPPORTED_LAYER_TYPES): + m_numel = m.weight.numel() + all_numel += m_numel + w_bits = m.bits * m_numel + all_bits += w_bits + if m.data_type in ("fp4_v2", "nv_fp", "mx_fp", "nv_fp4", "mx_fp4", "mx_fp8"): + scale_bits = 8 * (m_numel // m.group_size) + if m.data_type in ("fp4_v2", "nv_fp"): + scale_bits += 32 # global scale bits + all_bits += scale_bits + else: # woq + scale_bits = 16 * (m_numel // m.group_size) + all_bits += scale_bits + + avg_bits = all_bits / all_numel if all_numel > 0 else 0 + return round(avg_bits, 6) From 38ab517caeb871aad838d30a604c72837b0623ab Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Thu, 25 Sep 2025 04:30:58 -0400 Subject: [PATCH 2/4] update per review Signed-off-by: He, Xin3 --- auto_round/compressors/base.py | 6 ++---- auto_round/utils.py | 26 ++++++++++++++++++-------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index b3c42b1c8..c818dc919 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1693,11 +1693,9 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: self.has_qlayer_outside_block = self._set_layerwise_config(self.layer_config) average_bits = get_avg_bits(self.model) average_bits_w_lm_head = get_avg_bits(self.model, with_lm_head=True) + logger.info(f"The target average bits: {average_bits:.3f} bits") if average_bits_w_lm_head != average_bits: - logger.info(f"The target average bits of blocks in the model (without lm_head): {average_bits:.3f} bits") - logger.info(f"The target average bits of the entire model (with lm_head): {average_bits_w_lm_head:.3f} bits") - else: - logger.info(f"The target average bits of the entire model: {average_bits:.3f} bits") + logger.debug(f"The target average bits (including lm_head): {average_bits_w_lm_head:.3f} bits") if not hasattr(self, "formats"): logger.warning("this API is deprecated, please use `quantize_and_save` instead") else: diff --git a/auto_round/utils.py b/auto_round/utils.py index ee4554efc..922ff2805 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -2750,6 +2750,16 @@ def get_avg_bits(module, with_lm_head=False): - For certain data types ("fp4_v2", "nv_fp4", "mx_fp4", "mx_fp8"), scale bits are added. - For "fp4_v2" and "nv_fp4", an additional 32 global scale bits are included. """ + def _get_scale_num(bits, group_size, input_features, weight_numel): + if bits >= 16: + return 0 + if group_size == 0: + return 1 + elif group_size == -1: + return input_features + else: + return weight_numel // group_size + all_numel = 0 all_bits = 0 @@ -2760,18 +2770,18 @@ def get_avg_bits(module, with_lm_head=False): if n == lm_head_name and not with_lm_head: continue if isinstance(m, SUPPORTED_LAYER_TYPES): + # get weight bits m_numel = m.weight.numel() all_numel += m_numel w_bits = m.bits * m_numel all_bits += w_bits - if m.data_type in ("fp4_v2", "nv_fp", "mx_fp", "nv_fp4", "mx_fp4", "mx_fp8"): - scale_bits = 8 * (m_numel // m.group_size) - if m.data_type in ("fp4_v2", "nv_fp"): - scale_bits += 32 # global scale bits - all_bits += scale_bits - else: # woq - scale_bits = 16 * (m_numel // m.group_size) - all_bits += scale_bits + # get scale bits + scale_num = _get_scale_num(m.bits, m.group_size, m.weight.shape[-1], m_numel) + bits_per_scale = 16 if m.data_type == "int" else 8 + scale_bits = bits_per_scale * scale_num + if m.data_type in ("fp4_v2", "nv_fp"): + scale_bits += 32 # global scale bits + all_bits += scale_bits avg_bits = all_bits / all_numel if all_numel > 0 else 0 return round(avg_bits, 6) From 440ce637ac95614235b096945984e209a97e0dbd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Sep 2025 08:31:02 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 2 +- auto_round/utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index c818dc919..60c5cb50d 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -66,6 +66,7 @@ estimate_tuning_block_mem, find_matching_blocks, flatten_list, + get_avg_bits, get_block_names, get_device_memory, get_fp_layer_names, @@ -95,7 +96,6 @@ to_device, to_dtype, unsupport_meta_device, - get_avg_bits, ) from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block diff --git a/auto_round/utils.py b/auto_round/utils.py index 922ff2805..33fc8267c 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -2750,6 +2750,7 @@ def get_avg_bits(module, with_lm_head=False): - For certain data types ("fp4_v2", "nv_fp4", "mx_fp4", "mx_fp8"), scale bits are added. - For "fp4_v2" and "nv_fp4", an additional 32 global scale bits are included. """ + def _get_scale_num(bits, group_size, input_features, weight_numel): if bits >= 16: return 0 From ab0c3d2cf37b8e8dd8e414a32b01e6f087ffc80c Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Thu, 25 Sep 2025 23:51:31 -0400 Subject: [PATCH 4/4] fix per review Signed-off-by: He, Xin3 --- auto_round/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/utils.py b/auto_round/utils.py index 2b1209309..688aad8fe 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -2795,7 +2795,7 @@ def _get_scale_num(bits, group_size, input_features, weight_numel): for n, m in module.named_modules(): if n == lm_head_name and not with_lm_head: continue - if isinstance(m, SUPPORTED_LAYER_TYPES): + if type(m) in SUPPORTED_LAYER_TYPES: # get weight bits m_numel = m.weight.numel() all_numel += m_numel