diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 854f3f460..13bd41064 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -604,6 +604,16 @@ def parse_format_to_list(self, format: str) -> list: self.scale_dtype = torch.float32 logger.info("change `scale_dtype` to `torch.float32`") + if self.model.dtype != torch.float16 and self.scale_dtype == torch.float16: + only_auto_round = True + for format_ in formats: + if not ("auto_round" in format_ or "fake" in format_): + only_auto_round = False + break + if only_auto_round: + self.scale_dtype = torch.bfloat16 + logger.info("change `scale_dtype` to `torch.bfloat16`") + # Adjust format settings based on compatibility for index in range(len(formats)): format = formats[index] diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 22ffe967d..e48653a91 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -221,13 +221,9 @@ def pack_layer(layer_name, model, backend): if bits != 4: logger.error("AutoAWQ format only supports 4-bits quantization.") + qlayer = QuantLinear.from_linear( - linear=layer, - w_bit=bits, - group_size=group_size, - init_only=False, - scales=scale, - zeros=zp, + linear=layer, w_bit=bits, group_size=group_size, init_only=False, scales=scale, zeros=zp ) qlayer.to(device) set_module(model, layer_name, qlayer) diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 591e5c567..46acf8954 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -89,8 +89,10 @@ def __init__( weight_global_scale = calculate_gparam(self.orig_layer.weight, self.orig_layer.group_size) setattr(self, "weight_global_scale", weight_global_scale) self.weight_global_scale = self.weight_global_scale.to(self.orig_layer.weight.device) - if hasattr(self.orig_layer, "scale_dtype") and self.orig_layer.scale_dtype == torch.float32: - self.q_scale_thresh = 1e-8 + if hasattr(self.orig_layer, "scale_dtype") and ( + self.orig_layer.scale_dtype == torch.float32 or self.orig_layer.scale_dtype == torch.bfloat16 + ): + self.q_scale_thresh = 1e-30 else: self.q_scale_thresh = 1e-5 self._init_tuning_params_and_quant_func() diff --git a/auto_round_extension/torch/qlinear_torch.py b/auto_round_extension/torch/qlinear_torch.py index c45e183df..c3b24c821 100644 --- a/auto_round_extension/torch/qlinear_torch.py +++ b/auto_round_extension/torch/qlinear_torch.py @@ -30,7 +30,9 @@ class QuantLinear(nn.Module): QUANT_TYPE = "torch" - def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): + def __init__( + self, bits, group_size, infeatures, outfeatures, bias, trainable=False, weight_dtype=torch.bfloat16, **kwargs + ): super().__init__() if bits not in [2, 3, 4, 8]: raise NotImplementedError("Only 2,3,4,8 bits are supported.") @@ -62,7 +64,7 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa ), ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype)) else: self.bias = None @@ -89,8 +91,8 @@ def post_init(self): def pack(self, linear, scales, zeros, g_idx=None): scales_t = scales.t().contiguous() if linear.bias is not None: - self.bias = linear.bias.clone().half() - self.scales = scales_t.clone().half() + self.bias = linear.bias.clone().to(self.bias.dtype) + self.scales = scales_t.clone().to(self.scales.dtype) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" @@ -160,7 +162,7 @@ def pack(self, linear, scales, zeros, g_idx=None): if isinstance(zeros, torch.Tensor): zeros = zeros.t().contiguous() - zeros = zeros.numpy().astype(np.uint32) + zeros = zeros.to(torch.float16).numpy().astype(np.uint32) qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=torch.int32) i = 0 col = 0 diff --git a/auto_round_extension/torch/qlinear_torch_zp.py b/auto_round_extension/torch/qlinear_torch_zp.py index 2958d249c..eb37c8269 100644 --- a/auto_round_extension/torch/qlinear_torch_zp.py +++ b/auto_round_extension/torch/qlinear_torch_zp.py @@ -31,7 +31,9 @@ class QuantLinear(nn.Module): QUANT_TYPE = "torch" - def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs): + def __init__( + self, bits, group_size, infeatures, outfeatures, bias, trainable=False, weight_dtype=torch.bfloat16, **kwargs + ): super().__init__() if bits not in [2, 3, 4, 8]: raise NotImplementedError("Only 2,3,4,8 bits are supported.") @@ -59,11 +61,11 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa "scales", torch.zeros( (math.ceil(infeatures / self.group_size), outfeatures), - dtype=torch.float16, + dtype=weight_dtype, ), ) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype)) else: self.bias = None @@ -90,8 +92,8 @@ def post_init(self): def pack(self, linear, scales, zeros, g_idx=None): scales_t = scales.t().contiguous() if linear.bias is not None: - self.bias = linear.bias.clone().half() - self.scales = scales_t.clone().half() + self.bias = linear.bias.clone().to(self.bias.dtype) + self.scales = scales_t.clone().to(self.scales.dtype) device = "cpu" if torch.cuda.is_available(): device = "cuda:0" @@ -161,7 +163,7 @@ def pack(self, linear, scales, zeros, g_idx=None): zeros = zeros.t().contiguous() zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) + zeros = zeros.to(torch.float16).numpy().astype(np.uint32) qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=torch.int32) i = 0 col = 0