@@ -2730,3 +2730,48 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]):
2730
2730
return True
2731
2731
2732
2732
return False
2733
+
2734
+
2735
+ def get_avg_bits (module , with_lm_head = False ):
2736
+ """
2737
+ Calculates the average number of bits per weight element for supported layers in a given module.
2738
+
2739
+ Iterates through all named modules in the module, accumulating the total number of weight elements
2740
+ and the corresponding bit usage, including additional scale bits for specific data types.
2741
+
2742
+ Args:
2743
+ module: A neural network module containing layers to be analyzed.
2744
+
2745
+ Returns:
2746
+ float: The average number of bits per weight element across all supported layers.
2747
+
2748
+ Note:
2749
+ - Only layers of types specified in SUPPORTED_LAYER_TYPES are considered.
2750
+ - For certain data types ("fp4_v2", "nv_fp4", "mx_fp4", "mx_fp8"), scale bits are added.
2751
+ - For "fp4_v2" and "nv_fp4", an additional 32 global scale bits are included.
2752
+ """
2753
+ all_numel = 0
2754
+ all_bits = 0
2755
+
2756
+ lm_head_name = get_lm_head_name (module )
2757
+ if lm_head_name is None :
2758
+ with_lm_head = False
2759
+ for n , m in module .named_modules ():
2760
+ if n == lm_head_name and not with_lm_head :
2761
+ continue
2762
+ if isinstance (m , SUPPORTED_LAYER_TYPES ):
2763
+ m_numel = m .weight .numel ()
2764
+ all_numel += m_numel
2765
+ w_bits = m .bits * m_numel
2766
+ all_bits += w_bits
2767
+ if m .data_type in ("fp4_v2" , "nv_fp" , "mx_fp" , "nv_fp4" , "mx_fp4" , "mx_fp8" ):
2768
+ scale_bits = 8 * (m_numel // m .group_size )
2769
+ if m .data_type in ("fp4_v2" , "nv_fp" ):
2770
+ scale_bits += 32 # global scale bits
2771
+ all_bits += scale_bits
2772
+ else : # woq
2773
+ scale_bits = 16 * (m_numel // m .group_size )
2774
+ all_bits += scale_bits
2775
+
2776
+ avg_bits = all_bits / all_numel if all_numel > 0 else 0
2777
+ return round (avg_bits , 6 )
0 commit comments