Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,16 @@ def parse_format_to_list(self, format: str) -> list:
self.scale_dtype = torch.float32
logger.info("change `scale_dtype` to `torch.float32`")

if self.model.dtype != torch.float16 and self.scale_dtype == torch.float16:
only_auto_round = True
for format_ in formats:
if not ("auto_round" in format_ or "fake" in format_):
only_auto_round = False
break
if only_auto_round:
self.scale_dtype = torch.bfloat16
logger.info("change `scale_dtype` to `torch.bfloat16`")

# Adjust format settings based on compatibility
for index in range(len(formats)):
format = formats[index]
Expand Down
8 changes: 2 additions & 6 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,9 @@ def pack_layer(layer_name, model, backend):

if bits != 4:
logger.error("AutoAWQ format only supports 4-bits quantization.")

qlayer = QuantLinear.from_linear(
linear=layer,
w_bit=bits,
group_size=group_size,
init_only=False,
scales=scale,
zeros=zp,
linear=layer, w_bit=bits, group_size=group_size, init_only=False, scales=scale, zeros=zp
)
qlayer.to(device)
set_module(model, layer_name, qlayer)
Expand Down
6 changes: 4 additions & 2 deletions auto_round/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def __init__(
weight_global_scale = calculate_gparam(self.orig_layer.weight, self.orig_layer.group_size)
setattr(self, "weight_global_scale", weight_global_scale)
self.weight_global_scale = self.weight_global_scale.to(self.orig_layer.weight.device)
if hasattr(self.orig_layer, "scale_dtype") and self.orig_layer.scale_dtype == torch.float32:
self.q_scale_thresh = 1e-8
if hasattr(self.orig_layer, "scale_dtype") and (
self.orig_layer.scale_dtype == torch.float32 or self.orig_layer.scale_dtype == torch.bfloat16
):
self.q_scale_thresh = 1e-30
else:
self.q_scale_thresh = 1e-5
self._init_tuning_params_and_quant_func()
Expand Down
12 changes: 7 additions & 5 deletions auto_round_extension/torch/qlinear_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class QuantLinear(nn.Module):

QUANT_TYPE = "torch"

def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
def __init__(
self, bits, group_size, infeatures, outfeatures, bias, trainable=False, weight_dtype=torch.bfloat16, **kwargs
):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
Expand Down Expand Up @@ -62,7 +64,7 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa
),
)
if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None

Expand All @@ -89,8 +91,8 @@ def post_init(self):
def pack(self, linear, scales, zeros, g_idx=None):
scales_t = scales.t().contiguous()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.scales = scales_t.clone().half()
self.bias = linear.bias.clone().to(self.bias.dtype)
self.scales = scales_t.clone().to(self.scales.dtype)
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
Expand Down Expand Up @@ -160,7 +162,7 @@ def pack(self, linear, scales, zeros, g_idx=None):

if isinstance(zeros, torch.Tensor):
zeros = zeros.t().contiguous()
zeros = zeros.numpy().astype(np.uint32)
zeros = zeros.to(torch.float16).numpy().astype(np.uint32)
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=torch.int32)
i = 0
col = 0
Expand Down
14 changes: 8 additions & 6 deletions auto_round_extension/torch/qlinear_torch_zp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class QuantLinear(nn.Module):

QUANT_TYPE = "torch"

def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
def __init__(
self, bits, group_size, infeatures, outfeatures, bias, trainable=False, weight_dtype=torch.bfloat16, **kwargs
):
super().__init__()
if bits not in [2, 3, 4, 8]:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
Expand Down Expand Up @@ -59,11 +61,11 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=torch.float16,
dtype=weight_dtype,
),
)
if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None

Expand All @@ -90,8 +92,8 @@ def post_init(self):
def pack(self, linear, scales, zeros, g_idx=None):
scales_t = scales.t().contiguous()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
self.scales = scales_t.clone().half()
self.bias = linear.bias.clone().to(self.bias.dtype)
self.scales = scales_t.clone().to(self.scales.dtype)
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
Expand Down Expand Up @@ -161,7 +163,7 @@ def pack(self, linear, scales, zeros, g_idx=None):

zeros = zeros.t().contiguous()
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
zeros = zeros.to(torch.float16).numpy().astype(np.uint32)
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=torch.int32)
i = 0
col = 0
Expand Down