Skip to content

Commit 5a040c4

Browse files
xinhe3xin3he
authored andcommitted
dump avg_bits
Signed-off-by: xinhe3 <xinhe3@habana.ai> Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent bd4bc2c commit 5a040c4

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

auto_round/compressors/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
to_device,
9696
to_dtype,
9797
unsupport_meta_device,
98+
get_avg_bits,
9899
)
99100
from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block
100101

@@ -1690,6 +1691,13 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16901691
# because it may cause the gguf format to not be exported normally.
16911692
self.model = _handle_moe_model(self.model, formats=formats)
16921693
self.has_qlayer_outside_block = self._set_layerwise_config(self.layer_config)
1694+
average_bits = get_avg_bits(self.model)
1695+
average_bits_w_lm_head = get_avg_bits(self.model, with_lm_head=True)
1696+
if average_bits_w_lm_head != average_bits:
1697+
logger.info(f"The target average bits of blocks in the model (without lm_head): {average_bits:.3f} bits")
1698+
logger.info(f"The target average bits of the entire model (with lm_head): {average_bits_w_lm_head:.3f} bits")
1699+
else:
1700+
logger.info(f"The target average bits of the entire model: {average_bits:.3f} bits")
16931701
if not hasattr(self, "formats"):
16941702
logger.warning("this API is deprecated, please use `quantize_and_save` instead")
16951703
else:

auto_round/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2730,3 +2730,48 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]):
27302730
return True
27312731

27322732
return False
2733+
2734+
2735+
def get_avg_bits(module, with_lm_head=False):
2736+
"""
2737+
Calculates the average number of bits per weight element for supported layers in a given module.
2738+
2739+
Iterates through all named modules in the module, accumulating the total number of weight elements
2740+
and the corresponding bit usage, including additional scale bits for specific data types.
2741+
2742+
Args:
2743+
module: A neural network module containing layers to be analyzed.
2744+
2745+
Returns:
2746+
float: The average number of bits per weight element across all supported layers.
2747+
2748+
Note:
2749+
- Only layers of types specified in SUPPORTED_LAYER_TYPES are considered.
2750+
- For certain data types ("fp4_v2", "nv_fp4", "mx_fp4", "mx_fp8"), scale bits are added.
2751+
- For "fp4_v2" and "nv_fp4", an additional 32 global scale bits are included.
2752+
"""
2753+
all_numel = 0
2754+
all_bits = 0
2755+
2756+
lm_head_name = get_lm_head_name(module)
2757+
if lm_head_name is None:
2758+
with_lm_head = False
2759+
for n, m in module.named_modules():
2760+
if n == lm_head_name and not with_lm_head:
2761+
continue
2762+
if isinstance(m, SUPPORTED_LAYER_TYPES):
2763+
m_numel = m.weight.numel()
2764+
all_numel += m_numel
2765+
w_bits = m.bits * m_numel
2766+
all_bits += w_bits
2767+
if m.data_type in ("fp4_v2", "nv_fp", "mx_fp", "nv_fp4", "mx_fp4", "mx_fp8"):
2768+
scale_bits = 8 * (m_numel // m.group_size)
2769+
if m.data_type in ("fp4_v2", "nv_fp"):
2770+
scale_bits += 32 # global scale bits
2771+
all_bits += scale_bits
2772+
else: # woq
2773+
scale_bits = 16 * (m_numel // m.group_size)
2774+
all_bits += scale_bits
2775+
2776+
avg_bits = all_bits / all_numel if all_numel > 0 else 0
2777+
return round(avg_bits, 6)

0 commit comments

Comments
 (0)