Skip to content

Commit a7cfa3d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5a040c4 commit a7cfa3d

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

auto_round/compressors/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
estimate_tuning_block_mem,
6767
find_matching_blocks,
6868
flatten_list,
69+
get_avg_bits,
6970
get_block_names,
7071
get_device_memory,
7172
get_fp_layer_names,
@@ -95,7 +96,6 @@
9596
to_device,
9697
to_dtype,
9798
unsupport_meta_device,
98-
get_avg_bits,
9999
)
100100
from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block
101101

@@ -1695,7 +1695,9 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16951695
average_bits_w_lm_head = get_avg_bits(self.model, with_lm_head=True)
16961696
if average_bits_w_lm_head != average_bits:
16971697
logger.info(f"The target average bits of blocks in the model (without lm_head): {average_bits:.3f} bits")
1698-
logger.info(f"The target average bits of the entire model (with lm_head): {average_bits_w_lm_head:.3f} bits")
1698+
logger.info(
1699+
f"The target average bits of the entire model (with lm_head): {average_bits_w_lm_head:.3f} bits"
1700+
)
16991701
else:
17001702
logger.info(f"The target average bits of the entire model: {average_bits:.3f} bits")
17011703
if not hasattr(self, "formats"):

auto_round/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2769,7 +2769,7 @@ def get_avg_bits(module, with_lm_head=False):
27692769
if m.data_type in ("fp4_v2", "nv_fp"):
27702770
scale_bits += 32 # global scale bits
27712771
all_bits += scale_bits
2772-
else: # woq
2772+
else: # woq
27732773
scale_bits = 16 * (m_numel // m.group_size)
27742774
all_bits += scale_bits
27752775

0 commit comments

Comments
 (0)