@@ -44,7 +44,6 @@ def __init__(
44
44
self .norm = norm
45
45
self .is_activation = base_name != "weight"
46
46
47
-
48
47
def calculate_mse_min_max (
49
48
self ,
50
49
observed : Tensor ,
@@ -88,7 +87,7 @@ def calculate_mse_min_max(
88
87
89
88
from compressed_tensors .quantization .utils import generate_gparam
90
89
91
- if (is_fp4 (self .quantization_args )) and global_scale is None :
90
+ if (is_fp4 (self .quantization_args )) and global_scale is None :
92
91
# If the quantization scheme is fp4 and global_scale is still None
93
92
# i.e it has not yet been optimized, then we are should first get
94
93
# the global scale and then optimize the local scales.
@@ -147,7 +146,7 @@ def calculate_updated_min_max(
147
146
reduce_dims : Optional [Tuple [int ]] = None ,
148
147
tensor_id : Optional [Any ] = None ,
149
148
global_scale : Optional [torch .Tensor ] = None ,
150
- is_local : Optional [bool ]= False ,
149
+ is_local : Optional [bool ] = False ,
151
150
) -> Tuple [FloatTensor , IntTensor ]:
152
151
"""
153
152
Updates the mse-clipped min and max values of the observed tensor using
@@ -258,7 +257,6 @@ def reset(self):
258
257
self .min_val = {}
259
258
self .max_val = {}
260
259
261
-
262
260
def calculate_gparam (self , observed : Tensor ) -> torch .Tensor :
263
261
"""
264
262
Generate a global scale using the observed min and max from MSE optimization.
@@ -276,4 +274,5 @@ def calculate_gparam(self, observed: Tensor) -> torch.Tensor:
276
274
)
277
275
278
276
return generate_gparam (
279
- updated_min_val = updated_min_val , updated_max_val = updated_max_val )
277
+ updated_min_val = updated_min_val , updated_max_val = updated_max_val
278
+ )
0 commit comments