From c7ecb1efcdd333c8be50393a6c5f1eb77f0d6085 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 7 Nov 2025 11:51:47 -0800 Subject: [PATCH 01/22] Add torch2.9 in regression tests --- .github/workflows/regression_test.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index cc474ff9e7..456822e4f2 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -77,6 +77,12 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" + - name: CUDA 2.9 + runs-on: linux.g5.12xlarge.nvidia.gpu + torch-spec: 'torch==2.9.0' + gpu-arch-type: "cuda" + gpu-arch-version: "12.6" + dev-requirements-overrides: "" - name: CPU 2.6 runs-on: linux.4xlarge @@ -96,6 +102,12 @@ jobs: gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" + - name: CPU 2.9 + runs-on: linux.4xlarge + torch-spec: 'torch==2.9.0 --index-url https://download.pytorch.org/whl/cpu' + gpu-arch-type: "cpu" + gpu-arch-version: "" + dev-requirements-overrides: "" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: From e9f94ba9f56ee0778546c2943e680017b089e6e8 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 12 Nov 2025 14:03:01 -0800 Subject: [PATCH 02/22] Update torch version to 2.9.1 in regression tests --- .github/workflows/regression_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 456822e4f2..278b276ada 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -79,7 +79,7 @@ jobs: dev-requirements-overrides: "" - name: CUDA 2.9 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.9.0' + torch-spec: 'torch==2.9.1' gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" @@ -104,7 +104,7 @@ jobs: dev-requirements-overrides: "" - name: CPU 2.9 runs-on: linux.4xlarge - torch-spec: 'torch==2.9.0 --index-url https://download.pytorch.org/whl/cpu' + torch-spec: 'torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" From 886f0a6cb5654a4ceb20924eb7748ca85ddcac1d Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 12 Nov 2025 14:04:06 -0800 Subject: [PATCH 03/22] Update torch version from 2.7.0 to 2.7.1 --- .github/workflows/regression_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 278b276ada..46928b30cf 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -67,7 +67,7 @@ jobs: dev-requirements-overrides: "" - name: CUDA 2.7 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.7.0' + torch-spec: 'torch==2.7.1' gpu-arch-type: "cuda" gpu-arch-version: "12.6" dev-requirements-overrides: "" @@ -92,7 +92,7 @@ jobs: dev-requirements-overrides: "" - name: CPU 2.7 runs-on: linux.4xlarge - torch-spec: 'torch==2.7.0 --index-url https://download.pytorch.org/whl/cpu' + torch-spec: 'torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" dev-requirements-overrides: "" From 1a9a13f1c4399a41159fc3d9e419512e643b992e Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 7 Nov 2025 13:52:55 -0800 Subject: [PATCH 04/22] Move dyn_int8_act_int4_wei_cpu_layout to prototype/dtypes (#3299) --- docs/source/api_ref_dtypes.rst | 1 + test/dtypes/test_uintx.py | 39 +++ test/integration/test_integration.py | 27 -- test/sparsity/test_sparse_api.py | 27 -- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 326 +----------------- torchao/prototype/dtypes/__init__.py | 7 +- torchao/prototype/dtypes/uintx/__init__.py | 2 + .../uintx/dyn_int8_act_int4_wei_cpu_layout.py | 318 +++++++++++++++++ 10 files changed, 388 insertions(+), 369 deletions(-) create mode 100644 torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index e347dfd2e3..5c73d275eb 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -52,6 +52,7 @@ Prototype BlockSparseLayout CutlassInt4PackedLayout + Int8DynamicActInt4WeightCPULayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index cb0c88b21c..5d54a80753 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -3,6 +3,9 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import sys +import warnings + import pytest import torch @@ -165,3 +168,39 @@ def test_uintx_model_size(dtype): quantize_(linear[0], UIntXWeightOnlyConfig(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size + + +def test_uintx_api_deprecation(): + """ + Test that deprecated uintx APIs trigger deprecation warnings on import. + TODO: Remove this test once the deprecated APIs have been removed. + """ + deprecated_apis = [ + ( + "Int8DynamicActInt4WeightCPULayout", + "torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout", + ), + ("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"), + ("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"), + ] + + for api_name, module_path in deprecated_apis: + # Clear the cache to force re-importing and trigger the warning again + modules_to_clear = [module_path, "torchao.dtypes"] + for mod in modules_to_clear: + if mod in sys.modules: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Ensure all warnings are captured + + # Dynamically import the deprecated API + exec(f"from torchao.dtypes import {api_name}") + + assert any( + issubclass(warning.category, DeprecationWarning) + and api_name in str(warning.message) + for warning in w + ), ( + f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}" + ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 50c2eebe81..70da622c73 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1948,32 +1948,5 @@ def test_benchmark_model_cpu(self): assert self.run_benchmark_model("cpu") is not None -# TODO: Remove this test once the deprecated API has been removed -def test_cutlass_int4_packed_layout_deprecated(): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.cutlass_int4_packed_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - assert any( - issubclass(warning.category, DeprecationWarning) - and "CutlassInt4PackedLayout" in str(warning.message) - for warning in w - ), ( - f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}" - ) - - if __name__ == "__main__": unittest.main() diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index c9d41a98a9..66cd032a9a 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -267,33 +267,6 @@ def test_sparse(self, compile): torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1) - # TODO: Remove this test once the deprecated API has been removed - def test_sparse_deprecated(self): - import sys - import warnings - - # We need to clear the cache to force re-importing and trigger the warning again. - modules_to_clear = [ - "torchao.dtypes.uintx.block_sparse_layout", - "torchao.dtypes", - ] - for mod in modules_to_clear: - if mod in sys.modules: - del sys.modules[mod] - - with warnings.catch_warnings(record=True) as w: - from torchao.dtypes import BlockSparseLayout # noqa: F401 - - warnings.simplefilter("always") # Ensure all warnings are captured - self.assertTrue( - any( - issubclass(warning.category, DeprecationWarning) - and "BlockSparseLayout" in str(warning.message) - for warning in w - ), - f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}", - ) - common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse) common_utils.instantiate_parametrized_tests(TestQuantSemiSparse) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 252498bc97..354692e794 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,7 +16,6 @@ from .uintx import ( Int4CPULayout, Int4XPULayout, - Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, MarlinSparseLayout, @@ -29,6 +28,7 @@ ) from .uintx.block_sparse_layout import BlockSparseLayout from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout +from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index e46809059e..3816f9bf1f 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( - _linear_int8_act_int4_weight_cpu_check, - _linear_int8_act_int4_weight_cpu_impl, -) from torchao.dtypes.uintx.gemlite_layout import ( _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, @@ -94,6 +90,10 @@ _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) +from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( + _linear_int8_act_int4_weight_cpu_check, + _linear_int8_act_int4_weight_cpu_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 8d0cfaddeb..d66f70e2ee 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -3,317 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, +warnings.warn( + "Importing from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes import Int8DynamicActInt4WeightCPULayout' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import Layout, PlainLayout, is_device -from torchao.utils import torch_version_at_least -from .int4_cpu_layout import ( - Int4CPUAQTTensorImpl, - _is_float, +from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import ( # noqa: F401 + DA8W4CPUAQTTensorImpl, # noqa: F401 + Int8DynamicActInt4WeightCPULayout, # noqa: F401 + _aqt_is_int8, # noqa: F401 + _aqt_is_uint4, # noqa: F401 + _aqt_is_uint8, # noqa: F401 + _linear_int8_act_int4_weight_cpu_check, # noqa: F401 + _linear_int8_act_int4_weight_cpu_impl, # noqa: F401 ) - -aten = torch.ops.aten - - -@dataclass(frozen=True) -class Int8DynamicActInt4WeightCPULayout(Layout): - """Layout class for da8w4 CPU layout for affine quantized tensor""" - - pass - - -@register_layout(Int8DynamicActInt4WeightCPULayout) -class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): - """TensorImpl for da8w4 CPU layout for affine quantized tensor - It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of - dimension: [n][k / 2] (uint8 dtype) - It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data - fields: - packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout - scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor - qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - """ - - def __new__( - cls, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - compensation: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scales = scales - self.qzeros = qzeros - self.compensation = compensation - self.transposed = transposed - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scales", "qzeros", "compensation"], [ - self.transposed, - self._layout, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scales, qzeros, compensation = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scales"], - tensor_data_dict["qzeros"], - tensor_data_dict["compensation"], - ) - ( - transposed, - _layout, - ) = tensor_attributes - return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) - assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" - assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" - if scale.dim() == 1: - scale.unsqueeze_(-1) - scale = scale.to(torch.float) - if zero_point.dim() == 1: - zero_point.unsqueeze_(-1) - - # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. - # Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available. - # Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n]. - # Compensation shape = [N / block_n, K / block_k, block_n]. - weight_int4, scales, qzeros, compensation = ( - torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) - ) - return cls(weight_int4, scales, qzeros, compensation, False, _layout) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - fn(self.scales), - fn(self.qzeros), - fn(self.compensation), - self.transposed, - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = DA8W4CPUAQTTensorImpl( - args[0].packed_weight, - args[0].scales, - args[0].qzeros, - args[0].compensation, - not args[0].transposed, - args[0]._layout, - ) - return return_and_correct_aliasing(func, args, kwargs, transposed) - else: - return super().__torch_dispatch__(func, types, args, kwargs) - - __torch_function__ = torch._C._disabled_torch_function_impl - - @property - def block_size(self): - assert len(self.packed_weight.shape) == 2 - weight_shape = self.packed_weight.shape - N = weight_shape[0] - K = weight_shape[1] * 2 - groups = self.scales.numel() // N - group_size = K // groups - return (1, group_size) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Unpack weight by linear(eye(K), packed_weight).t() - packed_w_shape = self.packed_weight.shape - if len(packed_w_shape) == 4: - K = packed_w_shape[1] * packed_w_shape[2] - else: - K = packed_w_shape[1] - x = torch.eye(K).to(torch.uint8) - x_scale = torch.ones(K).float() - x_qzero = torch.zeros(K).to(torch.int32) - w_scale = torch.ones_like(self.scales).float() - w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) - plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( - x, - x_scale, - x_qzero, - self.packed_weight, - w_scale, - w_qzero, - self.compensation, - None, # bias - torch.float, # out_dtype - ) - plain_weight = plain_weight.t().contiguous() - plain_weight = plain_weight.to(torch.int8) - - if self.scales.dim() == 2: - assert self.qzeros.dim() == 2 - plain_scales = self.scales - plain_qzeros = self.qzeros - else: - assert self.scales.dim() == 3 and self.qzeros.dim() == 3 - packed_shape = self.scales.shape # [Nc, G, block_n] - plain_scales = ( - self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - plain_qzeros = ( - self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - - return plain_weight, plain_scales, plain_qzeros - - -def _aqt_is_uint8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == 0 - and aqt.quant_max == 255 - ) - - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 - and aqt.quant_min == -127 - and aqt.quant_max == 127 - ) - - -def _aqt_is_uint4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.uint8 - and aqt.quant_min == 0 - and aqt.quant_max == 15 - ) - - -def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): - return ( - torch_version_at_least("2.7.0") - and is_device(input_tensor.device.type, "cpu") - and is_device(weight_tensor.device.type, "cpu") - and (bias is None or is_device(bias.device.type, "cpu")) - and isinstance(input_tensor, AffineQuantizedTensor) - and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) - and _is_float(input_tensor.dtype) - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_uint4(weight_tensor) - and _is_float(weight_tensor.dtype) - and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) - ) - - -def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): - assert torch_version_at_least("2.7.0"), ( - f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" - ) - if _aqt_is_int8(input_tensor): - assert torch_version_at_least("2.8.0"), ( - f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" - ) - assert is_device(input_tensor.device.type, "cpu"), ( - f"For CPU device only but got: {input_tensor.device}" - ) - assert weight_tensor.block_size[0] == 1, ( - f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - ) - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " - ) - - act_mat = input_tensor - act = act_mat.tensor_impl.int_data - act_scales = act_mat.tensor_impl.scale - act_qzeros = act_mat.tensor_impl.zero_point - - packed_weight = weight_tensor.tensor_impl.packed_weight - wei_scales = weight_tensor.tensor_impl.scales - wei_qzeros = weight_tensor.tensor_impl.qzeros - compensation = weight_tensor.tensor_impl.compensation - - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape to 2D - act = act.reshape(-1, act.shape[-1]) - - y = torch.ops.torchao.da8w4_linear_cpu.default( - act.contiguous(), - act_scales, - act_qzeros, - packed_weight, - wei_scales, - wei_qzeros, - compensation, - bias.float() if bias is not None else bias, # requires bias to be float - orig_dtype, # out_dtype - ) - - # remove out_feature padding - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - return y.to(orig_dtype) - - -# Register the concat linear fusion pass -# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass - -# register_da8w4_concat_linear_cpu_pass() diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 25f139d583..52a5aec425 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -4,9 +4,14 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from .uintx import BlockSparseLayout, CutlassInt4PackedLayout +from .uintx import ( + BlockSparseLayout, + CutlassInt4PackedLayout, + Int8DynamicActInt4WeightCPULayout, +) __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 53edddb8ac..89c1f3f810 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -6,8 +6,10 @@ from .block_sparse_layout import BlockSparseLayout from .cutlass_int4_packed_layout import CutlassInt4PackedLayout +from .dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", + "Int8DynamicActInt4WeightCPULayout", ] diff --git a/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py new file mode 100644 index 0000000000..24cc02e358 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.int4_cpu_layout import ( + Int4CPUAQTTensorImpl, + _is_float, +) +from torchao.dtypes.utils import Layout, PlainLayout, is_device +from torchao.utils import torch_version_at_least + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Int8DynamicActInt4WeightCPULayout(Layout): + """Layout class for da8w4 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Int8DynamicActInt4WeightCPULayout) +class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): + """TensorImpl for da8w4 CPU layout for affine quantized tensor + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor + qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + compensation: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.qzeros = qzeros + self.compensation = compensation + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales", "qzeros", "compensation"], [ + self.transposed, + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales, qzeros, compensation = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + tensor_data_dict["qzeros"], + tensor_data_dict["compensation"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scales, qzeros, compensation, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) + assert int_data.dtype == torch.uint8, "DA8W4 CPU: expects uint8 weight" + assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + if zero_point.dim() == 1: + zero_point.unsqueeze_(-1) + + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack the inner blocks [block_k, block_n] to VNNI layout if AMX is available. + # Pack scales/qzeros from [N, num_groups] to [N / block_n, num_groups, block_n]. + # Compensation shape = [N / block_n, K / block_k, block_n]. + weight_int4, scales, qzeros, compensation = ( + torch.ops.torchao.da8w4_linear_prepack_cpu(int_data, scale, zero_point) + ) + return cls(weight_int4, scales, qzeros, compensation, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + fn(self.qzeros), + fn(self.compensation), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = DA8W4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scales, + args[0].qzeros, + args[0].compensation, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + else: + return super().__torch_dispatch__(func, types, args, kwargs) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] * 2 + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.uint8) + x_scale = torch.ones(K).float() + x_qzero = torch.zeros(K).to(torch.int32) + w_scale = torch.ones_like(self.scales).float() + w_qzero = torch.zeros_like(self.qzeros).to(torch.int8) + plain_weight = torch.ops.torchao.da8w4_linear_cpu.default( + x, + x_scale, + x_qzero, + self.packed_weight, + w_scale, + w_qzero, + self.compensation, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.int8) + + if self.scales.dim() == 2: + assert self.qzeros.dim() == 2 + plain_scales = self.scales + plain_qzeros = self.qzeros + else: + assert self.scales.dim() == 3 and self.qzeros.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + plain_qzeros = ( + self.qzeros.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, plain_qzeros + + +def _aqt_is_uint8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 255 + ) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is uint8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and aqt.quant_max == 127 + ) + + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _linear_int8_act_int4_weight_cpu_check(input_tensor, weight_tensor, bias): + return ( + torch_version_at_least("2.7.0") + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and (_aqt_is_uint8(input_tensor) or _aqt_is_int8(input_tensor)) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Int8DynamicActInt4WeightCPULayout) + ) + + +def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert torch_version_at_least("2.7.0"), ( + f"Requires PyTorch version at least 2.7, but got: {torch.__version__}" + ) + if _aqt_is_int8(input_tensor): + assert torch_version_at_least("2.8.0"), ( + f"Requires PyTorch version at least 2.8, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + act_qzeros = act_mat.tensor_impl.zero_point + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + wei_qzeros = weight_tensor.tensor_impl.qzeros + compensation = weight_tensor.tensor_impl.compensation + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.da8w4_linear_cpu.default( + act.contiguous(), + act_scales, + act_qzeros, + packed_weight, + wei_scales, + wei_qzeros, + compensation, + bias.float() if bias is not None else bias, # requires bias to be float + orig_dtype, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + +# Register the concat linear fusion pass +# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass + +# register_da8w4_concat_linear_cpu_pass() From 677ed0cabf30e639fde4054338737a28f06473ff Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 7 Nov 2025 14:12:52 -0800 Subject: [PATCH 05/22] Skip quantization when channels_out / channels_in are not multiple of 16 (#3309) Summary: The underlying fbgemm conv3d kernel for float8 only supports channels_out/channels_in are both multiples of 16 so we skip the shapes that doesn't satisfy the requirements for now, we can expand the support to do padding if needed in the future Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_skip_quant --- .../workflows/float8/test_float8_tensor.py | 85 ++++++++++++++++--- torchao/quantization/quant_api.py | 7 ++ 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index be5f2361c3..1b91875359 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -17,6 +17,7 @@ from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, + Float8Tensor, Float8WeightOnlyConfig, Granularity, PerBlock, @@ -25,7 +26,6 @@ quantize_, ) from torchao.quantization.quantize_.common import KernelPreference -from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase from torchao.utils import ( @@ -329,14 +329,13 @@ def _test_fp8_matmul_model( @unittest.skipIf( not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0" ) + @unittest.skipIf( + not _is_fbgemm_gpu_genai_available(), + "Requires fbgemm_gpu_genai to be installed", + ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("compile", [True, False]) - @common_utils.parametrize("granularity", [PerTensor()]) @common_utils.parametrize("inference_mode", [True, False]) - @common_utils.parametrize( - "kernel_preference", - [KernelPreference.AUTO], - ) # only test for 3D conv for now # Inputs are (N, C_in, C_out, D, H, W) @common_utils.parametrize( @@ -349,19 +348,14 @@ def test_fp8_conv_variants( self, dtype: torch.dtype, compile: bool, - granularity, inference_mode: bool, kernel_preference: KernelPreference, sizes: Tuple, ): - if (not _is_fbgemm_gpu_genai_available()) or (not is_sm_at_least_100()): - return unittest.skip( - "Requires fbgemm_gpu_genai and sm version >= 10.0 to run " - "fbgemm kernel preference test" - ) - - dim = 3 + granularity = PerTensor() + kernel_preference = KernelPreference.AUTO N, C_in, C_out, D, H, W = sizes + dim = 3 kernel_size = 3 # Note: this is channel last memory format @@ -404,6 +398,69 @@ def test_fp8_conv_variants( f"Quantization error is too high got a SQNR of {error}" ) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not is_sm_at_least_100(), "Requires GPU with compute capability >= 10.0" + ) + @unittest.skipIf( + not _is_fbgemm_gpu_genai_available(), + "Requires fbgemm_gpu_genai to be installed", + ) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + # only test for 3D conv for now + # Inputs are (N, C_in, C_out, D, H, W) + @common_utils.parametrize( + "sizes", + [ + (4, 12, 64, 32, 32, 32), + (4, 16, 12, 32, 32, 32), + ], + ) + def test_fp8_conv_skip_quant( + self, + dtype: torch.dtype, + sizes: Tuple, + ): + """Some shapes are not supported so we won't quantize the module + Specifically, we skip quantization when C_in or C_out is not a multiple of 16 + """ + granularity = PerTensor() + kernel_preference = KernelPreference.AUTO + N, C_in, C_out, D, H, W = sizes + dim = 3 + kernel_size = 3 + + # Note: this is channel last memory format + input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda") + input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + # Create a linear layer with bfloat16 dtype + model = ToyConvModel( + dim, + C_in, + C_out, + kernel_size, + bias=False, + padding=0, + dtype=dtype, + device="cuda", + ).eval() + + quantized_model = copy.deepcopy(model) + + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + kernel_preference=kernel_preference, + ) + + _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) + + quantize_(quantized_model, config, filter_fn=_is_conv3d) + assert not isinstance(quantized_model.conv.weight, Float8Tensor) + + output_original = model(input_tensor) + output_quantized = quantized_model(input_tensor) + self.assertEqual(output_original, output_quantized) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @unittest.skipIf( not is_sm_at_least_90(), diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bca3a7cb3e..e3a75bbb3e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1821,6 +1821,13 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): assert isinstance(activation_granularity, PerTensor) and isinstance( weight_granularity, PerTensor ), "5D tensor only supports per tensor activation and weight quantization" + + # weight dim: (C_out, C_in, K1, K2, K3) + # skip quantization when either C_out or C_in + # is not a multiple of 16 + if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: + return weight + elif not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked From 02ecbb786650983a105373c807f83802b50c97aa Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 7 Nov 2025 17:14:55 -0800 Subject: [PATCH 06/22] [mxfp8 moe training][BE] add docs showing equivalent convergence to bf16 at scale (#3312) --- docs/static/mxfp8_with_loss.png | Bin 0 -> 47012 bytes torchao/prototype/moe_training/README.md | 24 ++++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 docs/static/mxfp8_with_loss.png diff --git a/docs/static/mxfp8_with_loss.png b/docs/static/mxfp8_with_loss.png new file mode 100644 index 0000000000000000000000000000000000000000..47e2967aed57357bfa6ef2ab8d830f8cba7973f7 GIT binary patch literal 47012 zcmdRVWm{WK)GkhO3KW;(?(Rj3dueg^;_gr!N^#c|FU8$mQY^Rz2=4CAN&CF-b)66A z51f3-&d%O5Yu4OLXRRnT6q06^9f~dahB6ep&zSTS_3s7rmx>>b_WxXl zE5o7vceB45gGig9NKK)mq2cYnS0sPPAxC4C=V2<)B^|ARl|&M{Iq-Bf+Qa^h=ml0s z)z`v3do2CM(6@V;hm3+b)2)-2tkGOKaZl7nhfE3w({o=&g`|pfKVyga08`K(73K zog~VC^#0$J%Nzf9FO-lEh2?)2Oc(SYq1dJiG}ZqxK@vR+0snudGD+(G@9GVw-T#vv z|9=fCP-S5B+)3y-qSRvYEL#Qwqx!{dtRBP!1F-UfRHpcT@=s4GmfQ z$1&aV4{fwi1kyq`WTWun_|zYu=h#K2fiZjK33PTMwdz7zH$Jtg!15Y3)}J5qD~f}6 zYB;$SQns>Jo#*ag#feH(qIzZ$2csqLI^9S_4KhQ*}<=T>PB9ycE$}iF~8wOhcChl`T*0%k?uQe*Lgo% z#pQFg*Lg_l1UhQP@L*M7Ozmn7X8c5Fb&#^k{=9wZO>JpbLujzF;(z-Y%im}dC2Ov+ zF<^HnK>TqDd9IY1BZD81n2|ps9Jig#ey0j$hWY=~YV3p!X zDAV=9WNlY?W|u7#n;v)Q9|~$o&iBJFDh=8e)Vq?}mq{%_*-?pdK$zot*vmCs#ajW-K zt8<3G_r4#^6-6t8Dn3%6$>yLk&oLOdC68NLSy^0A5VAdzhO0=UUFX4CY{TU`=$4sz zGda65P#A@)bN|3JJgiVkXwb9#X$KoMf_r{tWkq=`JZ7O$+5K8z@XP1CNuCq?LJwCm zK|w*KEWVWXWqtI%_xN#j{#K&%0b7})?sf~!2?lM%P?A6h3MHVdXp>CsV)=)jPMh?Z z)8(O9vTHnid@kQPrBO|(MGm|?wQnaUq^}N6~MzUnWN1SH0e0egBT?ebchrvP*EAXcRwQ8II&|w?WW+7JTWwF+_h4fIOw< zZ%`2vd$?}$cY=cJK*J)=DvDhmq!Ey(8`vAZlcMVC_%uU96?JvWQ}6kW1hkpZB|fOt`GiPJlI8UpQE4(TGD67j!6NK2bNTA<#7AOl%gD;K zGxKTpJmfSmEi*I5@3c6p#xH=XL$v+yV&wkLodVH#b1(*v08d><2c^3&FfiHZ>B8{$ z@81iJ{>TZdZ_*4eFjaE+`44?MY+rY|$1k3<0wa$Fpc}FJ+R1#kZ1pONn~n%a>F~17 zTHnxk{%|nsx0P9T-@7X&YobmGxtsG`(ASDM6gpZ2|MBd+#$ao^!om9<5+kGtx%3|! z^SckPnj+aRs_1~qvFGa+P>E1Hx9#LbFy(WK-{qmv*c0Zm|2X?Sz~~7*W<*)MJ!LGb zNlh(Yb!BG^k09*=V)XI^3m@Hy>eYb*ZaiH>ZdxAx&A;01i1>?KYd{(Diiu0F^s}uk zBFnKQtd&&{)6;)NcvQ3blE1T)NXUD)k&n|!Jz=uPK*AWF&>5j9X7Bv@!8a;39F1rJd|BC3m zd&}VofZO5%FTj0#yk)adFN155go$UyiD=&a{P7toOHsF-k1HZ~-Tn))eNtHBqXPu& z+i_T8A2(K?#NErfwv(*` zEW?i`>ag?vGlyexSlkzn&piB1XZP$Y-Mw`LXOEAkn}d#iIgpo=+YEh-dG{z_L3Wop zXA{H&jX&_U?P){nwQu-y8BFumM$yyOjo$s{DaG1KT1OJ{>m5BQu!oCvtvli41NR#w zIDOw^+2-AoHm^O^Kx0iEqQHksBNPNU76Tv3aj{nsR@Mp_k(aGcN?Ae~P@R_R??9v~ zcFp6t?7{2_xi=e(qZm&dh~^5~4BA&(a~+8b=nq%U`nPM(x6j{KwL?!=+_+(~4*4$A zx}IX#{Z3q$eaFY|=2ea2Y%0mx|*}%j67MxM|o&N1)}7j-qJrf zI0d-&$$0Sdxr{|r6}S4360jMLZhYLW_8TZ$XdFY_WnUl5B<_5PaSSSBq~JyE!6VJipRIrX{>{XlESnbz-1OT=dmWl# z{PgAyzPU1ssEQ(b!27ARq~*e2i^Fl1QS-|eF7GLtWR1}b(dgoi$Ge+bhy2dYT$If+ z?wHGKUf$7kPO>7r%1!Jy4?|-^J-e6dt;F-r3=}K3c;f!c;V~r6WNA>u8vG#98`y^K z+d++)_xW}j_v2Dy-+*G8#%G-&?`V8Z?+qF#*+|`CyX+=r$Cf9}n{0FQcE*=GV{W%s zq!-u>6}6ZvywgFucfToTqdi!LE5jUW+*owIVsCXA zKXMR4Z3l7}MAeZ;LBfPG=)>C2p&4CU^l2KLsM;0>w6sn$#_1Aev*9iI_Q#yG^QT&Z zD5!BL{bhVG!?=7u4Y-(qkXz1;X2591mYVyy^ENGekf@5?^VC_5Ny_I!9TFO}0xvuTY$lyqX?bJ}6cejESP)Y96D!7qLIF-h3{0MFxe zwNJ+?x91x0gaCNh0=8Ujcw%ElaH$?=WoB|=*iCbAtIkZASLdw7(Km9u0N6-laC0#oE@hBF%up$3`6b!Sw? z9|?Zk0`~t{03EL%)Q9-P{rS%4dcqh&KV$^C=3jf={AX*%-d*9E9>m$?>7;> zPkRykmsyYGYuB@7g7&NWx!x~`Y`(iiFy(=D;hY7`f5yhf%7=_I=$OBBstAzwI_nR{ zR&0!NgN!%&QIygJ5?F6vGP}gC0^4szPj~c;9#K7)-TB!&+q0Ki$Fl2;VdwRqDPeqC z+c=e@0k0c|c3{iJX!i5Pz(<}s%d!3%G6e;NySXMu)05>zJ8)x6>O3ZJuZ6t@|`{4o>?=7s4q6Bjmy>!CfoG}F5nsb=zWl3 zb>H{g1+wE$mMG*O7SHud93JLue>!lnRLKJ6hoRN@;aAQ*UqEp`P|q?f8$%_|A1D5@ zubRzeg=y5W4t1E9=I_^dDQMTQrk3UgbDALEA*WW0XzXVlSq|&%1JMihCb&uTO6GGv zGTN^$MwVYT0k!5sE#*3Gv4Y%IoHo7v1^0*0vh9kwe5k-U1cFSdBdL2;r z^9^Oc+}!e1Z{4^ka4RZF;JE&ivq9RtR>jU3Kcuj^IWo+Oy|CR6uW_^5AMEj^;~c5N z-d+hH6gIG%6Je8Sx2!or;#i(nR~D^PXX$j(8%Buim3iFCRtC|Qq9WfduLJe-D5tUq zzO!HeWs6hvad@o{q1|;Q?H?U&&AzMp8{8`>Dx{0Qq7Q0mNijvNLvOx)gV4Ou zahg|Cli+SU@48=R+0k|6;qAV>7cQMXUpCbZM)h2$L2~!s)V&2y+CP2S2XTo<${)m6suhS z*lE=)Ak>MKroq_`BQVPU^(+1oWDsf(F88vh&Em7qFE9U5N09wdM~4VTuhF5Xv^47J z3GPdUp|Xt)V>C!Ct2e(_QzG^f6D44wt3lOj68g_7Q8J z)TVJ0D)x{EkY_2lH3m8Gkwgvy#rHKrd=#o~*?IW*YJE@1pTq)pxK)fK%h! zoe#H4P)`79$_$#jbadVCx`bL=%M~GAJum|slj#YgD_2=E0DWqHabc!E9D@X^0_1V& zIZY(5&nq*b`8=kr)_)_GVbrfwdx( zp>_EG<~B7w_LdPR=ARN+!d$ULo)!HI`)hyAIflSm-pX z;mhWv^i#9*x)eV&e*Q2c*A%E6^gtf6hw(|otWa;O5`qs#Eo%QJ8H@@IMPB}aDtlT6^duQEG;)xA5h{#f^ zu=fA?e(O!;dosoLrI``jHldnNu;Ez;Vb3?V9w9u` zg2upqKl&wc?i<|zKt z_l>u1KRyt)m4!)fdN>(e>|{r0Xer%hv8($G(R27oEEztasi@L_32=to6;S@R5%&4B zXKK6_zr&69Xv!z90}0*ijVQsxy&)3Og+%2UOLhVl2Pi?LDN?@65+y#uxeJRu<-PB< zR9Puah3{xc^xja$GXG7qZ#19hE*whZc|N=D$(0a%Fpt|;N1`z!z0AcqlD-nn zt~^z`?`+`p;hKSLt=ul7@i2|*B5n|UaY}V3?{G(!?%k<(K`H1Gm*d^$1>Dx1QZWIN zx4jX2s>wu>TckjyI%SBFoTP)j~#&u zGlr|mM{+G?Cdy1Vn>axw);Am_400bA^j3IH+nO^+6{-%wruFosD$HZsP^eA4rN{@$ z{v}${=ikoAH_SX!j^3wHED|N76l2v7o<*pqA}$S72ODCB%&pof6k{EG)iK-9@6&!Q zd3cw+t!m#S5fm#MWvBCTpB8}d^ZNdkwzVvEV!_6cs(I-(BN@VRH(Ea#;W|qc9$Y;x zTvB16&l4UTS{5a2NEe=mM=KGn-xX!Sz#^rdb7&TzgIV`o>P|TXDtuu<6K_{!q3WTJ z{;z(??#{NxT2lF@Auvh4Y>u8BgnKMQL?@c{WopK_(M=O~WL=Er6VT-laHk0pwE4ljZg?7j{wv%vsdQ2rLRx z_+#oMMhlX#B;W#(_JXA273aZFR<<%Wa|ii!$ZUhl$(}w=Y%vSXt0JzWk#6xvTwwYf zaZz$pzyx-e;4Tr_#s6{V( z%zfZ_jx!DRc6U~nVvwvb-;z7*iYMFg%>Yg&rB#R!@v>W!&uOM-;7WczIR z^jX?Z3DQ{j=T1_UHevr@lV=t!SbFxe=$ zAec<-lHhyJ$GZ^ulUiO}AQ=ML#Mdrlw_LmB*(3ygp!z%J`In&pFQFM1&nktN`?bEz;izmkuOEri^*FUD21HB?LkNkf=6H+y8I)R)S9%GVS zvm>ul#;wf(Qep)Ofw8#9+(MHFzV(&&(q&bCGJV1kD!iRYzT@sI&nN8PY%)XqF#Th? zU#{g;2g;{X4zeul*B-casByCF7|BnQ#TLRx*U{Sl9379Y9Zf`N){XKr&lMVY|?+3 zHOfSbH28ys3-?^AbMPFdiv8S(Di2SZ4)~~T7&J3$pSGIpdL5Vtm`Hr}51#b%%2{hLq9iALh%gS!a1)oA)8 zvPc{@^PUg$XRhg0m~7uC;$% z1q$oM$u|6bh%)w%n*1WN_zB@>xyYskXEc_ZQ#d7m)egR-w;638{sAIX3hg}-nurEu zGVu$miWfCoj0CJT0rMO!H^JZkH$vUcvcJ#VYJSKo1WRTk$SkWd)DNQIZ!jfETkREB z{%TGzY-J=dqTu81auWYq+1}|XgAZie(%7lv3L9m)-(EVzz{+0-kYBA5Broi!iM@7 zv1QIkGr&ooE}5%52VqU0$l6;lyFqN(2Nb6gv z=8mhI{)(ezl^fB`d$8&kYKK-GeFA#zZ!&og7-5Ni|B42(n~^Ysrl_l%de!_@eW&fY zwR|!5uK17~6v?tWj^!9;yRv&!iaxz^BG=zcKU$@o9>C5g&1RBnYq>_IfEvP6)0Hb# z-`0)>z53dM{ErRj(-W{#MjXyO z!D#;JM#xjdu3kP?^jX5D8I2iv#`dX%=m5Q44H^jUVk!?^5HLs%5iv zc~o7bz9HxFK$jC^AYZ-3&)HB=_2g*FibRj&UG^884cn5>URa4(9A`8g9JR+^9rVC| z=1yqd0=IELMYoi0jDx}V`}xlG()JGe4H4xTs>!twvGbHqR#*YrRwu|3fDo7qh z)6_%(q*jB@o?}*%ov~~ ziYp;*Z7y%+z*UQPn6^z%rdazeKd^O=&gT1-BnKs; zYu}ECXVT zSkmW5syhtoX}yF_D|^nYDDMCj5XYZh-7s2~#7Zzyt&pG?qQrw9Y8Cs-b;Q7Z4i>HK}xV&Am zKBqsW;a{1NGn{~S{Lk7OVO>Fdc?^4$h?`JlutcX24&N9G@-2-0@ASE9(QAd2EA20o@>;s~+4|mX&6v9Q(~+S>gcX@|3VHGb zu1pRkwY4&!J>2aN;uBNLi)kBtE6Xcp9kwUAah^=~yYNTK65i}5tLk{TI$?vl#Y`@K z&p1<;&vcg~E)1yYEaXWet&I%cVoC9;>Ud|tsKD8NG6H^HT+M#xMKVRT6q2=>M2;@- zvW!))1ldomeRGa-auyTb#mDPFQR4_>wq{HjuJC`5i!!Z_C@JMI zjenkbD=MPglm#MA>4V?>LA%~c_zNB@OftzVLunKX#u%Mz2{P*h?yY&MY)t90KV1tR z7NLqADbIgDufn`eJ#N2SM@xB9Wou!o9#5iDu{K=Ncpha_B(gH0oNIA0HqGn>nFcrK zFCl+}@GxLuJ+lb=v%)H#543D&s}t{l9gnsjGAz;+CUczL=phy7ao|B!U1z5YKOkEp zalhkF9s1zTgP^}A=B#&s&?ctkZ~86fk}qfM>@&EvcqHI_O#~Kl$31f?t?by)5oCIT z0%z|6-r>)`^p-`Jk<@JT{)VCP!Y@jDZlK@nR4rz0WnpH z;y>6)DJCcO#l^dmqIa(vFhR!O|Ip)1>T$F*b`mZ+nj@MN-0KHJoRz85aig3lpXKdp zsUccwtoG~+6)k<MA#(g^lDBkwLYu7e1QCKL*EC< z8{L3{6)f2=-CWnv3z^V}-NZB(@R^mfle;T8H&jt;sz=%12@1~P0^v66D6r<%ofmKc zj*UdGd;SYg{2@OhI^mD8ODoGlA)O4!!(5ZY=;|6G2Ajc(C?3H(%SMdll9H0Nr7^Jf z&9JRf7FSSjFWg2H0<%^{QGI=S@RRQDz&aLHF|JE4xqm(#Gh^QHQB`j4w(%QFG(vinOD(|!J7^g2z!{NZd`wnXEJKErQCn$y! zJE-R5%+w<^?#R(ZGdk1ze$(aIt)KM8{tHnODxBsv5(7%Q-9hAEIZYzW~U)R&A76DXHqEejVB%lye|ObxNi`&T}~DSVy7J=XVhASeD@Bb z)QA)ky{;WOu1VpNR8#&d=cHchZjTB$U5|+v29-|eS`}xw8>Gd8*r3VtSjCmPq7pAr z(hv+5!NeXqS&f^Y*F=Sj0!}HH)T}_!Uc-o{zO^iBPz?uq-$>1b|0O5M-(%4~2xK`1 zqNBO%7-7VO*@f__=2M`jO$~+hW=sY-glLW<1UG(MRiX4H2{K9Jf1#jrIhnOb$;;59 zITvZ|{O_}y^Trl7>*IF!6bea?5U?^UT#1Hw3CX(up~97@7iblUsX^Bh!l`>WGpZ`Z zUsg8)S`{aIf5bl()+tB%ECzod?4)PJ-oIAnYc((4mhI+W>+4Zlf~5iwH=r+iFhbS= z<{X8Kt{in!22~!K5+aq&Rnp_Da<~pAOZbXTRr&f-FElotMQX$=`K){u?RbltOqHNC zHH}McE&wf@G>MbI`x^)fEMc-&%k8RPL4dy4Ws|yj4X(>zfkk`X&G+J1ji+irEK(o$ z>e{+nRmasTGWl2FYs4o&r&)NVl!su5s^k`iJ^1tq{=n@b0_T~`**M{(bDCH`C$^V2 z+(fO(9hvz{<@Rr8wmxTlbSaF7JQ%#<#Lnr~R-e6tlcK`HMvVdJzISvaw#Bt^epgf4jxGDaR{1@)z(sd&)|W3PnMirYZ~D$sKQ607 zpv6X2TT-WO*;7+_R*Z7Drsi?7qXEaw#QaTfsPQF)F$sS=S~&cRx;FQWMe(l`mv6#j zPARm#s6~0!NR5xHdZY^`lYAu*oFyl%p8{Ujp^7w)YjVC>wRE6Z;zfjHwAy17gN;wcnJ9 zsjdCGelcrP;VA`sE(Cq_QE!J@188-DxnFA7u>iDm@%`HF!(pXNfjpfZ4>7=E{n8t) zkkqer@x>xz33fY8LhbL&gq*SU5`dNF(JRYBV!uKKEWanzX zK6IONsehaUC<^g5ACC`f8ug#IynC}%#b(4T8{jL zS|2^V8d>{9D2Qm-7r5r0NjF^W)TYuwz8_9zdcC<$kEUm-5UfTJq>iTRge0xvQUEVD6U!Jc^` zhy?pcLYWIoR`MsSk?UA*v)k}ouMnJSeU#DR?Xt9w6=_JmcGuJl)4$bqLL4GH-1om% z)gE}aCKQU_PB5bN(QT!UhJ^2@AFK19QTFGm`5hjm z^56TRDlHStSZc7Rc`dbm*5I4qQe(m#rJqTh$>M8>{56K`>E~=e?v?^gQ&j0Vw$YQh zDC0ve?k7BO$l`)E`sc)?TvW-d)4FF65iD~;RQgpH8o)9Rwg!yZ*;JaI%rfl+l7hfC znZIDhmQAs|6D3@V4YH?b=lsok_q19K2$G1rkR&bhgXIJPojOcE8 zY_Y=*!hwPw72cy1N%~D1PkRJBj}ev8V0CSaS!pe?@4P-#oxh2o`{eAq^ado7c9%Ne zL5{_v;>j&zNTec5J+!Lg#uYx@id_6 z%vSDzV@StWD}+VeHBgz%Lp<$yv$avr^V9ET=7%2UO|J5l$gN)a_aP|Iu-m-pg;u59 z#y)AMOmnzI%guDee&&3~U4>WD1V1|mQwbGlCG%@>j^6vwU87ho|JD4kSgCSLBAN~D z0{XY!_krhe!ieY}uV!ES+P&~(IMV(w2g3Hr<~s-m5$)#ph8cGM`aqA{KPMW(r8GMA z)wGOUiU>c*=8|({5VuMve3*HF-=R@!2ZdPkgAw{&DmXB+5tcnUNL?-b`?axj8}g$3 zzr>Ec0(>!6Q(Mj!=vqJ->_mKs`7~P}WBx3(tbMws`~lIy6RqZY+_dRL@;)R!GPo?N z;vTxLXx^`IDq3j69Fg-Panp-GK64_Llo2UKFASzRXCBrJDk-KEaG)vFFwsf9&pxOj z6d)aDUOu*=?z9o~nC7c~^3P#f#5i{2Y9yn-xB#5RSc z4)DMij28R$maUtGMAmB^dBQrou|xY^(i>3Ax7?C|Si!GwFDREBicFk+VRwULgG?k5 zB!{#dm-;TC4%D)8V&B~?hCpm;?B0{FbcS{4)Fo%309beF8*MKVn7|1|?2S8i*MI#& z&OF|NUBuhU<$>DCa7PGpziJRvyI6QPXL{-cled9?nB~Xr>I1_xIRL?BUGnDluWups z*=N1C1xat+sATT4%l786$q14paHaP1JJsKUb)>NTvIc?5e7$k`5x`zuQLVI*bHt~{>|H^y#$2KnCC3;Bk zwC|^ZpIPw+fCGz}Jtxx$Xak-Xw z8|$seCRcePm*#sV^Lr!8x93ZFUV6;kasG;h4`WzVSX9`v7iAu)Wu7AA4KJqrP14 zpr_eyD5WVZejX#poD;gF>bq*&idzduc_3|>Gr+HaCER9l?@C624hx|=9H`V?cx85Q zG65)26=gjf=bx47YbjxgKYAAg(5rd;l3;iB)0K+Kjuv=Q8pehme)Gp(TyU`(_I`$L1jqFdL3Eb zPA{g)g$X*9v##eRKme=F0vRf5AE*?zu7Qy=<=qY|%3e!wKe8puCzHAx!hr8nUYO z`aQ#FV_QvHXwsvXWaIP_w@Obwe+bRV#nR~81)_7oO8-2&6n()jWAZzsjMN(hE*D=G zm~A2JBBwY=BP;J6>FZbHJ+Qp4-=lTE~*?-A7sgKx84#jGRFbhNMy zuCXRE&w`oNQsd5u(r^r(X@&;JLdOFrhAG~=4S$rL7m+sw68BnlA{YGR;Vnywj&--% z{$h?1G4-=v#_U@bP3s9j8?gW-H-MbDMHHw$@) z<0xKs@v*d)sEX6gM{ToWFR65)%t%MCgzkDos5%Qac`LN(VoH)~k`hbL@oblDWBOTz zzVSf)$CEl#ZfF`nVHm?yJdbeYQ#IDY64s!P+C`f%Z*pq!%6^eHHHEA#qn)XyN5{&x z=Q7aC8=q*DCn}r@p|k60%L=0a?hXCS=<-q|yk53x8Y>W|-}~L95a#(gE`f2Yo8UeO z#!j=da3xda%1u@KvM14fIi0U_DlZ&?mXZ~hl|>BKydxxw=?8RR9sck%iOt%6mLyIb zWQ96Ek(*YpB9lEOS+h%7*YZV;ZIxL%Kq^smg|=1mx>&R6X_BJf7ZY<_dj|(X4~Hh{ zViTCHO}Nxm2r=R2a4ZxwN!Ipe2(A<70(X9EG?N}U)X9hTbu8=?@`D_`-Wgj?7WeG{ zR>+!Eo<`m0Jl@ull@VSJB{#0z%njasf6tswUTPlWVseqchVR?v1|1b)Ub|*MSQMAC z8&}W;G8fte<1|!z-OM#in9o%@734s5;IJHP{BFg8JwmTF{W^vx>b8%lml0+JzEUN$l*}LndR%&QwY&y8 zaq+-Vtgwp+b&uYx8)&NrlqsabrSIxvmh)#!olAMSjtbL0edxBuCXunUu-nGEu~)>X z(g;%PA}Vp?f-M&GQj_?NOq+2wmo>tA5pGiiD$1cDT*P>ud13O^pqM4-zYAHFbW`w_ z(3m=p9j}bb`GDy1P|;su_SO;6Zk~*TJ83Y+!>SnIPQ57ZS#+aldGW!GCGoo~o_&Rbwb)R>+NYuH^RmFc4u$d`&_QYw!tx|xs*PvI}kRLAM%YsWc6IZ8K0 zc7Ho98=J|NK6afz&0;FPevu{lW zV)cxjt$B()oa*OO%(K@vL4RXS%Ggxs=tY`4?hot7#-hZ?S@BkDf{@B7lpDjPzMj** zZr@iA-r5Msb!7=l`0&((e-Rb7&12k)z}X?S;4U~tf-s%}d@2@y zAfAk)FpU}Z8!|6i(SKT{DJpdi+Ss}<1nr7|_Nq-RPbsJfC^x(3xYbCr!kOG%Zc59klOim6TNXiruij9K9|yk9(m`xSo}{-WueucX29)A1L4Hi3|-L{m1QS5l?v!cG$` zXc`IHIM$y#C?m&utCkJ`)-&Rs-?Q*JE)H%D;Z`dp75k6bt>}*LsTT=Uj@&#LwC~nN zqaZzLeo3|>J`SIJMG4gMo$&LQ^{)qXp@Vn(g4G-g5tvp1+V z-z#Rz*_g|D90h?Um!_n-_9yu{c#KbLkAR=#xJ?~=eEhp}tS&{iwA~TKdFygj+~|*9 zryJfdxuPUL9gV+r_p}j!OwpIy{o|~8hS$#CKw7+V{mJY_{O5q%%={T5A|F3}qDF3O zBLt+UA7rckZ!N$UXe9!sAD5FYS6qsw_QuJC#EOMQL4tg60Y;k(Qdp(QmDxfp~u7|8_xwH&7osGrBNK>m14$?Kdp2B*bh?*LE zC4+CtIe0`AY!z*r=Ot}i$PM`N(CG(aVleatToX?_0uyk7Sh&#x&CZvY)a-+IYaN$( zsED2ISI3EtBP00uE`IL&sL-eV&Q=l2IY%vaTl}_>kg+_RAGVyuo~EtvqQ$T%uDQ7R zM?AYOF+0W|i3mvt5o6nKDNHKXK^@3tsq{K#0v<;kvOcZeuPbOXMbo3Bqvu1rkTcNz zv^f`s{>zOJ8k*qC)leGg0%Yi(oBr+rZ2RL!EG+h2v>8_Lw})^aP_7s#tgb{L;L3?s z%K73>9lCFSwtHI+?)vdODMWTwJVBFcMnEg006lLAKBVt#`)O;$QH}=3@Af&me}faI zv(+u^{n_97MfUwcd(qBN9(m2p1MjElGjcNu9F5a8nT8UikOf>11a?8$y}Bc3f6&;B zI(ljfB1)@NFM)C5C29;Vk0(-@S;T_wPYha>5cE2(M`1zF8=BK>@v#>WCd>m8Qp@WZ zium2CuUmc7`eT(AXT9^lXkQx(=4Bk znZ(ehC`>IUQ>G}vos|bhQb%BF7#RWgae-c|{F*{c*-B|$6%-@@niK`+gDp;tA)4Ls zTyrn;WOJ&XNh#6#2EL>r9u7V+3A*WS=tBFx?H@7bUSl}z8S}fB^4wC0ngnX?ez*4w znL_By6H_j(NigRMx@Ixk+&nt8E~?23LbFm$7$Hbc&o;O8h%JclwYIgVBvR$z6l~cz zJn<&q*Ze*=N0&dzU}!@U+UoXQnybc3)=I$tg3x^*FTxiyK2mmw05t_>LNKBs{Qe$M zWPI~(QmRlx9+WK<>l*7PHa^)G@vBPF_2K_8b(K+3{Nb7skj|w;P+B^rL=cef?j@wV zOF+6rT1r4bng!`tx?8$Cq`TqH^1t_->jym^W`6VL^X3c_9F)SlZ}2Yoy>DN;x-wq6 z*KqL&t9r%^4MitXTbMX~zga-Q$5-8O^1M8-&QoG>`q>yME{p>7jj0V$2Z~G}3`42N zKR8K*Mmhv~3fFr)uGR$nu6klqQ_^Wbku6ge&I1UM)rlW)%gQSZ>SVLS!oqqfEuxJ5 zw?$Zey(zvxPWmdkm$1V=(vXWfg@uG5pXOnP_!6z8ueYt!OEjTIV_vxYsdm^|fycVo z$3~0_@dDt2n4A>e$jK>UcYH1+(dwX-2T-57)l*s;DG3DyKnH*k^77G}W2^#YjDK6t z&oP36I{3F!47Gj&q7C+mISK#dvuxqx6w%#{ytP#+C-=4b5izKPd!it`D+lnL`*nMf zwc~8NHOMB}3k^||mXvb2t-r3WfbZb*OBp`lWLo@|1+E1r_6%>MtG7jTl;tb^(1=`I(dbd zZ((j}*RNKlYtLMp#setSsU86nbJsz6XalEf%)$d@#2A0~lzQ4%2!Ck!zKUD+G%@ELEtSB|63vn+PMDPs8iE-pvbtep6mSq*#+<8uSDzk z>B7F`>FWdbWe(d*v@OmjHl@j+idLM}f})aRU0F-zFOja3%V z85dWTkmqR>eY5ZR1@`IU72EiXk|84pbJY}pEURbbY|eK};*yfx{BGQXb9&8=p=7m( z-{oGzLLYz?N}9Df(gQy=4J3IH_syNDtQ0#81HmiPp-`OL4423{vd_nX5Zaf`{xepjsL3DIkLZ}>zl6?<;;d7#{CvnTb(4`VN{mN7TtuK&TM)ntk7ViWNtyx-* zz@_@r^*WFpO@s)|mw|BBMB8S5IBGIq4f2#2@S<7E)=|?t5!EMfpFh5CO4(8+XKe^J=IX?2=cCZiQXkyu6nlUsS)JS3KRvw&D3Cc!YD( z#FN1VOf)0AlJsh)rk(PJZX&v9ur##j&;D(%-HG_&;6%0>dLl1E5|5^@8O9>*orhlYNBNH}tTi1`Dd6p_DR3Bw;H=k@<0T9UJ zjFnDRR)qsdElz)gxT9}fDoKwh+@tAkc-HJ{p0yt=Ftb=aDL%_N?{`uY`NpA7GKIqb zhTOmI70c+34Y53bri{NcI0Nge|PKEjt=IYBPx~j%$+=ACZ zY|y6u1>&HqkjRVyTbt=3udL7u6+8FF4BG-%nvxS=t+&io4_u%zepcmJ{gjuwLlM9+ z`;Ohj>CS+&gZ=#A!5FhQH*ZJ+gM-P>(9tolF#n43>!q@5_zM~cJ~8>@>x+!6jG`JE zw)DprjI__oUY+@7k<}iRro3UK668zHySjnvGJ!bFTeUC#KBzn?y_mmG)Omcj?J<6r zeiELS*7SmBTkn$G;Do{YXA=dg@%zTi--HaMF?Ez@es)!iDoMIxe+`8SA^Ld5VVD)OzRpq>GP$hYV$Nm3fsSEwj<>x?!H6NLS22}jKWK?E?T4zXMG zeGkr$b1qN>5Yv&A^dAt+w~7a*;p%+1;T1DQ=(E_q_i-O*xR1nJe%|OlXjJ;a5EB#a ze9=w8;2Vj3Tw?!)t)*dbbQBfNwsc63-}{Ft6D}nRfp1n??NOXUL6V$E)$9u(3GBPj zT?KpK%F>X~gQ7X16|L(|Q@QWJDooAw&PT7FcA@wFKBqU+7|~H0bQR4bx7S;zJo)(+ z8oqDo3i635ec|)uWS&}aT8_^C)W+7^B!!F{5Kob!1=GOnC0f`0mw~DPF6n(Q?ETjRQcPW3qV!ltj(5=QmRfqM)7&)& zaT3bONJ-2-@4VP`sdSoO8D;OFWxKxoC-0SY@z#5&9~v92$jQ&$J|uWm=l!I-c?mqt zzNNW4TK7t3#L}v7kB{?R_X3*OD*#z>o^CXuo7hyo{bwy=fz++o2!{*atp!L=jc;h+ z4wndw?}o1D*nk(M+% zAP5rvN#yn#{hvx9VODH)k{YL&w=P%d+o8le{d%`Zz#SpwxzUDCTjD^DkxZ_cNu)3` zcruw+tb`_UiyXq<{mrtZ;L1~UmjbM~~NaDU-DTp2l%$Pv|s=F|DB zi$d;c`^C+!8GgW_vwS_nlac@+@&u}x#4*TA0P7ug`&}tT4>ms?ops5nN2@W3$*P9L z)PP9I_8?Q7TwIY~;c@Kk>(*LkI z@aRsxyUGaZ_f`q;;);$Ir;W^s#tI@KZX- zAt9-$Y3on-{(fE12Tka&=4L{x7Qm)a=5=)W-8R^zuqdcA6gWhotuvhjN9lqc0f!1P zkRY|oeXEZnOL>zPChoRviE_UwzLYq1| zGb^um>HH(ZoIHZE%dI#7a;NjVnfA>T8J8il@i=`}gnA$__eRem+oDnw@tMhhkG883F(5-@PLyCU3t&q@%Y$MSyiN(SB0rZMqCqmdkZCAd5t#LARH$>2D=7uYG z6$Szz&5{s`+4cl#17#LuZazN8!}FJ{nSw*BnGbqG{y#o*%VaBQRkr>lm;Uva@q4($ zn=g#9uj0a1Jt^U)R8m`$tq)79H=I3TwI zM8BjzFEkCjjL~>JKOuUp`bp;pSK{4kyqEv>+Npn%>s51G0KEVl(d$F+U*A8S*3+Dw zp_2AZgFG`|fzW$Mt3gCx+mv48WmgMyWCtzY@ArrTwlvPi-uXD} zOKhE*p7{n(+g1)s^f>(Edfo+YVPv7<;maQ1kqQH-Z}U{I#VhCzJIm=}&h|7I!2)Q9dz`rV|?#mi9 zrL0_^bR>eUDBg|A+Z&4Rs&U@CEnd~TKXLHcEdWV>%bT=}0>;16nwls;8zB2io9Zvu zD&G;%sfdbI934-1rOAL3E~)8zU}3Q>lM&y}aZbLo_83I3*X$Q*U%p6#iV*P;Xrfh2 z{MI^79bsh=R1_dPZSwzC&StpYU#%;&iz!>QRe44rBc{Vz`yQgC=O*;tF8e(zQKq0F z13MpIPfLv34=(wI`lpn~0AnDSezy5tpmqLI?>#4FF%mm-bl;iK8iYT~*EM_u+A!FD zkKo5a4FqSQ+rY6nqPp=GI#?WrN#Et>-J*Vm1Px{P)x(?Btp8U)l*|x$3QuJ6p|FOV>ZEdNZz8K-f@Y`ZV3*2IO2oY;0S_ExFkF zw@;lK1~ffnRrBFT&OFY4C3gs;BpvuGXX3{sgGA}q9KF06ULy3gjj>JVPMZc#j{f^YgAijL55 z>1u=_C#S$6z$4EYxAEScDAS^rwLLZy>&-HFKhHPRa;mx+hs_z#5PA>;5KNQ}K1y;}4 z>$Z&P?^bH-O+sGx7S37(s3Sxq4r^lt7F|g)ph!>NF-#!qZSGk?^9#WCd+ zwCZ~z8f5D->-q;AR)!s2@+Y9HB!=Z9^|pDLa{B`h4MD%^yXJiX4a2L-?YcS0>RDyt z;HQI|8kPzz=SFR6yyzgJ>=fPl?J@vtdKE6|d@VKR2yCkdtRJwvH>0PLvW%zk5`ls{FJd83Vb^|M&5*{GH*!UgOkBi5oQ9gJSxYf8HZV2xPcnDom%W$j%j@ zJfIEBoZnzr&uiZY+7$9(V#ukdO`bAFF)F$e z*z^w90##Ms`Ohr7_K!5|Je4Ow_^{_YZ2mSW?s-)yQsdq@%anAH+c3vGx?c=9UoZMBUc=6OaPu?3D(v{@N+vwzBqXscjSx-{!g z4L{xN7eND!M+b@cW*pHSS(fL^S(&U1-7?)W}daz1V@CW1D z3I_WPP;8HON|uF(={XO*K;qkz31b;1awE*n65}@Psm-cY2ahQ{{@P!Qsm%H_9C=E- zpq^S}XbOFUfH7fWT!|10nC;1a}#msgX1EicNkoy;yFIsaAcs z)(1mjk^Aa3N@tWYwc*ASv0vMhPer69G>q{G)c{Ls`Wp!|%4;mC&a<5ZieJs8snXmk zH#>>$yvy(;k*5?a#6f2dm=wt{Zf}{1B$|D4?4w@cMdO7dA&Fg=RxrPy6|;;E(Z$oN zlVLYDm@`wo{LWqhcmG9sTe_3RG%`fG9+PAXGl`HN)&UJAf&hE2ZwlB`YWL89TBGHN zEi%Sev`IwEI8Mb3XyWo&W7NT|F_v6}4Mw#LeS=l^yhB!o>8ytq#_Z<1kDyiZzJ$Pt zJZmndvx0w~@!EcqYdpmd@y)F`NgvcvaLTsmW$m;d&Q$s$NvedX85gfLNYJg5d)^ES z2#v7TH;p6wnjY4I>+%q;VaMts_ki$Gz$~aYpOct;m8fhI{)6qy#4bGIP-hWh_F)#r zTpiwmqKxI#f9+T5v{PKHxCh`k`GL2eN9G8=OV@*wj{?|*W>rB%Hu4k05LWB@5(No3Zy&N z89A~=y#X^BS`cXOf1C z=2aM&(Z^kUn*4qnN^dHz&(G{VICzIqjto{#C--2rQHH=L93)vza(g@*31C`8z`lxL z*F^~PK=~%_87&Q&O084VC_v*YEHOrXzlPK^EM_+fioTfUaj>M4Tj6$R8Z%j}OT3L5c^*q9Fq^S4&J43{-lW$y;nKPZX3QUJV4#;$K#Lkyd;w3t z8HZ19q*N0P{nsVE$*^el&gdG~&I)Nfx?Io&F#I^zvmym0YXS4TedH-4gsRtmzwiH9 zm9v)%Ssbhn9BvnJfu@Y2yF5U5pd%;;-F}|h?H~!b1JtvL$9x8|X0cIcD%0cQ7v_?s z$)v`OT~E(#X~BY9zG5!!EzWfTV^mdX`#`mBI(oOyw5(Io(+w&EN5Zs~vpMuxLkmy0dgWcKvCNnl&R2BsfapMEHIZO1tDAq3%U7o5TBk$J z=Vh6{P%RAFT$PB< zx#Iv|(blsjLKKp4t&(UBPEc?l3PE_6iMq0@N7E->95^~C&Z9CZn~gX#OgS%-Q1D{9 z0wYg#BF0^jn~H=PghmqStnE0yF_fn=OtxzduX?Uhg3KY+HY<9Y7B`T>m%N~D6*QJh zJQR;^qtZ|g%1jH0*!#|^M7ULT?_PELCawB>ZfMUJ9%M}{rhYOV=*bsrnZ$JGiOFfm zH{I4N#6W854cuGv*X>;^k#98A(L+~k5ztjX1re{E@OKmE#G&rG zd~bP#pPSbKW>M0**7yl39+TRUbMI+U$!o0g8jetmCDblF_@$JVLl{A=Y#{+%#T{>9 zhJ|@DK3-r3Wm}^E-OxUHj{%j(vWn5xv)Stx3*yS7?qI>>EOOWI>(Vm2avC;DFJ$3k z-yG4n2{JK&Y?%j_GPS!^ZTi>X`8GLr=QJ6t0d4owi#%qz5%Zhnr*i+5_F4Sr+K=+~g zON9Qk_0;hr^~Lm}QcGoj#tNQ1jjBnC{C)$$3I+>`CI2gOpve9BnzW7WHnD4AXo%$T z0eUwLjIPH=l??0s4?L`&K2QDda2TRE-UkiFi2FdjhGEpisIh)B4&LI8gsR>)3x(^x zJJO!QjmYyhYdSwsKY>^av^UeqX`!fF@>t7XkVa0=-)8F-nidx&dCa%wI2pbE@EVVj zGMYtLo|`ngknya)?QL^YO;7Y>?Y9v#M`zX1N<+tBn0pm!R%hrJ=Gmy{{?=4D?T452 zkn@B_zKN|^1G9<*q6uB*{4+x*4uY(L7B=&c%NNMYi}E$dCJ#U7U9*gAeVN}l%lgNX*J-g_O20b$!ann5dnr%dQnsjVm4%b zSQ|9i3jC>3){d5^ygBZy``6ZxYqLp z3Za3ggql4UgX=2q5)mWlF5?pZZx`ScIu}!Ho`|Jf{%%_)vSz@|pu63xE?;4A#<^EP z-1l!tlp+v{H(X1T1U%HW37G%n<`wzc2=pq!y|OkdLO_$rQ?|#Gn@9(~aAHfS)nwV- zt|eQNof_OC=uvOOSfRuS3vT-?#VD_2tBA}tS&(_EGN)hia_i2KwIiEEmKrm@AlwvI zWy*776ZCtrnbBy?MYRcbn_4??I*;o&LOV+_X>nbL!yl(bH8aHcg(!t<>I4Lqe&A0v1U)0;g8b+J_33i|!8u(BnrCSalih>oDoH#K zsQP|9n#(A*=95<3#jJ!-w%?yyr@x1Bf7Wb;cU5U8EPF|dqm3wD7(X*Gj$ZCP%#&`h zZN?(Qqe_?A%`#}YZVSU(9)s$!#bP@5e!E_)DK2)NO<@hI;rgnE4-jIlx)8S#W+My9 zY}Q-^aoELEQ#woMylp|1q#x89wNdy{Wh1u$J1 z)YM0FFb6i2OhqsJeap~f^LC_It-O@ssaSY(A>GybzDw!l6M$GJn;QEu?nsfp2*A_wM%b|_QWsWQ9q#r8dES#84#HWxpE3nZr0 zUEkgnKZZBSxaQ~FAK8|i3oLj>)*3@7nMixOzHe+-b&~R#raUV!GZ0-njw8|1*%hau z?(A~U->$usLqo)FfD;6Z$p#j2MIGHHb!~Thg8YD>fB}`MX7Y#TlQrAl5}pS+P9A{RsrbY%|K>BAb*N~yqCdyTb=F_x+v_`zqg);^lLih;%09A;*hJVuBnLm4cH;%ZMy zR&y?;<|*M|EehnBNC!<1q)~0WHS0qd>++b1Bi9=1|55a|h{*P+LFMVX+1Z4W-p02dcYoP?D!`r?r~m%IeA zswu3zqVZOB)%^X>t!fgze)@0O4w|}OSD*{>$PE$~*jPHd24=oAG-)R82_F&>P<5B1 zYM(HIdo*97)ywdT!-zF4*DhSNxfr>qL6^z?z29n^<0#G3+KxehZqN7rzh{2{3{MK@ zt7uw7*hbYx53+1MTB10tN!PnnLvmnFyXV)g3XDZ<*{Og*6p<53zvu{fT$!YR@-r z17wxBEg77B>L^%ur3jTk3=U0h-{X7r;XhF-=nauKdQIEDT{57JFX`~0j3XMEw9$aw zdnp7c;FcbfHbLwA&kLMkHYMT%-6n(QW3g2lXKRh-v5X=Gh!5V1w%>|;fyuoW3j%QB zuBiFDQstb`a#4|E5$Wvj3PN=TAxs9dtw~jl#RNq=A3*rJ%zUCDf7|Is3@ zhIWx;5<0v_bVfi_x##cnEH~T1i9tvZ^^9kpQVSIh68xhqClXmE1n%>0H-A>|4w!-E zOSO*me@gsd+o)}QFl^1Y=LxcYo2mbBo)!mbl{{`xP$A#V_H7|86eItt#$9 zBpZ>{uzt9=ml(F=aI{u`;TOikWq{i62PtyBS1p+feRTa}{3(eu;}C0_F=fE_eL_K}UB=GVH{NtPaG8T{*GU z_9Qm^C{t_hKpSSqCiX^qE6{-64t zzsWRBp8bb)hbm*JQS$G1BX5&kZS=ACoGzU%vv-IbhjAmj;P@Auzx4Gv=m&9_V^-@$ z*0SYWE;$*y<78INxRPovhdgck=zj)$(`lv=m*RpZGUZz}Ykf{oH*p~KPhfvdl|JUw z(Cbz3S+Coc)m6GvIvz>&ie z&(Ce5T16JAfN8vUYQFd5yxE??XS1m6y!!CWpgU%i5)A%Cd_SMZ1(^RO4ZjblCT1KSh6K`h@~ywt%1eq(ROuml%W5s9qp;UQ*?1 zO-r79G_0GfzxeZA7iKALr9i-GEF-Y8<^<)rf*`LiuiYIVu8d9Cn?60K@nQQalc^D> z(2JTLPS7eM3H7+w7wXKfExI^JXR#ivUINr7MGVAB{`8zwSUoN@h z|Jk~rntL#(dZV4agi_lQkX<-yu?1Q##d3tS_x@Fp&MtYz&V4MqV1XK$?5VVuSv9r? z-8B>;Qpp?v&~U`(usJKV=~mTP3DC&}LJow(mo<48ix6`YXOSiV_l!>Rxo%ZteZ;M+ zu34S0|2bf{v|%HZ{_jSLg=THqqEq+RF*)u4+e4an!V-2npHhZKLp-0!k5sc@_;Yp{ zfZ5xbn&H3tFf?EeCz8>=c;4_2U& zd(sEn$f%F0&K`^O$Nzqu9Qi5unUj~OJ3ND-pau*KOXr5{_`$7_YZ>?OkW}LcY*yuS zLbd1f4=ak@mzIE57dd~Zf>pp!7@D630K8 zr{WMi^(m&MeCFM?8g~{=4NbjE%K@G)vy=RZ^`sr)@|LRCqM+8p2*x9geL(|Kkr&G0 z#V?iLZ-P*XZ03j89ags`&Z-Eg^k3S~pY}?@%wc25Mv_s@baA9ka`Ut0$*J2*Rt)}Q ztn@Z{j8)+V36Z4NUTcVQIA0fYU{_1j)F%RLZAf|ona3~gq+ZBX_#d}bO5Y&Sd$qTp zVY!}Zk5c38!RyH{o=3?mo=Dg&ua6`D5j>lX;8d9nUn4Dl=qDaLfhXs{T@Ez9MWR zYK8`XCO`E1;Aa;#6PyJWQfzycyNz{q+gDDIZk`Z3TZH?GJu0c-vHfElLdZ@qw0)+R z*B=2U*vjdNnzTG~->$Mgt$nQk#q(o9BbLX~*k0ps!)xZKg39OAa#jRLvEFU7PucAZ z1Noho{uCk+j<5?Qm%!JjYff}zJR(0(S^a;OZVuR@n;FJRSo#A;eoBYGQ|hiM?FGFk zL-nlFKhkH>%}ZXa9aQS8qE&;rKoxaA1xb4vuhAW1Iv*5MyKw}a0!+C>+7bil%{ZRr zXl-knpCcz|4=q%g3X~?eJ7f#vKB9NsyCP`0C3Nm+CSWH9JF<>ycs()|Rn^-SJ8pHds{CM~=;dY;4Y17~IuuXl7UXs1rgh1!neCdi0(ad-~u+b+7t%MWC zWiV|$Fm11g^gGKQD82W@??W`Y6`#IJY`;>`@qBiOwBFLw%?HUu@j?Qt<~q1Qxbjkz zyN;*Ko!~6oks+Ms6YuW3%x~&S&ll_dsu?Pg$*$o7w}p6Xr>G%b_BPH?-Rr7Mjk{fF zMlH5T7}k7IWJH}kiGx+CVfzGWV|E%Okr(E2n|&#AD3?J`$ZUVbrq@S)Km8Nscy_jE<%~jxCtZ7;rsix=717cO_A!~xClP4S&E-?xj3dS+MVp`ph=$|s0SLH0 z4s)H?n9u5J3k_5BCw>xXO3`e1>z<(oI%XEk2Go|T3USYfpr zSW}Dv3`s3{K|Xc%*>XPA`k<9qe07#$;X+iY={#0~^`1O=;Db^pZ%t@V%3EEu=Rn!` zzMjB}mah*G!Qa$=|BmTaTtSZ?+9+FFRP0aZ(6thjmRP|va1)m(8j(90K4rDL!LvG(K%;V8e=j3Ac)WQ6Fak50NGWhc$e0 z8DVTv!8Q^VReT(`57)cFd>llVQzgoWV8;KRA5cz#DW>U-D!x$(QQ`+%qVDewV%Lq% zDO%rwNwYxil}PU9QgVt1+TVp>-|cdalXEnBrA8f?a}hG5!^kLC?lvYYwbp zy~VD$(iR2Y44@;e26ifdom2`k&THY^^R1q!4s!o2n{^So^$P+MMqoA`GR!eYc#>E=*Lt5xgd=ym`nHdpYF%LVIK@idlZ26y52EQt8ai&6#U8RT(GtIP!d1 za;B7^M{tYsBv^DpIyoV;Np`p1B0q@_Y!>c$ zl~Dz(xmoocRBtK+*yTJWP{JW06nmK<6?w$9;{Q+9FRwEKN`tHg+>)87XP=wrD@^Z^p^V#GeupyHZ))|(yv zm+}<+U#Q`hHM!|w&l}*Uw!^&9%Rc>h8e`4ZSn;jnMU1XHl@ss+h{)3|r!S)T{XM;C z=asLJJ><!mHi`R`#QMDc z)Y1~)U{M<;+^F!lgUx;r%AOheX-XUdN4D?^*4~leP0l4(Wf{(JmuuACrlyRAd>c7o zds?#q=IUeP39EnETJm8_!IDY8^!i2r?rwZPLka*JzmlT+TVWtRugbaDb3!5Kp4Cl7 z)z16JXpeIP`A~J)Kk~Mb$~bFs-QwMe!o!;hp5&MbWN(<)H-1Y0XNH1hPyZ*4r+eiN zK686)YJ4yQB^G%518m1Co2?ZQWuV+ch=gTLR7-UA0B6m|nlo3L?)-;XO4j$s?z!?f zve0^putQ)#h+9Yj+c^8_m!wBBcEMnQ6JMKY#}}}<5l!RQVw;%Sl+u}~`tt@emVU9; z<@ut;W}$4DjNDpQ>EV5H>J=@~9-^3!SNcN*<#$T3Ir(eaYpLPULI>ZusCA_Y$;_ zP{Z|+C~$0ZtmQv7ZCIe|h#QWvsEQC)R# zV3wtmD~JDfW}hInMK8g^CS2tiVVrZFS}G zOidi;uDoeSR;h97(}{GAL5=ydJvQfxb2ZF{%1MF5Cqm=CJ6)Fw89$6MhXkov9^o}Z zV1($F3thPlt3Q_VnXzn47A>VS#C+BzGJWLf(xqBQrcxn3>Y)@0wOzsH!Tgpy^XOWT@ACvOpJLWdr zF)30)c#m1~OHtyjCJcyS6qA)SFeenwxq8#~{{))jO@w>!SZs-Hh@U^xp-H^M5?jG< z+R62}Qq56p3I2>G{>4)GOY`w`idx5-XF?gw3boz;D|SrDxYW`A=o1@twpSC1UA@e? z!{W%eVpND-!502if5p@`xopCE&qKF}I@2{0fS!~jy^9Frxsr9lPQxvvKTY#-w>5z0 z{cA}ul3r0LQS=>gPp;iD>k0M~-|1Zp{zUaReiy8GtImle$roal5I&uF*zKnvpzBb= z395u2tmJQnv+d(o-S%%+8xT_gkAx{=8s$qe2(^3_xKnM)R|u1U7$D^+%W>DMD?Ycd zH46=z^lN}}vGR)l=NQ)n1%LZ~0IkQYWneH%+#WR7BIYumjk90Tk-j{1E zB&pvE>+Llx=6MR{)3s``F$I5@2g|etZ<9wOeY5(ahIkkRrF|zL&5@SJE~p>^{gOf6 z8FPW7Ssc=1gUN~>{q$JtN-L%M#G|cGBP%**$zQbyM zNT(6?aVGJ zYidql-p%1EUpMQx27UniuhV}{W_3j%FyFT`ZLTd+JxlJn(CK}r3dj^A@q!qhKXm)! z@jp(b#`=39A*j61IZ5!me~1gRC@JlBH&S9wMwC-k2hjEn!8X z`*-cJmKcbpIT!>XX3N1YfJs8yijVnJBgy}p99VFy+bt(=z;8vX?t>f0$Pi=QYmV}*N-O0IoQRp1K_;#7fX7#99IoN#xsZ7vTNtq z41X?+Ub25lCCPPol8v|$=>m37DBw_F$;&e70`^O?h zdSOo8tOjh80lMgIa4_c9y`8T_wh>1!%YGq~EvAp0PxBOVAHA%4bxam0@gs*Q*aL5V zJo?iVI1tYHm|P;K9E*WK7yz|i&@8NL1(+qJ`dnP&E!Vc*HAx8?FQ@Gt!{a2hfPEn{ zgS0yf_)-FviiaN=5fVSO*EUxZc)5+dftWHSBXEJ-$YG2>zl8)}HYBzHy8+>%0IEn(g#+7zDyH1BuKlXb~mB)^L1mLCY{LIDdo5!A+VrUSTC^C6IM3^Ox=>=ii z!s}tJCMON?XFq5P2F4D<0#Vhnur3r})=OVGQo1x$jj1a(u(_Mwno%c8e-rfz$gv1p zct9$J0;deed(cA_{f^m9UH9s|5q z*nhZmm;vpcf+-wQp?Ft^ywY`oeZXxcw~V!z|tA81+hz8tIp9h#Ss1R=nKnSE)e~UiEOe$Umzke&QhkJ4J89nupdNyi&!(855|5K5IN4Cn9W0 zylgEk-2<^iD@Gq_Mdng`-jk?x#oMDqSYKxWjm9TeB#X2S>PbZ#wz>nKe3ubqm=?dn z5-$v6B&>*Z{(W%gcqvjF&!m4(K<#0Zip{9Whu!LkRPP=OULhRZn$u&JhLoC4UyVd^!&aZGi3)mEsL^UjT{d8 z31qrOK|HPjQlCOWFnM8yd`%-nT4mKPrJdb^{t0xELh+N5BG5;pUV^u!7=_{&Q@;w`-UaJij_5=d=v?y9Js2Ro_$7~8$o*?P!Mipag(`>B?}fi3 zRw+*lBxQ|z_(R`G#R2q{fH%F~)~LMlu^E@E|xGj~Sv z)cHz*$As%ii4onumgjb|+ z=&wI4C-p8r9ml@spqU z)bCiGLj72kIlpfz#5okn?dJ@$@&42zesRMRds4)%{sk(4jX>&>aML9ak}-IE&RAC@pZsHgcu1m?qUh@S za;}gkqF`NYjoc5`un+57-V8k^6LwKJIj8{@Ds;SppKiOgCwi6Tj_lMLFgM?@gJv9Y z=#!#y6nV&@@^$C4Kav`Fz$X#$b+D({TsTw#f7^c7+Sz%tvlET{BVT7cFxYgNwcoe# zM^9SB$o<0W_+ZO5bo@W^1|MSe+jN4+SWF`SIvl6*Ly*)4;E72W6G-f=|sb5s!S6;kl0KOCL< z(R_|^E9ZMipeQ1#kyvr1;25#<%OITF6f#;P%pz|Trb#RV3@*MHu45(P`%)O$KCZZ` z3}QIEJVka%f2^*`(inKEbSFAi+?ZO`XJYJZI$VmH$xviHwkRC*tonG103g$h0n@Ps zRi~UHXa!M1!<>W?#%F%YI6V5cXG>LdNJ;y|s876wci4-oY*Er$7}`KNH9J%Vn3*B4 zu~KEQe%q~+KJSqg*$ptLQ&4#U|1(u^Vb1I>eNi#{4Dq)dC0`~T7Qwt1Y6KTCefQw# ztF=U$r3K7%t8lGSiM-O}g$w*nsijv<4pAmGY5dCHMN2*=Zja4l{S2Y(Mta3nrTV%N zS(+YCS;KMCL}s!iSDyo0^DV+B*{!o^k)PG?+WGdF>OTUat-&i@H4hjfv(sPauaU z9{TL;pC|GZ9Ti$aT=6&bv2~eOYq3|b4uji#8Teu%=yTsqR7LP@Vz!37nj14>;RdNh zwI!j&`NR!Q)%p)pPfM$Lh^7xk!La1R~wrhQ1qnlZ6o zugNFc*E=-abE&8&43OJbw$Lv4^rTPTjpE$l?t6ZO!a=Rn`~0ZYPQQ%VU2UZc;i~bx z{-OQKK~-~eO#0VJc<)U0P3Vd>-<#gWR{ZKfGH(vG(?fT9IA6;-YhxHMlQ3)SN{~9gyNdl% z$SDi|9Io)9bp?J)*6cp~th3DL!~Q=L}YkMv4dl=5gF@ zxj^raLD#8wld!VBy#8|((Oc8|JAeu8ptiXSlJ%oPqcI#x5#bvW9#PY~8u$|J=Cv#f zv^Sh#0Q!B^4T(wge5(UtvBdZ>Rz#g>+cQuD_PTR)-!JewBhO29bg~eo9(0S7HB+vK zpV;SjTO+TQh#3u3UlG2=^BWN-Xkw|BND8Pb#|5Ody}GTz=gc&Ot9MN-=!$7&7I>Jk z^Qh1N;C-}WmwYEd{1ZxTnyTx+H>;=Eq_A5A-X`D8}rwD-uP*F2> zVt+W3_OoqdD$HXZrUDEwnVWC(`<-GeKMhZN9t`F3dtLPN#qf1cK!a3j7{y4y%KhB2 zM7F>;1@=17q^t>Vs68g=C+k_)QnB>@r>U=iiYr*Q4HjGi1ef6Mgux}aYk=VH?oROF zAp{E^f(Cb&0S0#sLy*DU-Tuja_x=A_vsh;hoSvRK-CbR~cJDGrv3n28A-!YBtETUp zS7_Ezh24LgorKfs%>5pkO#tyCGxH0I8^1V#O{nI@%n%LG?Jh~M%|+^`(~1TnX|fe) z2xxIfwFrcg05NoK-}Zz0Um+*VG<`xdRy0gT^V5cH$cGHOF>7o6AL*JCHcnu=J4q&R z-HEBpN<;$OY_D-F#K8RBLcHiX%_wqLcUBO>)9naQV=ZtbYf%TqnkAeL3PUje1e|fB z*47uUC6$mhbAB%ec<#k;UdW90SYHu$MD08b8LsQ8H&@_sk>-~MlU>8-GRpEYC);H? z(@r_v1v^Hh2EES*c16WIKR}pY<@#I<(>$k_|2c*a<_U-%h$5AmK_7tkuYf0LF8C?uU_A#W_0D$}0v zW1&SjrkA4vzDa`>XF;vz{m+-Sp#_P*pY0=) z7CeTFsYOMtC^Qpi(LI&)Dl)70XE(JkS@(3zZY9prTxboaYSfb`oYpnpUrL|bpXOg! z&Sf+L)(UTV^B(%rEXyC;DX~wa8s<|dhI6aXv762dHQ_x;OVmo&C0~G^FVHbzU+qp^+oW%$ST4 zSaxTm(D2o+biTX*9yh1XqL}xQw`TGy{{sPFvZV(M!5nHb-?Y>70i0W1R>6hXpi6!J zr?whKg5qML1t#WbK0${RU$I|n$QHs;M|b!`+i%%$q|q@PzayT>ln#&8sxD8_W`6{z z3%z=>m6l7hJ}q|XxqYL{-Q)4*Va($NtC3Px@wEp#_oC~7`@0sU5fBF^r5bM?b3-$& zuDyCD)3SqLHg&cNoErPb+tWqY$k(q=g>T%CnzipyVkYw5q3DOvC68~LMK`d)_Ya3u z=D19dEPW4EV!wwrI{#5uDZ=xMs+p4>edI0E1``XuN|2mZ{3>$GT1bK@k&-dVDS-ys zUYlTF6Hx!W(pQ?rtpiin0UagYwH0Ah|uJb+qCfWimS~Z#O2apGNl0~!Vv?*@>+(> zK6P)wkH!aIQrmPkAR;JXO>w=a}Bj@p_KNkp6T3Gs7He7W368~Ro7zAObqvu+K9HDxX%x=m%`Q?;3c zOoFuuwMi3^+UmHSkygEWF2Y)6yBc{D(d-U~z1$nLHZB7)#+wzL zr_Cdv_ik}L3E5$dMp~^MbXiW{3H#y&O7?@+Zmu>qP_o>0@9 z9Dho7iUL?6F=-HmJqFj&Oj^v;Hd_Q7)&RT7L}wz{OlCz=>{&#Mn(nt5U?^R^>?TqN zIn9#S6+t*8HqOra7mHh~Czi^Iog{SX)8kFERe-jRMx zKFK8)V2Zl~lP#%^u{)6dgc38eSREN}eU@|orm>=pB1vIg)hH}EIN_(4p>X8V1DW0K z`k+7?XT-}oM^DxIVc=_#0Dgr8lyuoni9gsjU9K^Y-4taVLg^eSEsbiXwa2mR&Fn3I zP`T-8cbDY@&dIjVi}@dNtvhJqc*jWRub2TVN%MVd{xhem$U>FTgd~B}$y(7#+;rH2 z+|Pn5ZY+zOF^LnIYG<>U6Z6KNUwDJ}iSE^)7tBIvJ-@i7(D!hezsz4~SW)UB3fr8X z-4tDOu?AhYa4I{D5fsU!vhm~90*Ty7_Z7?H0Yf$6N|^39PN?*k*IjX#fgb*1wA}o4 zAQYn)S>}hXh?$P~>e?^f=O4|2GKNGrzNWWJUH9d#JtC)Hi;gZLEla(AI>gR1uinb| zL=UgKcunhHcCl#QMSV4E{B?gN@_^XtH{6*|+-QpGpkV-|m{?!GZ0A5D57r-+T(f(x z{EH_=DnA|n()WmoGF}(0;yw&k#uWq0@v>1q!6>2KR_LHM68n$Zyad1!or!?CNoJ_} zsk6vcBU>){r-oyR7IeC?6*cYn50Xfm$%JN;0fS&Yg(DLNCl3gPa7TTJPw52f!)Lb) zPK-f%T>eVgiuzX!&vN(4#<)IATC5S)2$%UM--wqdQbNLLowD-nx8YkdbMrNQV-f>M z#X#cRC};aW^2%yhwaYA)fOE)QA5PbuPgh?A?kBx>j}MDU5t5U`bwnZjuHPPVi#f7Q z0-$ejFFvhP2Tu`jT`5C19$H+Md5oMBwZPIhbhr3T0Q>Zm_L zyCOP(ItqCTyf*j zpCIL=m{hU6I0|_5)1Ld@eWG{d+eyxOW~f??=-=xul|EN%Qe1X#)J?$s&#jcCg&REo_bTX7s zwVe=Df`9NsXmv~T1JymBdWIenre1)oUgF%iMozTV{W0T)z!;lVf zdpmnfQxfpR9Tw{`Z5WBjn;pVq5mdZOpHvSOFFYm3jsW*^+pxKk!Vuq=haQg7f$b^P$@~yO ze^tOF%msv#g#zD=A$0@`da(KFq~l-n^8;qA5>#SR*4~uzvLN!mnKy8V=zeKS%lsBx^g;UAky3^C~>g1XJHe`YV+*@&B}5!;dGiUCmKH%4YfZ>(w_dC2Isp-3s8^@t?!LVN$^T98={{#mCfESh%fhpMSljFgVpo&`xUe}u>bg&mCgdG}M6B*(hP zNgrU!4H#nw?HA&(kIci<(=d3mIk)wqKce%q&>d~DWGE6VPWYG<4q{wZTX)|qhR9#T z!$m%E!IYnR6{Zz*1hMP3k%tiVtSzgizR zX7oPhm~OQ_s!Jt)OUF}^|I{SLmm_uLW3lJYu6N}vI-&eLL@7~$QMZBEz${pjCln>O zEtp2xw1dKjQNy#~`iO>aB{z{^HZYE3z#k(IK)CSvqztK}CQX#$0~)(9TJjm48dFT3 zbHSjs4cjQEfr2s-G)d$k6+^JL;&(4$d8bE={p6SKODyvp`Vt6FpWr=8BPBI~YJR3D z#c!LueckTSgpldkC4affIZ=9pvv*f+V{^6WC8>c{cW0q4xwa^1RUOZK4C7hmCpm`g z!FPB?Ma6fJg`BOvQl8JSQkq0lAexDv`N`zSI(oqIHO_q|D$@{sN*CCFp%|l<8~Ueo zTp=}6tFOE11*n%t=;9W8GnA%SvT=}Y@s}KvPc4MVk>aQr^e~NWjH*hot|)QuHs0_~ z@Lu5I%(@U*b>B$XYq3hVbILzLh+Rasqq_QUyJ<%2F{eeJiEH+_>@q008{pTiWlL2I zB%3@XgIKGf1Vg{j$~kwY-I8D!dsg)5Pg`#&$oo8qg*= zRz0khw|{i^56N$-x)MG^dUMPX-*N&r%TQHGIQ0&`AjH7Dgi|xnwQvadhAk>8v3ymc z^;`I2*3@9F16$10%L!nM-_!&`U(M;okLXk;e;E*X@0J>%PpV2T5PncDQn7mZ4?+j< zR+fpKj%NLvDmks@YXKJR9i1y>o42~Dk7`+d^B1qj_8Q+S`^q%UMZ}HgzqLH>#Sn?p zj!|k%RwCjSs*f#AfODIJ|6?GeqT8F%lHh;MxP2&uXoQwIWX_Dv78=n{JH6(o-Z6}n zHBOaf5wmz#z-lBC5>NLosTBP+nkGAL5wyTjJpwgZ0GuH4?XTqb?V@V;jpiVvddc)* zT*Zd0UGvzlP zL^~Bh$B_}E2Q zT9@a+dmU7eVl3)hhm2A_FiQ|kHQfkR#^dD-#yGp##*cZ=ZF!Z#_jhkVVHYJ=4jhJsq80&ZyRQKL8*p6T{_!OErE9H~eZe`-)ozna zRmAm+0bg$tP1ItIg)@j)#FLntJl*;kkc(D5TvqqFcq#mLM6YC_&ez8h z2VyuKCV5mMged2QGOxgBnNdovj*HrFkj^tOhV`_FK5NfzTcg-phJtBBjA>@t@QU55 znI0MOq~_Sh^UUM{m79T0x}R^iI1w?l1xi}+D&bHh)$&N;O#=J>Fu$eyZcw`3T3DgI zY2U>5;;Fc})9}-k4bSK}Jd$4dr+D%2WGS^6C4vS!a>kDH;5 zDF%0u{BqP8ad-{Ai>4ojzg;> zn>eEz(AxPErhO7(KcdO3{F?3}Oxf$6E!*w?#5lHbyH(<3n!=gda~y=mu^{hrL1q>7 z4M5HR(4=Ce*qgc(PKNJ98@j@Y?sls;{P7)KSBCg^Dke5xK409CUnOq zkGmlFjP9Pfe5xQ3t10+-e<^aNf$d&{P4@SVk=dQE>8lV*<EO=fRZ#I$Dg5Fx+lQ70tvYY%R zx*o)qzdkC{&;0~Ghi4rx0F&4Sh1&VPam4h31}Rl$>aunoz9H%kr_vf}E>#-Ifi~IO zV0d}qe{-A?_-)rd>Q>h_@0J$mi2_%#FA+`J`JrN9RW`VAoMoNS5ofm@|B`rmDQ3I3;`ah959TL@?vDum8Ie2U4%7dQLxYLEcmM3L#KxaIF9qI6U zT7_r)fls6q?lzCQHm}%<9-yG(jlBB+SHUZGH#O4!vR*toMmYE@IqxSK!H`sFfPQL( zM)XdO;AnCbHrt(aKIekUx`Bv<+Eq<<)LKU{&GHjQtGBhs!+fQ0n=w8RVH^w1DPYFF z=`lmsteBllkfe5{%p!|G{!pmV<7$8#eM?Tp<{$$+%*u_VV?e9u)hj}3c_|4kC8rQR zBQ|}1no9>7vQxGodbZz7etWtpqm5){Y1)0CL~}hfTHf|IY)-v$G4NQp-kLjyeRVPp zGHT7%)#vbY!kDYWdu#9OeeE!Sb(E-G&=97;`ey-}FY5wi`yLP3%JGmfRCf5XI#lyp zBdgCaF5HpG7HIUdySsd{Tn<_1qfKi{_Tl+b29zy0hxovn&L%UyafAc&!)l)#M0ucJ z__bIHG>-b*vQ6N@JkGLUoxhtykePG`2EJ7Acmi(fPIV1NpQie3l^*dy^I9J@oVgP| zsiG;EB+l66EXg_l5ls*UC7C!&i!}Pt%~7j>TtP==!|?u#!U<<@*@#tVetl{t>ApgsJ z33dN1&^f#p1;>W5~pnSKPyJ~2c;u|jBu zk0CdHQ02h|*;G(DFqyr?)hPO9ex0Aq*=(sS)9T&h>pso=h`mlx7y~DE7~bj=p;@ zO}23p0;?}gH*1@%G^Vj~qy$<7sN=z3lr7|s1>)3IHq(tKepX0Q>KlFeX0a5>5uSi6 ztEVQV@_@AlfO8&^3bsiW!g}4m$5ivjI}$Z>2JYW_OsM#rNAWW!w2GqjzTn;OPM}PX zOL}5RS*aV-?HqA@Q&(GDWc4PvM{6(d1ec`B7mtcGF^(+hlK8UV-uWw#R&6HaL}&yY z`<<47ePcBPBnxUa{?tMqZHAmh?|(H!u4Fff7?Fmd*@?ps{I)N;kJRA+B@YNZ4m
  • _Oc}O2=c&ccEeaTRz_>rxd$ktm_}(ten54x4oW8qb?}- zqor+ldb1oj0h?z|pQ$oely;~D;g;3<0#GhzV9IG-0c};Fp7Hv;=LHD;%z`(XBTfby zyYV@i{w_c9078}1m)#lH;Y;EUOY9@#z^xAbJUt1Ai!CDUc?myWnCSAv&Bc*`X?Lq5 zn%{Wz?rg=75##xmZaH=UWto3_(mVQQ?qYP7S&dCIvo$6=vzv4g{YZq?y6VguYTqpT zLa9`SAePaa*%1;`thUs{I+}+SKyJoVqV1}TvL|e3!V%XVhPMAqa4=Z`%7TD@$iXVK z%DBHea*;SvL;~8aWphhmmEWpW!c9M;V$PM(EDf?L5RsH)mT22&_B|Ul>w_m>O+zwC z+LqE%OVl7g<38DKK%VtgD|46ji``@iJD%>UshMYwV(fjRF(?eqOKn!?FVm_oHUJ8i z6dp;~^pC|giVx=9%DwAdx<(-T{C!C$oSY~9(Om(QM@D>?GIGum2Zlu){jB)#_z2>1 zV}LnQtRcShpk@fIsWZO?(2^~9=IjBR_?$L**gk{HUSPWxp#-;VgCmZgml?tKDMXaU z>k83tldZgT9a4MSVzfNXI$qq>pX)=!kVD&kO>E?vrOT?kOQ*J=isK`{yictH-0coZ z$1(h^HsU;hAttfRQLyr0JM+STf;rP*_jG6P@Yh$rBDrBQkdT!ml7idBww z<2%b+$Uw2hNB>w@>r`TgF006H8Onha4eLsP>AHlDv@U}I2$uvsq957tt!efbh@J70 z5Bi2LR%chn%D`__Ha<6MZ=p1Q&@!N0D$6eLo@0trq;9P~q?j7{`2*b&C|Kmaio9BM zm$x}X4liZ5|25n)5z^7mR?6sl68b(nWR5vm>Z#~Pss<%cP%Yj8fB3Mpx5s+_1tT6T zWoJ?%;$2Q*>cIfh1P+IicD6jQg;wU85i=6zd?Fg0q(b9s6~!K5F;86juy%W zaU4>WMiI?SM5h-Cv?IAu>)VSj)XTy54L$qMw#NKroBOVSke3x!ie>K}m^NO?mfgG* zwTiX1Xv`MNZyNoWLQZ>ivXfb6ELre&(G(=b4e7HKHx)@N+c3-wwc_oNg@$%r6C*mh zI%%=SG(t8<+^PqvW)C|x4PN7!K%x)nj?h8 z{6R!Ef=I>8pyefd6D zXxXcc`?zS?Umc>&exow{MHJfV(N&Lf`ff>XDM)0*a0PO3m$}4$^Ek(xw$1!kgS1X! zpMERoj$04v(Gb)=b>d6&-5+FGuG8Y}Hc>v62pX$oKBQ(Vow6zk;-RRchts%4F zj$hrgqA~yEKx6xG#Rjr7uRHE4jrW)^E_6$QpHX|Mnp$X@7IB_;>yjgmKbikw>Dd11 z`2X*YSA0;rdd`*=%T3R`<52sM-n*P%KSg?oa#^|-Avn?!{ z52#!vvb;G>c(8JBDs_Aa7cyhI@UjTzKs|3{|K{s|Pbm*ec&b+U2C*^6#%5!yU=~joLw)q#f~{axKR7ofWHH zjHD2Jh5tlV0FW3T-GQK4g8bhw>tCexT$3sTGqPb>{1V#V(aL4ILCy9!*R7{WM6x$y z%=N*X6SCZs1hxX+vH3+AJjXem*n}T{bEdbnUHH=LZQ6n4CzGR_FlzkDqT_o+nh7G~DNJ70`9O6lLFHU!Av^fToyMI9{J= zn}VlED$=_2_(nvW``u28;$$+TtZd1&71npRo_q>cFJu^RoQ8F)+(=K>Q_L?MA6%O~ zvWyBSom2RhP#n7`JngSH?!h^RXQECH_8sF#l`qnZYN7-;7kU6g1Uo(m$WCm|;1&Cx zpHdC&-@k?3mQJ{kMo5+w@jn6SqzdD3P=p>X(2GuL%}+R`X!j}%~LWj>H3 z$B`D8Kc0hv``krBIX5SJ5QJ|oHw@Z+Qgp!^-M!eR@kp|et9YdX!;utb=dK4hLUyD4 z;PCNNX+wW$^cu=!nWy zvW?Z?3uN3H{uH~Es^BAA3X$r>5gq5XeYt8JgzYhaoVqm!N^m;@AcuitMK9k;Yg<`@yhw!E(gk0wa|#__a$0ofCBm1u;OZ79gk|E9u2_-AXKtFR&a8$ zaS|MLcL(2fm#A}R{4%oY;_eZ8lNdN`{IqpLM4q{(^KwQ4m@=4}qgq*fgZyKqmM&{b zdQOJp@r)iv=9k-`S7^dIk=gMo4casqmUI|ix*l2Onq81J2DEnH?hQpVEkWHL=Ms(ctE=Vc z$*<uib8xH=VjVV<*5beSZ7BY(?@2&kgiwvn1VtV z9^PIqyGc-cCD%>tSOiD zrW;e~4)t5Lz_2kDT!K2cX@B}XF9*S!wG+)y`eDF_c@?cp6mC=DbFll+es}#ab_9kL zOl|iHHT~=Tw-GBJSPg?0%ATCBQ97C7tgnYyBJ{vq=yMcb(EiXffV~|V-u9Dbb$muu zRE0e=zp5$Ig8Glc$t3odeqjhb{oR7`V=4>`Yn_!qL|Ij3j+i1ZUvp~G7yBDZ#Mi?PlXe(XM$Y!-?^kA2*`^1e;W zeTm%36$#add57|)e}k1Z_4a~P)^~rJ)T_+|j|dEzsZV;=IXtt$!z1hG{H!H6W1_L!K6v)#Z^A z{8${kQObbze#|C3Ua{9fnMs;CW3#(6s@Y?UWL~d+V{O6i@zS~r7*r1wBt}iw_&pz* zcbXG=#f>>iwZ0ueU+*65n+0{sac5XG>e^6%gg~m+);Yi?Fu!_o*@zK;(2_=8&&Vhf zc@BmnBXYPqJBqQV^(`dFNL0YJXz!85A%Fxcs6svayT%ybxFyf!l}C~So$qr7Lk0^$ z{;|31G1xN4|J+sd`I1)xN3_}>=QvIq`PlgJT2=UAa{~kGQo@kO^X%O7^=rQ&i*jt3z2P+2jP&T4VM*=%eS?;h>- zc^W=;TfeNBc0L>t0>-^tTP&WA0>N32=j>cs?=S+a>lQzRLVQ3XR-s{7+!(a#vre4j zzm58Qxg<2U+S>uSiweGtda3ga(}>~Z#g9Sv*o0<65!Y4D0&^EugAx!) ztO8SJaCBK1`uZWWi_mn#q(a($>~LAwx&2A>{p!x)U2OaF^#h=Te?Iu!(SEgCN5wX6V>hyII*#@tgPSv@U$lbMsZ3|UPFcp?D*>~!p>5do}JB*HUbkIXPZAH zWtPaPZa@jeS?IV>vtj|;qlZ?bHbaBhQNR!y3`ta~_6U^FSKj0w+MHVveZliRda(zB z%5%1o{D!3phAkcIKC}+(aSiGlbGD1HCMzSOJAE8Jb{6F{lA}LMk?|ApL7lZ6piWZr zDZQnDg1R3qdiHIrm4OlQBW%wOp1GxS_ca^*_6ZBqi6HE$#}Px%XHT*F8WaL>t4mvX zLes;W5k))qf)u@vf9t2ubQa0#1#PRAgFv~ZJA!ZcM4Rm0MT`G6b_DiYtrS@xsVby& zrZ~2EZpeAk&H6|F`j_yKXlC=?Z|5NSM-2MU%V{LrPCyDc&G4V}9SXQi*PaOMO|8on#Rx^s@vZ*N5Vg@9; zljLY$0g~c<8)p+DF9bc{#u8E5{XCL+d1~8p|C4(Z=Z(@_iES(ORUG;n5#qH(Y|JO<%!nw8^ALcz)uzFoY_G$jt zcJ#m2V>?Lf(Fp9b8UOoU=%lzq_WxKKP_>`_V)DODfYW6DpLsiZssBSh_-BaI!n+r^ YX^ecKoFA3KuYezUX%(r;4<@1i2dkb=i~s-t literal 0 HcmV?d00001 diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index 553e50f90d..befa53cc00 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -6,7 +6,7 @@ This prototype provides: - Using MXFP8 on a B200 GPU, this provides: - **~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes - **~1.19 - 1.6x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes - + - These benchmarks use `seq_len=8192`, `local_batch_size=16` (so `total_M = 8192 * 16 = 131,072`). We recommend using a large `total_M` dim to maximize speedup. See [benchmarks](#microbenchmarks) for more details. 2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" --quantize.grouped_mm.mx.fqns="experts"` @@ -14,6 +14,28 @@ This prototype provides: 3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below). +## Equivalent convergence to bfloat16 training baseline + +Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. Infact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/). + +Image + +Training and model configurations for this run: +- Model: Llama4 Scout +- Dataset: C4 +- Sequence length: 8192 +- Local batch size: 1 +- Learning rate: 1e-4 +- LR scheduler warmup steps: 2000 +- Parallelisms (64 nodes of 4 devices each = 256 chips): + - FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts) + - EP=16 (on routed experts) +- Activation checkpointing mode: `none` (ideally this should use selective per op AC but there was a bug at the time preventing us from using it). +- `torch.compile` enabled +- `mxfp8` applied to routed experts computation (grouped GEMMs) +- `mxfp8` applied to all linear layers except: `output`, `router.gate`, `attention.wk`, `attention.wv` (Wk and Wv too small to benefit from mxfp8) + + ## Table of Contents - [Examples](#examples) From e4ecec02b81c05169aab4a688bee56afbc312212 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 7 Nov 2025 17:50:35 -0800 Subject: [PATCH 07/22] Move marlin_qqq_tensor to prototype/dtypes (#3307) --- benchmarks/microbenchmarks/utils.py | 2 +- docs/source/api_ref_dtypes.rst | 4 +- test/dtypes/test_uintx.py | 1 + test/quantization/test_marlin_qqq.py | 2 +- torchao/_models/llama/generate.py | 2 +- torchao/dtypes/__init__.py | 8 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- torchao/dtypes/uintx/marlin_qqq_tensor.py | 359 +----------------- torchao/prototype/dtypes/__init__.py | 6 + torchao/prototype/dtypes/uintx/__init__.py | 8 + .../dtypes/uintx/marlin_qqq_tensor.py | 351 +++++++++++++++++ 11 files changed, 397 insertions(+), 354 deletions(-) create mode 100644 torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index d7300a6a81..2c6a443a86 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -218,7 +218,7 @@ def string_to_config( ) if "marlin" in quantization: if "qqq" in quantization: - from torchao.dtypes import MarlinQQQLayout + from torchao.prototype.dtypes import MarlinQQQLayout return Int8DynamicActivationInt4WeightConfig( group_size=128, diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 5c73d275eb..58ad4ee8a4 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -23,8 +23,6 @@ Layouts and Tensor Subclasses FloatxTensorCoreLayout MarlinSparseLayout UintxLayout - MarlinQQQTensor - MarlinQQQLayout Int4CPULayout CutlassSemiSparseLayout @@ -53,6 +51,8 @@ Prototype BlockSparseLayout CutlassInt4PackedLayout Int8DynamicActInt4WeightCPULayout + MarlinQQQTensor + MarlinQQQLayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 5d54a80753..0878dfed4d 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -182,6 +182,7 @@ def test_uintx_api_deprecation(): ), ("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"), ("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"), + ("MarlinQQQLayout", "torchao.dtypes.uintx.marlin_qqq_tensor"), ] for api_name, module_path in deprecated_apis: diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index e0733520ff..6f0f0d69ba 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -10,7 +10,7 @@ from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests -from torchao.dtypes import MarlinQQQLayout +from torchao.prototype.dtypes import MarlinQQQLayout from torchao.quantization.marlin_qqq import ( pack_to_marlin_qqq, unpack_from_marlin_qqq, diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index da1b848bcb..fc3d371139 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -460,7 +460,7 @@ def ffn_or_attn_only(mod, fqn): ) if "marlin" in quantization: if "qqq" in quantization: - from torchao.dtypes import MarlinQQQLayout + from torchao.prototype.dtypes import MarlinQQQLayout quantize_( model, diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 354692e794..4c83de7ddd 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,19 +16,21 @@ from .uintx import ( Int4CPULayout, Int4XPULayout, - MarlinQQQLayout, - MarlinQQQTensor, MarlinSparseLayout, PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, - to_marlinqqq_quantized_intx, ) from .uintx.block_sparse_layout import BlockSparseLayout from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout +from .uintx.marlin_qqq_tensor import ( + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, +) from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 3816f9bf1f..21f13729dd 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -39,10 +39,6 @@ _linear_fp_act_uint4_weight_int8_zero_check, _linear_fp_act_uint4_weight_int8_zero_impl, ) -from torchao.dtypes.uintx.marlin_qqq_tensor import ( - _linear_int8_act_int4_weight_marlin_qqq_check, - _linear_int8_act_int4_weight_marlin_qqq_impl, -) from torchao.dtypes.uintx.marlin_sparse_layout import ( _linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl, @@ -94,6 +90,10 @@ _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, ) +from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( + _linear_int8_act_int4_weight_marlin_qqq_check, + _linear_int8_act_int4_weight_marlin_qqq_impl, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, _dequantize_affine_no_zero_point, diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 04066a6c65..19d16a1e9f 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -3,349 +3,24 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import logging -import math -from dataclasses import dataclass -from typing import Optional, Tuple -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - get_tensor_impl_constructor, - register_layout, -) -from torchao.dtypes.uintx.plain_layout import ( - _aqt_is_int8_reduced_range, +warnings.warn( + "Importing from torchao.dtypes.uintx.marlin_qqq_tensor is deprecated. " + "Please use 'from torchao.prototype.dtypes import MarlinQQQLayout, MarlinQQQTensor' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout -from torchao.quantization.quant_primitives import ( - ZeroPointDomain, - _choose_qparams_and_quantize_affine_qqq, - _dequantize_affine_qqq, -) - -logger = logging.getLogger(__name__) - -aten = torch.ops.aten - - -class MarlinQQQTensor(AffineQuantizedTensor): - """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. - - To see what happens during _choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, - please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: _choose_qparams_and_quantize_affine_qqq and _dequantize_affine_qqq - """ - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - int_data, s_group, s_channel = self.tensor_impl.get_plain() - nbits = int(math.log2(self.quant_max - self.quant_min + 1)) - group_size = max(self.block_size) - return _dequantize_affine_qqq( - int_data, s_group, s_channel, nbits, group_size, output_dtype - ) - - @classmethod - def from_hp_to_intx( - cls, - input_float: torch.Tensor, - block_size: Tuple[int, ...], - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - _layout: Optional[Layout] = None, - ): - """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" - if zero_point_domain is None: - raise ValueError("Please use ZeroPointDomain.NONE instead of None") - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - nbits = int(math.log2(quant_max - quant_min + 1)) - group_size = max(block_size) - data, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( - input_float, nbits, group_size - ) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype, - ) - - -@dataclass(frozen=True) -class MarlinQQQLayout(Layout): - """MarlinQQQLayout is a layout class for Marlin QQQ quantization.""" - - pass - - -@register_layout(MarlinQQQLayout) -class MarlinQQQAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl storage class for sparse_qqq layout for affine quantized tensor. - - Can only be used with 4 bits quantization for now. - - Original marlin documentation and information: - https://github.com/IST-DASLab/marlin/tree/master - - Marlin qqq information: - https://github.com/HandH1998/QQQ/tree/main - https://arxiv.org/pdf/2406.09904 - - fields: - original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape - group_size (int): the group size used to pack the tensor - num_bits (int): the number of bits used to quantize the tensor - """ - - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - self.int_data = int_data - self.s_group = s_group - self.s_channel = s_channel - self._layout = _layout - self.original_shape = original_shape - self.group_size = group_size - self.num_bits = num_bits - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"MarlinQQQAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - return ["int_data", "s_group", "s_channel"], [ - self._layout, - self.original_shape, - self.group_size, - self.num_bits, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data = tensor_data_dict["int_data"] - s_group = tensor_data_dict["s_group"] - s_channel = tensor_data_dict["s_channel"] - _layout, original_shape, group_size, num_bits = tensor_attributes - return cls( - int_data, s_group, s_channel, _layout, original_shape, group_size, num_bits - ) - - def get_plain(self): - from torchao.quantization.marlin_qqq import ( - unpack_from_marlin_qqq, - ) - int_data_expanded, s_group_expanded, s_channel_expanded = ( - unpack_from_marlin_qqq( - self.int_data, - self.s_group, - self.s_channel, - self.original_shape, - self.num_bits, - self.group_size, - ) - ) - int_data_expanded_t = int_data_expanded.t() - s_group_expanded_t = s_group_expanded.t() - s_channel_expanded_t = s_channel_expanded.t() - return int_data_expanded_t, s_group_expanded_t, s_channel_expanded_t - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - s_group: torch.Tensor, - s_channel: torch.Tensor, - _layout: Layout, - ): - from torchao.quantization.marlin_qqq import ( - const, - pack_to_marlin_qqq, - ) - - assert isinstance(_layout, MarlinQQQLayout) - - # Linear layers are (in_features, out_features) but the int_data that is reaching this point - # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. - q_w = int_data.t() - s_group_t = s_group.t() - s_channel_t = s_channel.t() - - if not torch.cuda.get_device_capability()[0] >= 8: - raise ValueError( - f"Can not use Marlin QQQ int4*int8 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." - ) - - if q_w.dtype != torch.int32: - raise ValueError("Only `torch.int32` weights are supported.") - - in_features, out_features = q_w.shape - # (thread_k, thread_n) - thread_config = [(64, 256), (128, 128), (128, 64), (64, 128)] - if not any( - [ - in_features % thread_k == 0 and out_features % thread_n == 0 - for thread_k, thread_n in thread_config - ] - ): - raise ValueError( - "Not supported `in_features`: {} and `out_features`: {}.".format( - in_features, out_features - ) - ) - - num_bits = 4 if torch.max(q_w) - torch.min(q_w) < 16 else -1 - if num_bits not in [4]: - raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") - - if s_group.numel() == 0: - group_size = -1 - else: - group_size = in_features // s_group_t.shape[0] - assert group_size <= in_features, ( - "Group size must be less than or equal to in_features." - ) - - if group_size not in const.SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." - ) - - # Compress quantized weight to marlin format - marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( - q_w, s_group_t, s_channel_t, num_bits, group_size - ) - - return cls( - marlin_qqq_q_w, - marlin_qqq_s_group, - marlin_qqq_s_channel, - _layout, - q_w.shape, - group_size, - num_bits, - ) - - def get_layout(self) -> Layout: - return self._layout - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.s_group = fn(self.s_group) - self.s_channel = fn(self.s_channel) - return self - - -def _linear_int8_act_int4_weight_marlin_qqq_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_int8_reduced_range(input_tensor) - and input_tensor.dtype == torch.float16 - and input_tensor.tensor_impl.scale.dtype == torch.float32 - and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 - and isinstance(weight_tensor, AffineQuantizedTensor) - and weight_tensor.tensor_impl.dtype == torch.int32 - and len(weight_tensor.shape) == 2 - and isinstance(weight_tensor._layout, MarlinQQQLayout) - ) - - -def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bias): - from torchao.ops import marlin_qqq_gemm - from torchao.quantization.marlin_qqq import marlin_qqq_workspace - - assert isinstance(input_tensor, AffineQuantizedTensor) - assert isinstance(weight_tensor, AffineQuantizedTensor) - - input = input_tensor.tensor_impl.int_data - input_scale = input_tensor.tensor_impl.scale - - w_int4 = weight_tensor.tensor_impl.int_data - s_group = weight_tensor.tensor_impl.s_group - s_channel = weight_tensor.tensor_impl.s_channel - original_shape = weight_tensor.tensor_impl.original_shape - - # Folds batch dimension into the first dimension - input_2d = input.view(-1, input.shape[-1]) - input_scale = input_scale.view(1, -1) - - size_m = input_2d.shape[0] - size_n = s_channel.shape[1] - size_k = input_2d.shape[1] - workspace_qqq = marlin_qqq_workspace(original_shape[1]) - - out = marlin_qqq_gemm( - input_2d, - w_int4, - input_scale, - s_channel, - s_group, - workspace_qqq, - size_m, - size_n, - size_k, - ) - - # Unfold the batch dimension - out = out.reshape(input.shape[:-1] + (s_channel.shape[1],)) - - if bias is not None: - out += bias.to(out.dtype) - return out - - -to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx +from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( # noqa: F401 + MarlinQQQAQTTensorImpl, # noqa: F401 + MarlinQQQLayout, # noqa: F401 + MarlinQQQTensor, # noqa: F401 + _linear_int8_act_int4_weight_marlin_qqq_check, # noqa: F401 + _linear_int8_act_int4_weight_marlin_qqq_impl, # noqa: F401 + to_marlinqqq_quantized_intx, # noqa: F401 +) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 52a5aec425..294c7d0b15 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -8,10 +8,16 @@ BlockSparseLayout, CutlassInt4PackedLayout, Int8DynamicActInt4WeightCPULayout, + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, ) __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", "Int8DynamicActInt4WeightCPULayout", + "MarlinQQQLayout", + "MarlinQQQTensor", + "to_marlinqqq_quantized_intx", ] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 89c1f3f810..cd333a90e9 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -7,9 +7,17 @@ from .block_sparse_layout import BlockSparseLayout from .cutlass_int4_packed_layout import CutlassInt4PackedLayout from .dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout +from .marlin_qqq_tensor import ( + MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, +) __all__ = [ "BlockSparseLayout", "CutlassInt4PackedLayout", "Int8DynamicActInt4WeightCPULayout", + "MarlinQQQLayout", + "MarlinQQQTensor", + "to_marlinqqq_quantized_intx", ] diff --git a/torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py b/torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py new file mode 100644 index 0000000000..04066a6c65 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/marlin_qqq_tensor.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import logging +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + get_tensor_impl_constructor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + _choose_qparams_and_quantize_affine_qqq, + _dequantize_affine_qqq, +) + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +class MarlinQQQTensor(AffineQuantizedTensor): + """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + + To see what happens during _choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: _choose_qparams_and_quantize_affine_qqq and _dequantize_affine_qqq + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + int_data, s_group, s_channel = self.tensor_impl.get_plain() + nbits = int(math.log2(self.quant_max - self.quant_min + 1)) + group_size = max(self.block_size) + return _dequantize_affine_qqq( + int_data, s_group, s_channel, nbits, group_size, output_dtype + ) + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + block_size: Tuple[int, ...], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + _layout: Optional[Layout] = None, + ): + """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + nbits = int(math.log2(quant_max - quant_min + 1)) + group_size = max(block_size) + data, s_group, s_channel, _ = _choose_qparams_and_quantize_affine_qqq( + input_float, nbits, group_size + ) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + +@dataclass(frozen=True) +class MarlinQQQLayout(Layout): + """MarlinQQQLayout is a layout class for Marlin QQQ quantization.""" + + pass + + +@register_layout(MarlinQQQLayout) +class MarlinQQQAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl storage class for sparse_qqq layout for affine quantized tensor. + + Can only be used with 4 bits quantization for now. + + Original marlin documentation and information: + https://github.com/IST-DASLab/marlin/tree/master + + Marlin qqq information: + https://github.com/HandH1998/QQQ/tree/main + https://arxiv.org/pdf/2406.09904 + + fields: + original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape + group_size (int): the group size used to pack the tensor + num_bits (int): the number of bits used to quantize the tensor + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + self.int_data = int_data + self.s_group = s_group + self.s_channel = s_channel + self._layout = _layout + self.original_shape = original_shape + self.group_size = group_size + self.num_bits = num_bits + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinQQQAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "s_group", "s_channel"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + s_group = tensor_data_dict["s_group"] + s_channel = tensor_data_dict["s_channel"] + _layout, original_shape, group_size, num_bits = tensor_attributes + return cls( + int_data, s_group, s_channel, _layout, original_shape, group_size, num_bits + ) + + def get_plain(self): + from torchao.quantization.marlin_qqq import ( + unpack_from_marlin_qqq, + ) + + int_data_expanded, s_group_expanded, s_channel_expanded = ( + unpack_from_marlin_qqq( + self.int_data, + self.s_group, + self.s_channel, + self.original_shape, + self.num_bits, + self.group_size, + ) + ) + int_data_expanded_t = int_data_expanded.t() + s_group_expanded_t = s_group_expanded.t() + s_channel_expanded_t = s_channel_expanded.t() + return int_data_expanded_t, s_group_expanded_t, s_channel_expanded_t + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + ): + from torchao.quantization.marlin_qqq import ( + const, + pack_to_marlin_qqq, + ) + + assert isinstance(_layout, MarlinQQQLayout) + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w = int_data.t() + s_group_t = s_group.t() + s_channel_t = s_channel.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Can not use Marlin QQQ int4*int8 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + ) + + if q_w.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w.shape + # (thread_k, thread_n) + thread_config = [(64, 256), (128, 128), (128, 64), (64, 128)] + if not any( + [ + in_features % thread_k == 0 and out_features % thread_n == 0 + for thread_k, thread_n in thread_config + ] + ): + raise ValueError( + "Not supported `in_features`: {} and `out_features`: {}.".format( + in_features, out_features + ) + ) + + num_bits = 4 if torch.max(q_w) - torch.min(q_w) < 16 else -1 + if num_bits not in [4]: + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + + if s_group.numel() == 0: + group_size = -1 + else: + group_size = in_features // s_group_t.shape[0] + assert group_size <= in_features, ( + "Group size must be less than or equal to in_features." + ) + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin format + marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( + q_w, s_group_t, s_channel_t, num_bits, group_size + ) + + return cls( + marlin_qqq_q_w, + marlin_qqq_s_group, + marlin_qqq_s_channel, + _layout, + q_w.shape, + group_size, + num_bits, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.s_group = fn(self.s_group) + self.s_channel = fn(self.s_channel) + return self + + +def _linear_int8_act_int4_weight_marlin_qqq_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and input_tensor.dtype == torch.float16 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.tensor_impl.dtype == torch.int32 + and len(weight_tensor.shape) == 2 + and isinstance(weight_tensor._layout, MarlinQQQLayout) + ) + + +def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bias): + from torchao.ops import marlin_qqq_gemm + from torchao.quantization.marlin_qqq import marlin_qqq_workspace + + assert isinstance(input_tensor, AffineQuantizedTensor) + assert isinstance(weight_tensor, AffineQuantizedTensor) + + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + w_int4 = weight_tensor.tensor_impl.int_data + s_group = weight_tensor.tensor_impl.s_group + s_channel = weight_tensor.tensor_impl.s_channel + original_shape = weight_tensor.tensor_impl.original_shape + + # Folds batch dimension into the first dimension + input_2d = input.view(-1, input.shape[-1]) + input_scale = input_scale.view(1, -1) + + size_m = input_2d.shape[0] + size_n = s_channel.shape[1] + size_k = input_2d.shape[1] + workspace_qqq = marlin_qqq_workspace(original_shape[1]) + + out = marlin_qqq_gemm( + input_2d, + w_int4, + input_scale, + s_channel, + s_group, + workspace_qqq, + size_m, + size_n, + size_k, + ) + + # Unfold the batch dimension + out = out.reshape(input.shape[:-1] + (s_channel.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out + + +to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx From 865583b774166289185ffcd02f4f701c79389357 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 10 Nov 2025 08:21:38 -0500 Subject: [PATCH 08/22] Enable `PerRow(axis)` to support axes other than `-1` (#3303) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- .../workflows/float8/test_float8_tensor.py | 65 +++++++++++++++++++ test/quantization/test_quant_primitives.py | 25 +++++++ torchao/quantization/granularity.py | 23 ++++--- .../workflows/float8/float8_tensor.py | 4 +- torchao/quantization/utils.py | 6 +- torchao/testing/utils.py | 12 ++-- 6 files changed, 120 insertions(+), 15 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 1b91875359..4bc106a60f 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -15,6 +15,7 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import run_tests +from torchao.core.config import config_from_dict, config_to_dict from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8Tensor, @@ -634,6 +635,44 @@ def forward(self, x): sqnr = compute_error(original, quantized) self.assertTrue(sqnr > 20) + @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") + @unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai") + def test_bmm_weight_in_bkn_layout(self): + # Tests rowwise quantization of a 3d weight stored with shape (B, K, N) + # and contigous with that shape. Since the `K` dimension is not last, we + # need to specify granularity with `PerRow(1)`. + + # only support per row quantization + granularity = [PerRow(), PerRow(1)] + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + + class Model(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.weight = weight + + def forward(self, x): + return torch.bmm(x, self.weight) + + dtype = torch.bfloat16 + device = "cuda" + + B, M, K, N = 10, 32, 128, 256 + + input = torch.randn(B, M, K, dtype=dtype, device=device) + weight = torch.randn(B, K, N, dtype=dtype, device=device) + m = Model(weight).eval() + original = m(input) + quantize_(m, config, filter_fn=lambda x, fqn: True) + + assert m.weight.scale.shape == (B, 1, N), ( + f"unexpected scale shape {m.weight.scale.shape}" + ) + + quantized = m(input) + sqnr = compute_error(original, quantized) + self.assertTrue(sqnr > 20) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize( "sizes", @@ -1007,6 +1046,32 @@ def test_transpose(self): self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0) self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0) + def test_per_row_config_before_dim(self): + """ + Test that loading a serialized config of `PerRow` before the `dim` + argument was introduced works properly + """ + + # create a config with PerRow granularity + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ) + + # serialize it + config_ser = config_to_dict(config) + + # manually modify the serialized config to match v1 + # reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335 + del config_ser["_data"]["granularity"][0]["_data"]["dim"] + del config_ser["_data"]["granularity"][1]["_data"]["dim"] + assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0 + assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0 + + # load the modified version, verify that granularity is as expected + config_deser = config_from_dict(config_ser) + assert config_deser.granularity[0].dim == -1 + assert config_deser.granularity[1].dim == -1 + common_utils.instantiate_parametrized_tests(TestFloat8Tensor) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 5f7895b4ea..cc6b7fff91 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -10,6 +10,7 @@ import torch +from torchao.quantization.granularity import PerRow from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, @@ -27,6 +28,7 @@ # TODO: remove test for utils? from torchao.quantization.utils import ( _quantize_activation_per_token_absmax, + get_block_size, get_group_qparams_symmetric, groupwise_affine_dequantize_tensor_from_qparams, groupwise_affine_quantize_tensor_from_qparams, @@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self): torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0) torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0) + def test_float8_rowwise_scaling_3d_weight_axis_1(self): + """ + Test scaling a weight with shape (B, K, N) and row-major memory layout + across the K dimension. + """ + + B, K, N = 8, 16, 32 + hp_tensor = torch.randn(B, K, N, dtype=torch.float) + + granularity = PerRow(1) + block_size = get_block_size(hp_tensor.shape, granularity) + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + hp_value_lb=None, + hp_value_ub=None, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + + assert scale.shape == (B, 1, N) + assert data.shape == (B, K, N) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/granularity.py b/torchao/quantization/granularity.py index d83032d7be..97d9c07b6f 100644 --- a/torchao/quantization/granularity.py +++ b/torchao/quantization/granularity.py @@ -39,12 +39,14 @@ class PerAxis(Granularity): This granularity type calculates different quantization parameters along a specified axis of the tensor. - For example if the input tensor is shape [8, 16] and axis=0, then - the quantization parameters are calculated for each row of the tensor. - Giving a total of 8 quantization parameters. + Examples: + * input_tensor shape [A, B], axis 0 -> scale_shape [A, 1] + * input_tensor shape [A, B], axis 1 -> scale_shape [1, B] + * input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1] Attributes: - axis (int): The axis along which reduction is performed. + axis (int): The axis which is kept, reduction is performed across all + the other axes """ axis: int @@ -76,12 +78,17 @@ class PerRow(Granularity): """ Represents row-wise granularity in quantization. - This is a special case of per-axis quantization and is unique to Float8 matmuls - where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight - is quantized with a block_size of (1, weight.shape[1]). + Examples: + * input_tensor shape [A, B], dim 0 -> scale_shape [1, B] + * input_tensor shape [A, B], dim 1 -> scale_shape [A, 1] + * input_tensor shape [A, B], dim -1 -> scale_shape [A, 1] + * input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C] + + Attributes: + dim (int): The dim which is reduced across, all other dims are kept """ - pass + dim: int = -1 @dataclass(frozen=True) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index a9c7af34b3..abb9ddc1f9 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -179,6 +179,8 @@ def from_hp( and _is_fbgemm_gpu_genai_available() and is_sm_at_least_90() and isinstance(granularity, PerRow) + # fbgemm path only supports quantizing along the last dim + and granularity.dim in (-1, len(hp_tensor.shape) - 1) and float8_dtype == torch.float8_e4m3fn and hp_value_lb is None ): @@ -475,7 +477,7 @@ def _(func, types, args, kwargs): res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, - b_data.transpose(-2, -1), + b_data.transpose(-2, -1).contiguous(), a_scale, b_scale.transpose(-2, -1), b_scale, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index db9a5149c3..1a0375f3d2 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -723,8 +723,12 @@ def get_block_size( f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}" ) return block_size - elif isinstance(granularity, (PerRow, PerToken)): + elif isinstance(granularity, PerToken): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerRow): + block_size = [1] * len(input_shape) + block_size[granularity.dim] = input_shape[granularity.dim] + return tuple(block_size) elif isinstance(granularity, PerGroup): assert input_shape[-1] % granularity.group_size == 0, ( f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}" diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index a1dc40fdd3..10315d45f5 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) # making the weight different dummy_l.weight = torch.nn.Parameter( - dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), + dummy_l.weight + + 1.0 + + 2 * torch.randn(1024, 1024, device=device, dtype=dtype), requires_grad=False, ) quantize_(dummy_l, config) @@ -456,15 +458,15 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig): param = l.weight param_data = param.data param_data = param_data.narrow(output_dim, start_idx, shard_size) - orig_value = param_data.qdata[0][0] + orig_values = param_data.qdata[0] loaded_weight = dummy_l.weight loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0] - assert not torch.equal(orig_value, loaded_weight.qdata[0][0]) + # making sure param.data.qdata[0] is not the same as loaded_weight.qdata[0] + assert not torch.equal(orig_values, loaded_weight.qdata[0]) param_data.copy_(loaded_weight) # making sure param.data is updated to loaded_weight - assert torch.equal(param_data.qdata[0][0], loaded_weight.qdata[0][0]) + assert torch.equal(param_data.qdata[0], loaded_weight.qdata[0]) if hasattr(param_data, "scale"): assert torch.equal(param_data.scale, loaded_weight.scale) if hasattr(param_data, "zero_point"): From 2c109431bffa1d00315ccfdcb967804dcca29abe Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 10 Nov 2025 11:47:34 -0500 Subject: [PATCH 09/22] Remove old TORCH_VERSION variables (#3146) * Remove config functions like `int4_weight_only` **Summary:** As a follow-up to https://github.com/pytorch/ao/pull/2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI [ghstack-poisoned] * Remove old TORCH_VERSION variables **Summary:** As a follow-up to https://github.com/pytorch/ao/pull/2719, which deprecated these variables in 0.13.0, we remove them now in the next release 0.15.0. **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Remove old TORCH_VERSION variables" **Summary:** As a follow-up to https://github.com/pytorch/ao/pull/2719, which deprecated these variables in 0.13.0, we remove them now in the next release 0.15.0. **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Remove old TORCH_VERSION variables" **Summary:** As a follow-up to https://github.com/pytorch/ao/pull/2719, which deprecated these variables in 0.13.0, we remove them now in the next release 0.15.0. **Test Plan:** CI [ghstack-poisoned] --- test/test_utils.py | 50 ----------------------------------- torchao/utils.py | 66 ---------------------------------------------- 2 files changed, 116 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index b46d600053..0e77388f13 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import unittest -import warnings from unittest.mock import patch import torch @@ -37,55 +36,6 @@ def test_torch_version_at_least(self): f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", ) - def test_torch_version_deprecation(self): - """ - Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* - trigger deprecation warnings on use, not on import. - """ - # Reset deprecation warning state, otherwise we won't log warnings here - warnings.resetwarnings() - - # Importing and referencing should not trigger deprecation warning - with warnings.catch_warnings(record=True) as _warnings: - from torchao.utils import ( - TORCH_VERSION_AFTER_2_2, - TORCH_VERSION_AFTER_2_3, - TORCH_VERSION_AFTER_2_4, - TORCH_VERSION_AFTER_2_5, - TORCH_VERSION_AT_LEAST_2_2, - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, - ) - - deprecated_api_to_name = [ - (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), - (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), - (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), - (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), - (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), - (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), - (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), - (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), - (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), - (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), - (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), - ] - self.assertEqual(len(_warnings), 0) - - # Accessing the boolean value should trigger deprecation warning - with warnings.catch_warnings(record=True) as _warnings: - for api, name in deprecated_api_to_name: - num_warnings_before = len(_warnings) - if api: - pass - regex = f"{name} is deprecated and will be removed" - self.assertEqual(len(_warnings), num_warnings_before + 1) - self.assertIn(regex, str(_warnings[-1].message)) - class TestTorchAOBaseTensor(unittest.TestCase): def test_print_arg_types(self): diff --git a/torchao/utils.py b/torchao/utils.py index 26191e2482..e123dfe891 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -35,17 +35,6 @@ "is_sm_at_least_100", "is_package_at_least", "DummyModule", - # Deprecated - "TORCH_VERSION_AT_LEAST_2_2", - "TORCH_VERSION_AT_LEAST_2_3", - "TORCH_VERSION_AT_LEAST_2_4", - "TORCH_VERSION_AT_LEAST_2_5", - "TORCH_VERSION_AT_LEAST_2_6", - "TORCH_VERSION_AT_LEAST_2_7", - "TORCH_VERSION_AFTER_2_2", - "TORCH_VERSION_AFTER_2_3", - "TORCH_VERSION_AFTER_2_4", - "TORCH_VERSION_AFTER_2_5", ] @@ -379,61 +368,6 @@ def torch_version_at_least(min_version): return parse_version(torch.__version__) >= parse_version(min_version) -def _deprecated_torch_version_at_least(version_str: str) -> str: - """ - Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log - a deprecation warning if the variable is used. - """ - version_str_var_name = "_".join(version_str.split(".")[:2]) - deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" - return _BoolDeprecationWrapper( - torch_version_at_least(version_str), - deprecation_msg, - ) - - -def _deprecated_torch_version_after(version_str: str) -> str: - """ - Wrapper for existing TORCH_VERSION_AFTER* variables that will log - a deprecation warning if the variable is used. - """ - bool_value = is_fbcode() or version("torch") >= version_str - version_str_var_name = "_".join(version_str.split(".")[:2]) - deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" - return _BoolDeprecationWrapper(bool_value, deprecation_msg) - - -class _BoolDeprecationWrapper: - """ - A deprecation wrapper that logs a warning when the given bool value is accessed. - """ - - def __init__(self, bool_value: bool, msg: str): - self.bool_value = bool_value - self.msg = msg - - def __bool__(self): - warnings.warn(self.msg) - return self.bool_value - - def __eq__(self, other): - return bool(self) == bool(other) - - -# Deprecated, use `torch_version_at_least` directly instead -TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0") -TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0") -TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") -TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") -TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") -TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0") -TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0") -TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev") -TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev") -TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev") -TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") - - class _ConfigDeprecationWrapper: """ A deprecation wrapper that directs users from a deprecated "config function" From 36e8d0b7dc906b2799a20b4ed7d64d3f0f0bd95c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 10 Nov 2025 15:34:28 -0800 Subject: [PATCH 10/22] Add per tensor fp8 conv2d support (#3315) Summary: Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1. 1. unsqueeze both input and weight in dim 2 ( the D dimension) 2. call fp8 conv3d op from fbgemm `torch.ops.fbgemm.f8f8bf16_conv` 3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants --- .../workflows/float8/test_float8_tensor.py | 153 ++++++++++++++---- torchao/quantization/quant_api.py | 14 +- .../workflows/float8/float8_tensor.py | 97 ++++++++++- 3 files changed, 218 insertions(+), 46 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index 4bc106a60f..df11b71e66 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -87,6 +87,8 @@ def __init__( ) if dim == 3: self.conv = self.conv.to(memory_format=torch.channels_last_3d) + elif dim == 2: + self.conv = self.conv.to(memory_format=torch.channels_last) def forward(self, x): return self.conv(x) @@ -337,12 +339,14 @@ def _test_fp8_matmul_model( @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("inference_mode", [True, False]) - # only test for 3D conv for now - # Inputs are (N, C_in, C_out, D, H, W) + # test for 2D/3D conv + # Inputs are (N, C_in, C_out, (D, H, W) or + # (N, C_in, C_out, (H, W) @common_utils.parametrize( "sizes", [ - (4, 16, 64, 32, 32, 32), + (4, 16, 64, (32, 32, 32)), + (4, 16, 64, (32, 32)), ], ) def test_fp8_conv_variants( @@ -350,20 +354,28 @@ def test_fp8_conv_variants( dtype: torch.dtype, compile: bool, inference_mode: bool, - kernel_preference: KernelPreference, sizes: Tuple, ): + torch.compiler.reset() granularity = PerTensor() kernel_preference = KernelPreference.AUTO - N, C_in, C_out, D, H, W = sizes - dim = 3 + + N, C_in, C_out, spatial_dims = sizes + dim = len(spatial_dims) + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + assert dim in convs, f"Unsupported dim: {dim}" + conv_class = convs[dim] + kernel_size = 3 # Note: this is channel last memory format - input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda") - input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda") + if dim == 3: + input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + else: + assert dim == 2 + input_tensor = input_tensor.to(memory_format=torch.channels_last) - # Create a linear layer with bfloat16 dtype model = ToyConvModel( dim, C_in, @@ -382,9 +394,9 @@ def test_fp8_conv_variants( kernel_preference=kernel_preference, ) - _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) + _is_conv = lambda m, fqn: isinstance(m, conv_class) - quantize_(quantized_model, config, filter_fn=_is_conv3d) + quantize_(quantized_model, config, filter_fn=_is_conv) if compile: quantized_model = torch.compile(quantized_model, fullgraph=True) @@ -408,13 +420,16 @@ def test_fp8_conv_variants( "Requires fbgemm_gpu_genai to be installed", ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) - # only test for 3D conv for now - # Inputs are (N, C_in, C_out, D, H, W) + # test for 2D/3D conv + # Inputs are (N, C_in, C_out, (D, H, W) or + # (N, C_in, C_out, (H, W) @common_utils.parametrize( "sizes", [ - (4, 12, 64, 32, 32, 32), - (4, 16, 12, 32, 32, 32), + (4, 12, 64, (32, 32, 32)), + (4, 16, 12, (32, 32, 32)), + (4, 12, 64, (32, 32)), + (4, 16, 12, (32, 32)), ], ) def test_fp8_conv_skip_quant( @@ -427,14 +442,23 @@ def test_fp8_conv_skip_quant( """ granularity = PerTensor() kernel_preference = KernelPreference.AUTO - N, C_in, C_out, D, H, W = sizes - dim = 3 + + N, C_in, C_out, spatial_dims = sizes + + dim = len(spatial_dims) + convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} + assert dim in convs, f"Unsupported dim: {dim}" + conv_class = convs[dim] + kernel_size = 3 # Note: this is channel last memory format - input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda") - input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) - # Create a linear layer with bfloat16 dtype + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda") + if dim == 3: + input_tensor = input_tensor.to(memory_format=torch.channels_last_3d) + else: + input_tensor = input_tensor.to(memory_format=torch.channels_last) + model = ToyConvModel( dim, C_in, @@ -453,9 +477,9 @@ def test_fp8_conv_skip_quant( kernel_preference=kernel_preference, ) - _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) + _is_conv = lambda m, fqn: isinstance(m, conv_class) - quantize_(quantized_model, config, filter_fn=_is_conv3d) + quantize_(quantized_model, config, filter_fn=_is_conv) assert not isinstance(quantized_model.conv.weight, Float8Tensor) output_original = model(input_tensor) @@ -832,7 +856,6 @@ def test_index_select(self): ], ) def test_unsqueeze_operation(self, granularity, sizes): - """Test aten.unsqueeze.default operation on Float8Tensor""" config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 device = "cuda" @@ -845,7 +868,7 @@ def test_unsqueeze_operation(self, granularity, sizes): original_weight = linear.weight original_shape = original_weight.shape - # Test unsqueeze operation at dim=0 (only supported dimension) + # Test unsqueeze operation at dim=0 unsqueezed_weight = original_weight.unsqueeze(0) # Verify the unsqueezed tensor has correct shape @@ -887,22 +910,84 @@ def test_unsqueeze_operation(self, granularity, sizes): self.assertEqual(unsqueezed_dequant, expected_dequant) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - def test_unsqueeze_error_cases(self, granularity): - """Test error cases for aten.unsqueeze.default operation""" + def test_unsqueeze_conv2d_weight(self): + granularity = PerTensor() config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 device = "cuda" + N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32) + dim = len(spatial_dims) + kernel_size = 3 - # Create a linear layer and quantize it - linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) - quantize_(linear, config) + input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=device) + input_tensor = input_tensor.to(memory_format=torch.channels_last) + model = ToyConvModel( + dim, + C_in, + C_out, + kernel_size, + bias=False, + padding=0, + dtype=dtype, + device=device, + ).eval() + + quantized_model = copy.deepcopy(model) + + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + ) + + _is_conv = lambda m, fqn: isinstance(m, torch.nn.Conv2d) - weight = linear.weight + quantize_(quantized_model, config, filter_fn=_is_conv) - # Test that unsqueezing on unsupported dimensions raises an error - with self.assertRaisesRegex(AssertionError, "Only dim == 0 is supported"): - weight.unsqueeze(1) # dim=1 should not be supported + original_weight = quantized_model.conv.weight + original_shape = original_weight.shape + + # Test unsqueeze operation at dim=2 + unsqueezed_weight = original_weight.unsqueeze(2) + + # Verify the unsqueezed tensor has correct shape + original_shape_list = list(original_shape) + expected_shape = original_shape_list[:2] + [1] + original_shape_list[2:] + scale_shape_list = list(original_weight.scale.shape) + expected_scale_shape = scale_shape_list[:2] + [1] + scale_shape_list[2:] + + self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape)) + # Verify qdata and scale shapes + expected_qdata_shape = expected_shape + + self.assertEqual( + unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape) + ) + self.assertEqual( + unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape) + ) + + # Verify block_size is correctly updated + expected_block_size = [] + for i in range(len(expected_shape)): + expected_block_size.append(expected_shape[i] // expected_scale_shape[i]) + + self.assertEqual(unsqueezed_weight.block_size, expected_block_size) + + # Test that metadata is preserved + self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config) + self.assertEqual( + unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs + ) + self.assertEqual( + unsqueezed_weight.kernel_preference, original_weight.kernel_preference + ) + self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype) + + # Test numerical correctness + original_dequant = original_weight.dequantize() + unsqueezed_dequant = unsqueezed_weight.dequantize() + expected_dequant = original_dequant.unsqueeze(2) + + self.assertEqual(unsqueezed_dequant, expected_dequant) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) @common_utils.parametrize("slice_dim", [0, 1, 2]) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e3a75bbb3e..09c2edcd9f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1816,13 +1816,19 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity - if weight.dim() == 5: - # weights for conv3d + # Note: right now we assume it's weights of conv2d and conv3d purely based + # on the dimension of weight, currently there is no conflict with linear 2d + # and moe weights 3d + # if we need to support conv1d, which also has 3d weight, we may have to + # pass around the module as well to distinguish between conv1d and 3d moe weight + if weight.dim() in [4, 5]: + # weights for conv2d or 3d assert isinstance(activation_granularity, PerTensor) and isinstance( weight_granularity, PerTensor - ), "5D tensor only supports per tensor activation and weight quantization" + ), "4D/5D tensor only supports per tensor activation and weight quantization" - # weight dim: (C_out, C_in, K1, K2, K3) + # conv3d weight dim: (C_out, C_in, K1, K2, K3) + # conv2d weight dim: (C_out, C_in, K1, K2) # skip quantization when either C_out or C_in # is not a multiple of 16 if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index abb9ddc1f9..733d7a17a5 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -539,6 +539,7 @@ def _quantize_and_scaled_conv3d( # move C_in to last dim # after permute: (C_out, K1, K2, K3, C_in) + weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1]) assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), ( @@ -574,10 +575,71 @@ def _(func, types, args, kwargs): groups, ) = args assert not transposed, "transposed conv is not supported currently" - assert tuple(output_padding) == (0, 0, 0), ( - f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}" - ) + dim = len(output_padding) + assert dim in [2, 3], "Only 2d or 3d convs are supported" assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}" + + if dim == 2: + assert input_tensor.is_contiguous( + memory_format=torch.channels_last + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), ( + "Please make sure both activation and weights are in the `channels_last` memory_format" + ) + # (N, C, H, W) --> (N, C, 1, H, W) + input_tensor = input_tensor.unsqueeze(2) + weight_tensor = weight_tensor.unsqueeze(2) + assert tuple(output_padding) == (0, 0), ( + f"Only (0, 0) is supported for `output_padding`, got: f{output_padding}" + ) + padding = [0, *padding] + stride = [1, *stride] + dilation = [1, *dilation] + res = _quantize_and_scaled_conv3d( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + ) + assert res.shape[2] == 1 + res = res.squeeze(2) + return res + else: + assert input_tensor.is_contiguous( + memory_format=torch.channels_last_3d + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), ( + "Please make sure both activation and weights are in the `channels_last_3d` memory_format" + ) + assert tuple(output_padding) == (0, 0, 0), ( + f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}" + ) + return _quantize_and_scaled_conv3d( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + ) + + +@implements(aten.conv3d.default) +def _(func, types, args, kwargs): + ( + input_tensor, + weight_tensor, + bias, + stride, + padding, + dilation, + groups, + ) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1]) + assert input_tensor.is_contiguous( + memory_format=torch.channels_last_3d + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), ( + "Please make sure both activation and weights are in the `channels_last_3d` memory_format" + ) return _quantize_and_scaled_conv3d( input_tensor, weight_tensor, @@ -588,7 +650,7 @@ def _(func, types, args, kwargs): ) -@implements(aten.conv3d.default) +@implements(aten.conv2d.default) def _(func, types, args, kwargs): ( input_tensor, @@ -598,9 +660,26 @@ def _(func, types, args, kwargs): padding, dilation, groups, - ) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1]) - assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}" - return _quantize_and_scaled_conv3d( + ) = fill_defaults(args, 7, [None, [1, 1], [0, 0], [1, 1], 1]) + # (N, C, H, W) --> (N, C, 1, H, W) + # memory_format of both tensors should be torch.channels_last + # and it should be preserved with unsqueeze(2) (becoming torch.channels_last_3d) + assert input_tensor.is_contiguous( + memory_format=torch.channels_last + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), ( + "Please make sure both activation and weights are in the `channels_last` memory_format" + ) + input_tensor = input_tensor.unsqueeze(2) + weight_tensor = weight_tensor.unsqueeze(2) + + assert input_tensor.is_contiguous( + memory_format=torch.channels_last_3d + ) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d) + + padding = [0, *padding] + stride = [1, *stride] + dilation = [1, *dilation] + res = _quantize_and_scaled_conv3d( input_tensor, weight_tensor, bias, @@ -608,6 +687,9 @@ def _(func, types, args, kwargs): padding, dilation, ) + assert res.shape[2] == 1 + res = res.squeeze(2) + return res @implements(aten.slice.Tensor) @@ -839,7 +921,6 @@ def _(func, types, args, kwargs): @implements(aten.unsqueeze.default) def _(func, types, args, kwargs): self, dim = args - assert dim == 0, f"Only dim == 0 is supported, got: {dim}" qdata = self.qdata.unsqueeze(dim=dim) scale = self.scale.unsqueeze(dim=dim) block_size = [] From bab6ce5ed17b8e18126a473f0d2fe7e90faea788 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Mon, 10 Nov 2025 16:16:16 -0800 Subject: [PATCH 11/22] Pin pytest==8.4.2 (#3321) Signed-off-by: Huy Do --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 600d5001cf..ef00257bb7 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ # Test utilities -pytest +pytest==8.4.2 unittest-xml-reporting parameterized packaging From 8bce9b1af0b4fae84314525a89c03f96a64d589e Mon Sep 17 00:00:00 2001 From: namgyu-youn Date: Wed, 12 Nov 2025 03:01:57 +0900 Subject: [PATCH 12/22] Update common used toy linear model (#3275) * build common used toy linear model Co-authored-by: Jerry Zhang * update model to use direct input * revert unit test skip --- test/sparsity/test_fast_sparse_training.py | 18 +----- torchao/testing/model_architectures.py | 68 ++++++++++++++++++++-- 2 files changed, 66 insertions(+), 20 deletions(-) diff --git a/test/sparsity/test_fast_sparse_training.py b/test/sparsity/test_fast_sparse_training.py index 424306f897..a9f57bb5a5 100644 --- a/test/sparsity/test_fast_sparse_training.py +++ b/test/sparsity/test_fast_sparse_training.py @@ -15,22 +15,10 @@ swap_linear_with_semi_sparse_linear, swap_semi_sparse_linear_with_linear, ) +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.utils import is_fbcode -class ToyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(128, 256, bias=False) - self.linear2 = nn.Linear(256, 128, bias=False) - - def forward(self, x): - x = self.linear1(x) - x = torch.nn.functional.relu(x) - x = self.linear2(x) - return x - - class TestRuntimeSemiStructuredSparsity(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(is_fbcode(), "broken in fbcode") @@ -41,7 +29,7 @@ def test_runtime_weight_sparsification(self): input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): @@ -89,7 +77,7 @@ def test_runtime_weight_sparsification_compile(self): input = torch.rand((128, 128)).half().cuda() grad = torch.rand((128, 128)).half().cuda() - model = ToyModel().half().cuda() + model = ToyTwoLinearModel(128, 256, 128, device="cuda", dtype=torch.float16) model_c = copy.deepcopy(model) for name, mod in model.named_modules(): diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index 8f41a8464c..4100a3cd76 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -11,14 +11,72 @@ import torch.nn.functional as F +class ToySingleLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + output_dim, + dtype, + device, + has_bias=False, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) + + def forward(self, x): + x = self.linear1(x) + return x + + # TODO: Refactor torchao and tests to use these models -class ToyLinearModel(torch.nn.Module): - def __init__(self, k=64, n=32, dtype=torch.bfloat16): +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + dtype, + device, + has_bias=False, + ): super().__init__() - self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype) + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + # Note: Tiny-GEMM kernel only uses BF16 inputs + def example_inputs(self, batch_size=1): + return ( + torch.randn( + batch_size, + self.linear1.in_features, + dtype=self.dtype, + device=self.device, + ), + ) def forward(self, x): x = self.linear1(x) + x = self.linear2(x) return x @@ -179,8 +237,8 @@ def create_model_and_input_data( m, k, n (int): dimensions of the model and input data """ if model_type == "linear": - model = ToyLinearModel(k, n, high_precision_dtype).to(device) - input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype) + model = ToySingleLinearModel(k, n, device=device, dtype=high_precision_dtype) + input_data = model.example_inputs(batch_size=m)[0] elif "ln_linear" in model_type: # Extract activation type from model_type string match = re.search(r"ln_linear_?(\w+)?", model_type) From 4a102c25ecb954a62383a9f7b193becc29553991 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Tue, 11 Nov 2025 18:58:51 -0500 Subject: [PATCH 13/22] Use conda libgcc-ng 11.2 (#3327) * Remove devtoolset install * Update regression_test.yml * Update regression_test.yml --- .github/workflows/regression_test.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 46928b30cf..149a7b07da 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -117,11 +117,8 @@ jobs: gpu-arch-version: ${{ matrix.gpu-arch-version }} submodules: recursive script: | - conda create -n venv python=3.10 -y + conda create -n venv python=3.10 libgcc-ng=11.2.0 libstdcxx-ng=11.2.0 -y conda activate venv - echo "::group::Install newer objcopy that supports --set-section-alignment" - dnf install -y gcc-toolset-10-binutils - export PATH=/opt/rh/gcc-toolset-10/root/usr/bin/:$PATH python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} sed -i '${{ matrix.dev-requirements-overrides }}' dev-requirements.txt From 5c3e652e1a0fe0483f6b761774cc74608050677b Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 11 Nov 2025 18:43:15 -0800 Subject: [PATCH 14/22] Move gemlite layout to prototype/dtypes (#3313) --- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- torchao/dtypes/uintx/gemlite_layout.py | 461 +----------------- torchao/prototype/dtypes/__init__.py | 2 + torchao/prototype/dtypes/uintx/__init__.py | 2 + .../prototype/dtypes/uintx/gemlite_layout.py | 452 +++++++++++++++++ torchao/quantization/autoquant.py | 2 +- torchao/quantization/quant_api.py | 2 +- 7 files changed, 480 insertions(+), 449 deletions(-) create mode 100644 torchao/prototype/dtypes/uintx/gemlite_layout.py diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 21f13729dd..6c7216ab12 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -25,10 +25,6 @@ _linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl, ) -from torchao.dtypes.uintx.gemlite_layout import ( - _linear_fp_act_int4_weight_gemlite_check, - _linear_fp_act_int4_weight_gemlite_impl, -) from torchao.dtypes.uintx.int4_cpu_layout import ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, @@ -90,6 +86,10 @@ _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, ) +from torchao.prototype.dtypes.uintx.gemlite_layout import ( + _linear_fp_act_int4_weight_gemlite_check, + _linear_fp_act_int4_weight_gemlite_impl, +) from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 8a8f2309c9..c75c7fe1b1 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -3,450 +3,25 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Dict, Optional, Tuple -import torch -from torch.utils._python_dispatch import ( - is_traceable_wrapper_subclass, - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, +warnings.warn( + "Importing from torchao.dtypes.uintx.gemlite_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes import GemlitePackedLayout' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl -from torchao.dtypes.utils import Layout -from torchao.utils import fill_defaults - -try: - import gemlite -except: - gemlite = None - -aten = torch.ops.aten - - -def _same_metadata( - self: "GemliteAQTTensorImpl", - src: "GemliteAQTTensorImpl", -) -> bool: - kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) - for k, v in self.gemlite_kwargs.items(): - if k in [ - "in_features", - "out_features", - "packing_bitwidth", - "elements_per_sample", - ]: - kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) - - return ( - isinstance(self, GemliteAQTTensorImpl) - and isinstance(src, GemliteAQTTensorImpl) - and self.shape == src.shape - and self.packed_weight.shape == src.packed_weight.shape - and self.scale.shape == src.scale.shape - and self.zero_point.shape == src.zero_point.shape - and kwargs_match - and type(self._layout) == type(src._layout) - ) - - -def get_gemlite_quant_kwargs(bit_width, group_size, dtype): - from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain - - kwargs = {} - if bit_width != 8: - kwargs["mapping_type"] = MappingType.ASYMMETRIC - kwargs["block_size"] = (1, group_size) - kwargs["target_dtype"] = torch.uint8 - kwargs["eps"] = 1e-6 - kwargs["quant_min"] = 0 - kwargs["quant_max"] = (2**bit_width) - 1 - kwargs["eps"] = 1e-6 - kwargs["zero_point_dtype"] = dtype - kwargs["zero_point_domain"] = ZeroPointDomain.FLOAT - elif bit_width == 8: - kwargs["mapping_type"] = MappingType.SYMMETRIC - kwargs["block_size"] = (1, group_size) - kwargs["target_dtype"] = torch.int8 - kwargs["quant_min"] = -128 - kwargs["quant_max"] = 127 - kwargs["eps"] = 1e-5 - kwargs["zero_point_dtype"] = None - kwargs["zero_point_domain"] = ZeroPointDomain.NONE - return kwargs - - -def get_gemlite_aqt_kwargs( - weight, - group_size=64, - bit_width=4, - packing_bitwidth=None, - mode="weight_only", - use_hqq=True, -): - if gemlite is None: - raise ImportError( - "Unable to import 'gemlite'. Please ensure it is installed correctly. You can install it with: pip install gemlite" - ) - - assert bit_width in [ - 4, - 8, - ], f"gemlite only works with bit_width 4,8 but got {bit_width}" - - assert weight.dtype in [torch.float16, torch.bfloat16], ( - f"gemlite only works with dtype torch.float16 or torch.bfloat16 but got {weight.dtype}" - ) - assert group_size in [32, 64, 128, 256, 512, 1024, None] - assert group_size is None or bit_width != 8, ( - "gemlite only works with group_size=None for bit_width=8" - ) - assert packing_bitwidth in [8, 16, 32, None], ( - f"Invalid packing bitwidth, got {packing_bitwidth}" - ) - - assert mode in ["weight_only", "dynamic"], ( - f"Invalid mode: should be either weight_only or dynamic, got {mode}" - ) - - out_features, in_features = weight.shape - group_size = in_features if group_size is None else group_size - - aqt_kwargs = get_gemlite_quant_kwargs(bit_width, group_size, weight.dtype) - aqt_kwargs["_layout"] = GemlitePackedLayout( - group_size=group_size, - bit_width=bit_width, - packing_bitwidth=packing_bitwidth, - mode=mode, - ) - aqt_kwargs["use_hqq"] = use_hqq - return aqt_kwargs - - -@dataclass(frozen=True) -class GemlitePackedLayout(Layout): - group_size: Optional[int] = 128 - bit_width: int = 4 - packing_bitwidth: Optional[int] = None - mode: Optional[str] = "weight_only" - - -@register_layout(GemlitePackedLayout) -class GemliteAQTTensorImpl(TensorCoreTiledAQTTensorImpl): - def __new__( - cls, - packed_weight: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - gemlite_kwargs: Dict, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - gemlite_kwargs: Dict, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scale = scale - self.zero_point = zero_point - self.gemlite_kwargs = gemlite_kwargs - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scale", "zero_point"], [ - self._layout, - self.gemlite_kwargs, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scale, zero_point = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scale"], - tensor_data_dict["zero_point"], - ) - _layout, gemlite_kwargs = tensor_attributes - return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, GemlitePackedLayout), ( - f"GemliteAQTTensorImpl only works with GemliteLinearTriton but got {_layout}" - ) - device = int_data.device - if device.type != "cuda": - int_data = ( - int_data.cuda() - ) # We need int_data on cuda device because of Triton packing - - group_size, bit_width = _layout.group_size, _layout.bit_width - out_features, in_features = int_data.shape - packing_bitwidth = _layout.packing_bitwidth - mode = _layout.mode - - if bit_width == 8 and group_size == in_features: - processor = ( - gemlite.helper.A8W8_int8_dynamic - if mode == "dynamic" - else gemlite.helper.A16W8 - ) - gemlite_linear = processor(device=int_data.device).from_weights( - int_data, scales=scale, bias=None - ) - else: - processor = ( - gemlite.helper.A8Wn_dynamic - if mode == "dynamic" - else gemlite.helper.A16Wn - ) - gemlite_linear = processor( - device=int_data.device, packing_bitwidth=packing_bitwidth - ).from_weights( - int_data, scale, zero_point, bit_width, group_size, bias=None - ) - - meta_args = gemlite_linear.get_meta_args() - gemlite_kwargs = { - "in_features": in_features, - "out_features": out_features, - "packing_bitwidth": packing_bitwidth, - "data_contiguous": gemlite_linear.data_contiguous, - "elements_per_sample": gemlite_linear.elements_per_sample, - "W_group_mode": gemlite_linear.W_group_mode, - "meta_args": meta_args, - } - - packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() - packed_weight = packed_weight.to(device) - if zero_point is None: - zero_point = torch.tensor( - [[]], device=packed_weight.device, dtype=torch.int32 - ) - - return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] - return self.__class__( - self.packed_weight.to(device), - self.scale.to(device), - self.zero_point.to(device), - self.gemlite_kwargs, - self._layout, - ) - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - fn(self.scale), - fn(self.zero_point), - self.gemlite_kwargs, - self._layout, - ) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - device = self.packed_weight.device - int_data = ( - ( - gemlite.bitpack.unpack_over_rows( - self.packed_weight.cuda(), - W_nbits=self._layout.bit_width, - num_output_rows=self.gemlite_kwargs["in_features"], - dtype=torch.uint8, - ) - ) - .to(device) - .t() - ) - - # Preserve col-row major layout - if self.gemlite_kwargs["data_contiguous"]: - int_data = int_data.contiguous() - - # Handle FMA mode: W_q * s + z -> (W_q - z) * s - if self.gemlite_kwargs["W_group_mode"] == 4: - scale_min_val = 1e-8 - scale = self.scale.clone().float() - scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = ( - scale_min_val - ) - scale[ - torch.logical_and(scale < 0, scale.abs() <= scale_min_val) - ] = -scale_min_val - zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100) - zero_point = zero_point.to(self.scale.dtype) - else: - zero_point = self.zero_point - - scale = self.scale.t().contiguous() - zero_point = zero_point.t().contiguous() - - return int_data, scale, zero_point - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - # we don't handle transpose operations and just ignore them. In practice the only - # reason a transpsoe should occur is because the functional linear - # op can decompose into e.g. transpose + addmm so since we want - # to use the gemlite matmul kernel, which expects teh weight to be passed in as is, - # we ignore the transpose - if func is aten.detach.default or func is aten.t.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1, "Only step == 1 is supported in slicing right now" - - if dim in [0, 1]: - # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T - dim = 1 - dim - packed_weight = self.packed_weight - scale = self.scale - zero_point = self.zero_point - - gemlite_kwargs = self.gemlite_kwargs.copy() - orig_shape = [ - gemlite_kwargs["in_features"], - gemlite_kwargs["out_features"], - ] - elements_per_sample = gemlite_kwargs["elements_per_sample"] - data_len = orig_shape[dim] - scale_len = scale.shape[dim] - ratio = data_len / scale_len - start_scale = int(start / ratio) - end_scale = int(end / ratio) - - # For packing only the K dimension. This should be flipped for N-dim packing. - div = elements_per_sample if dim == 0 else 1 - packed_weight = aten.slice.Tensor( - packed_weight, dim, start // div, end // div, step - ) - - # Update in_features/out_features - gemlite_kwargs["in_features"] = ( - packed_weight.shape[0] * elements_per_sample - ) - gemlite_kwargs["out_features"] = packed_weight.shape[1] - - scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - if zero_point is not None and zero_point.numel() > 0: - zero_point = aten.slice.Tensor( - zero_point, dim, start_scale, end_scale, step - ) - else: - zero_point = None - - sliced = GemliteAQTTensorImpl( - packed_weight, scale, zero_point, gemlite_kwargs, self._layout - ) - return return_and_correct_aliasing(func, args, kwargs, sliced) - - else: - raise NotImplementedError( - f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - - elif func is aten.copy_.default: - self = args[0] - src = args[1] - - # Handle zero_point = None with symmetric quant - if self.zero_point is None: - self.zero_point = torch.tensor( - [[]], device=self.packed_weight.device, dtype=torch.int32 - ) - - if src.zero_point is None: - src.zero_point = torch.tensor( - [[]], device=src.packed_weight.device, dtype=torch.int32 - ) - - if _same_metadata(self, src): - self_tensors = self.__tensor_flatten__()[0] - for tensor_name in self_tensors: - getattr(self, tensor_name).copy_(getattr(src, tensor_name)) - for key in self.gemlite_kwargs: - self.gemlite_kwargs[key] = src.gemlite_kwargs[key] - return - raise ValueError( - f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" - ) - - raise NotImplementedError( - f"GemliteAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_layout(self) -> Layout: - return self._layout - - @property - def block_size(self): - return (1, self._layout.group_size) - - -def _linear_fp_act_int4_weight_gemlite_impl(input_tensor, weight_tensor, bias=None): - if hasattr(weight_tensor, "tensor_impl"): - weight_impl = weight_tensor.tensor_impl - else: - weight_impl = weight_tensor - - return gemlite.core.forward_functional( - x=input_tensor, - bias=bias, - tensor_args=( - weight_impl.packed_weight, - weight_impl.scale, - weight_impl.zero_point, - ), - meta_args=weight_impl.gemlite_kwargs["meta_args"], - ) - - -def _linear_fp_act_int4_weight_gemlite_check(input_tensor, weight_tensor, bias): - return ( - # input is native fp16 tensor - not is_traceable_wrapper_subclass(input_tensor) - # and input_tensor.dtype in [torch.float16, torch.bfloat16] - # weight is gemlite layout - and isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, GemlitePackedLayout) - ) +from torchao.prototype.dtypes.uintx.gemlite_layout import ( # noqa: F401 + GemliteAQTTensorImpl, # noqa: F401 + GemlitePackedLayout, # noqa: F401 + _linear_fp_act_int4_weight_gemlite_check, # noqa: F401 + _linear_fp_act_int4_weight_gemlite_impl, # noqa: F401 + _same_metadata, # noqa: F401 + get_gemlite_aqt_kwargs, # noqa: F401 + get_gemlite_quant_kwargs, # noqa: F401 +) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 294c7d0b15..7ad78dbed6 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -7,6 +7,7 @@ from .uintx import ( BlockSparseLayout, CutlassInt4PackedLayout, + GemlitePackedLayout, Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, @@ -20,4 +21,5 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", + "GemlitePackedLayout", ] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index cd333a90e9..56b1eed50a 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -7,6 +7,7 @@ from .block_sparse_layout import BlockSparseLayout from .cutlass_int4_packed_layout import CutlassInt4PackedLayout from .dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout +from .gemlite_layout import GemlitePackedLayout from .marlin_qqq_tensor import ( MarlinQQQLayout, MarlinQQQTensor, @@ -20,4 +21,5 @@ "MarlinQQQLayout", "MarlinQQQTensor", "to_marlinqqq_quantized_intx", + "GemlitePackedLayout", ] diff --git a/torchao/prototype/dtypes/uintx/gemlite_layout.py b/torchao/prototype/dtypes/uintx/gemlite_layout.py new file mode 100644 index 0000000000..8a8f2309c9 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/gemlite_layout.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl +from torchao.dtypes.utils import Layout +from torchao.utils import fill_defaults + +try: + import gemlite +except: + gemlite = None + +aten = torch.ops.aten + + +def _same_metadata( + self: "GemliteAQTTensorImpl", + src: "GemliteAQTTensorImpl", +) -> bool: + kwargs_match = len(self.gemlite_kwargs) == len(src.gemlite_kwargs) + for k, v in self.gemlite_kwargs.items(): + if k in [ + "in_features", + "out_features", + "packing_bitwidth", + "elements_per_sample", + ]: + kwargs_match = kwargs_match and (v == src.gemlite_kwargs[k]) + + return ( + isinstance(self, GemliteAQTTensorImpl) + and isinstance(src, GemliteAQTTensorImpl) + and self.shape == src.shape + and self.packed_weight.shape == src.packed_weight.shape + and self.scale.shape == src.scale.shape + and self.zero_point.shape == src.zero_point.shape + and kwargs_match + and type(self._layout) == type(src._layout) + ) + + +def get_gemlite_quant_kwargs(bit_width, group_size, dtype): + from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain + + kwargs = {} + if bit_width != 8: + kwargs["mapping_type"] = MappingType.ASYMMETRIC + kwargs["block_size"] = (1, group_size) + kwargs["target_dtype"] = torch.uint8 + kwargs["eps"] = 1e-6 + kwargs["quant_min"] = 0 + kwargs["quant_max"] = (2**bit_width) - 1 + kwargs["eps"] = 1e-6 + kwargs["zero_point_dtype"] = dtype + kwargs["zero_point_domain"] = ZeroPointDomain.FLOAT + elif bit_width == 8: + kwargs["mapping_type"] = MappingType.SYMMETRIC + kwargs["block_size"] = (1, group_size) + kwargs["target_dtype"] = torch.int8 + kwargs["quant_min"] = -128 + kwargs["quant_max"] = 127 + kwargs["eps"] = 1e-5 + kwargs["zero_point_dtype"] = None + kwargs["zero_point_domain"] = ZeroPointDomain.NONE + return kwargs + + +def get_gemlite_aqt_kwargs( + weight, + group_size=64, + bit_width=4, + packing_bitwidth=None, + mode="weight_only", + use_hqq=True, +): + if gemlite is None: + raise ImportError( + "Unable to import 'gemlite'. Please ensure it is installed correctly. You can install it with: pip install gemlite" + ) + + assert bit_width in [ + 4, + 8, + ], f"gemlite only works with bit_width 4,8 but got {bit_width}" + + assert weight.dtype in [torch.float16, torch.bfloat16], ( + f"gemlite only works with dtype torch.float16 or torch.bfloat16 but got {weight.dtype}" + ) + assert group_size in [32, 64, 128, 256, 512, 1024, None] + assert group_size is None or bit_width != 8, ( + "gemlite only works with group_size=None for bit_width=8" + ) + assert packing_bitwidth in [8, 16, 32, None], ( + f"Invalid packing bitwidth, got {packing_bitwidth}" + ) + + assert mode in ["weight_only", "dynamic"], ( + f"Invalid mode: should be either weight_only or dynamic, got {mode}" + ) + + out_features, in_features = weight.shape + group_size = in_features if group_size is None else group_size + + aqt_kwargs = get_gemlite_quant_kwargs(bit_width, group_size, weight.dtype) + aqt_kwargs["_layout"] = GemlitePackedLayout( + group_size=group_size, + bit_width=bit_width, + packing_bitwidth=packing_bitwidth, + mode=mode, + ) + aqt_kwargs["use_hqq"] = use_hqq + return aqt_kwargs + + +@dataclass(frozen=True) +class GemlitePackedLayout(Layout): + group_size: Optional[int] = 128 + bit_width: int = 4 + packing_bitwidth: Optional[int] = None + mode: Optional[str] = "weight_only" + + +@register_layout(GemlitePackedLayout) +class GemliteAQTTensorImpl(TensorCoreTiledAQTTensorImpl): + def __new__( + cls, + packed_weight: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + gemlite_kwargs: Dict, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + gemlite_kwargs: Dict, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self.gemlite_kwargs = gemlite_kwargs + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale", "zero_point"], [ + self._layout, + self.gemlite_kwargs, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale, zero_point = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + _layout, gemlite_kwargs = tensor_attributes + return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, GemlitePackedLayout), ( + f"GemliteAQTTensorImpl only works with GemliteLinearTriton but got {_layout}" + ) + device = int_data.device + if device.type != "cuda": + int_data = ( + int_data.cuda() + ) # We need int_data on cuda device because of Triton packing + + group_size, bit_width = _layout.group_size, _layout.bit_width + out_features, in_features = int_data.shape + packing_bitwidth = _layout.packing_bitwidth + mode = _layout.mode + + if bit_width == 8 and group_size == in_features: + processor = ( + gemlite.helper.A8W8_int8_dynamic + if mode == "dynamic" + else gemlite.helper.A16W8 + ) + gemlite_linear = processor(device=int_data.device).from_weights( + int_data, scales=scale, bias=None + ) + else: + processor = ( + gemlite.helper.A8Wn_dynamic + if mode == "dynamic" + else gemlite.helper.A16Wn + ) + gemlite_linear = processor( + device=int_data.device, packing_bitwidth=packing_bitwidth + ).from_weights( + int_data, scale, zero_point, bit_width, group_size, bias=None + ) + + meta_args = gemlite_linear.get_meta_args() + gemlite_kwargs = { + "in_features": in_features, + "out_features": out_features, + "packing_bitwidth": packing_bitwidth, + "data_contiguous": gemlite_linear.data_contiguous, + "elements_per_sample": gemlite_linear.elements_per_sample, + "W_group_mode": gemlite_linear.W_group_mode, + "meta_args": meta_args, + } + + packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() + packed_weight = packed_weight.to(device) + if zero_point is None: + zero_point = torch.tensor( + [[]], device=packed_weight.device, dtype=torch.int32 + ) + + return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + return self.__class__( + self.packed_weight.to(device), + self.scale.to(device), + self.zero_point.to(device), + self.gemlite_kwargs, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale), + fn(self.zero_point), + self.gemlite_kwargs, + self._layout, + ) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = self.packed_weight.device + int_data = ( + ( + gemlite.bitpack.unpack_over_rows( + self.packed_weight.cuda(), + W_nbits=self._layout.bit_width, + num_output_rows=self.gemlite_kwargs["in_features"], + dtype=torch.uint8, + ) + ) + .to(device) + .t() + ) + + # Preserve col-row major layout + if self.gemlite_kwargs["data_contiguous"]: + int_data = int_data.contiguous() + + # Handle FMA mode: W_q * s + z -> (W_q - z) * s + if self.gemlite_kwargs["W_group_mode"] == 4: + scale_min_val = 1e-8 + scale = self.scale.clone().float() + scale[torch.logical_and(scale >= 0, scale.abs() <= scale_min_val)] = ( + scale_min_val + ) + scale[ + torch.logical_and(scale < 0, scale.abs() <= scale_min_val) + ] = -scale_min_val + zero_point = (-self.zero_point.float() / scale).clamp_(-100, 100) + zero_point = zero_point.to(self.scale.dtype) + else: + zero_point = self.zero_point + + scale = self.scale.t().contiguous() + zero_point = zero_point.t().contiguous() + + return int_data, scale, zero_point + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + # we don't handle transpose operations and just ignore them. In practice the only + # reason a transpsoe should occur is because the functional linear + # op can decompose into e.g. transpose + addmm so since we want + # to use the gemlite matmul kernel, which expects teh weight to be passed in as is, + # we ignore the transpose + if func is aten.detach.default or func is aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1, "Only step == 1 is supported in slicing right now" + + if dim in [0, 1]: + # data in self is transposed, meaning forward() performs x @ W_deq not x @ W_deq.T + dim = 1 - dim + packed_weight = self.packed_weight + scale = self.scale + zero_point = self.zero_point + + gemlite_kwargs = self.gemlite_kwargs.copy() + orig_shape = [ + gemlite_kwargs["in_features"], + gemlite_kwargs["out_features"], + ] + elements_per_sample = gemlite_kwargs["elements_per_sample"] + data_len = orig_shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + # For packing only the K dimension. This should be flipped for N-dim packing. + div = elements_per_sample if dim == 0 else 1 + packed_weight = aten.slice.Tensor( + packed_weight, dim, start // div, end // div, step + ) + + # Update in_features/out_features + gemlite_kwargs["in_features"] = ( + packed_weight.shape[0] * elements_per_sample + ) + gemlite_kwargs["out_features"] = packed_weight.shape[1] + + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + if zero_point is not None and zero_point.numel() > 0: + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + else: + zero_point = None + + sliced = GemliteAQTTensorImpl( + packed_weight, scale, zero_point, gemlite_kwargs, self._layout + ) + return return_and_correct_aliasing(func, args, kwargs, sliced) + + else: + raise NotImplementedError( + f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + elif func is aten.copy_.default: + self = args[0] + src = args[1] + + # Handle zero_point = None with symmetric quant + if self.zero_point is None: + self.zero_point = torch.tensor( + [[]], device=self.packed_weight.device, dtype=torch.int32 + ) + + if src.zero_point is None: + src.zero_point = torch.tensor( + [[]], device=src.packed_weight.device, dtype=torch.int32 + ) + + if _same_metadata(self, src): + self_tensors = self.__tensor_flatten__()[0] + for tensor_name in self_tensors: + getattr(self, tensor_name).copy_(getattr(src, tensor_name)) + for key in self.gemlite_kwargs: + self.gemlite_kwargs[key] = src.gemlite_kwargs[key] + return + raise ValueError( + f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}" + ) + + raise NotImplementedError( + f"GemliteAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_layout(self) -> Layout: + return self._layout + + @property + def block_size(self): + return (1, self._layout.group_size) + + +def _linear_fp_act_int4_weight_gemlite_impl(input_tensor, weight_tensor, bias=None): + if hasattr(weight_tensor, "tensor_impl"): + weight_impl = weight_tensor.tensor_impl + else: + weight_impl = weight_tensor + + return gemlite.core.forward_functional( + x=input_tensor, + bias=bias, + tensor_args=( + weight_impl.packed_weight, + weight_impl.scale, + weight_impl.zero_point, + ), + meta_args=weight_impl.gemlite_kwargs["meta_args"], + ) + + +def _linear_fp_act_int4_weight_gemlite_check(input_tensor, weight_tensor, bias): + return ( + # input is native fp16 tensor + not is_traceable_wrapper_subclass(input_tensor) + # and input_tensor.dtype in [torch.float16, torch.bfloat16] + # weight is gemlite layout + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, GemlitePackedLayout) + ) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index c72e18a923..884c96559a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -724,7 +724,7 @@ class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight( @classmethod def from_float(cls, weight): from torchao.dtypes import to_affine_quantized_intx - from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + from torchao.prototype.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs if weight.dtype != torch.float16: weight = weight.to(torch.float16) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 09c2edcd9f..c29382b658 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1055,7 +1055,7 @@ def _gemlite_uintx_weight_only_transform( weight = module.weight - from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + from torchao.prototype.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs use_hqq = True if bit_width == 4 else False new_weight = to_affine_quantized_intx( From 7213f817592ba6187fa446a66708b74211656f07 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 11 Nov 2025 21:30:23 -0800 Subject: [PATCH 15/22] Move uintx_layout to prototype/dtypes (#3316) --- docs/source/api_ref_dtypes.rst | 2 +- test/dtypes/test_uintx.py | 3 +- torchao/dtypes/__init__.py | 2 +- torchao/dtypes/uintx/uintx_layout.py | 260 ++---------------- torchao/prototype/autoround/core.py | 2 +- torchao/prototype/autoround/eval_autoround.py | 4 +- torchao/prototype/dtypes/__init__.py | 8 + torchao/prototype/dtypes/uintx/__init__.py | 10 + .../prototype/dtypes/uintx/uintx_layout.py | 251 +++++++++++++++++ .../codebook/codebook_quantized_tensor.py | 2 +- 10 files changed, 295 insertions(+), 249 deletions(-) create mode 100644 torchao/prototype/dtypes/uintx/uintx_layout.py diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 58ad4ee8a4..3997b444b3 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -22,7 +22,6 @@ Layouts and Tensor Subclasses FloatxTensor FloatxTensorCoreLayout MarlinSparseLayout - UintxLayout Int4CPULayout CutlassSemiSparseLayout @@ -53,6 +52,7 @@ Prototype Int8DynamicActInt4WeightCPULayout MarlinQQQTensor MarlinQQQLayout + UintxLayout .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index 0878dfed4d..3172381a3a 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -9,7 +9,7 @@ import pytest import torch -from torchao.dtypes.uintx.uintx_layout import to_uintx +from torchao.prototype.dtypes.uintx.uintx_layout import to_uintx from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_ from torchao.quantization.quant_primitives import ( MappingType, @@ -183,6 +183,7 @@ def test_uintx_api_deprecation(): ("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"), ("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"), ("MarlinQQQLayout", "torchao.dtypes.uintx.marlin_qqq_tensor"), + ("UintxLayout", "torchao.dtypes.uintx.uintx_layout"), ] for api_name, module_path in deprecated_apis: diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 4c83de7ddd..43c140908a 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -21,7 +21,6 @@ QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, - UintxLayout, ) from .uintx.block_sparse_layout import BlockSparseLayout from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout @@ -31,6 +30,7 @@ MarlinQQQTensor, to_marlinqqq_quantized_intx, ) +from .uintx.uintx_layout import UintxLayout from .utils import ( Layout, PlainLayout, diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 3180e9f2c9..dfd93249d6 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -3,250 +3,24 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import List, Tuple -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import register_layout -from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl -from torchao.dtypes.utils import ( - Layout, +warnings.warn( + "Importing from torchao.dtypes.uintx.uintx_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes import UintxLayout, UintxTensor' instead. " + "This import path will be removed in a future release of torchao. " + "See https://github.com/pytorch/ao/issues/2752 for more details.", + DeprecationWarning, + stacklevel=2, ) -from torchao.utils import TorchAOBaseTensor -from .bitpacking import pack, unpack - -aten = torch.ops.aten - -# Note: Uintx does not work for torch 2.3 and below -_DTYPE_TO_BIT_WIDTH = {} -_BIT_WIDTH_TO_DTYPE = {} - -_DTYPE_TO_BIT_WIDTH = { - torch.uint1: 1, - torch.uint2: 2, - torch.uint3: 3, - torch.uint4: 4, - torch.uint5: 5, - torch.uint6: 6, - torch.uint7: 7, -} - -_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} - - -class UintxTensor(TorchAOBaseTensor): - """ - Splits int data into packed shards based on bit size - fields: - int4_shard (torch.Tensor): 4 bit packed shard - int2_shard (torch.Tensor): 2 bit packed shard - int1_shard (torch.Tensor): 1 bit packed shard - bit_width (int): number of bits for each element - pack_dim: (int) dimension to pack along - """ - - bits_to_shard = { - 1: ["int1_shard"], - 2: ["int2_shard"], - 3: ["int2_shard", "int1_shard"], - 4: ["int4_shard"], - 5: ["int4_shard", "int1_shard"], - 6: ["int4_shard", "int2_shard"], - 7: ["int4_shard", "int2_shard", "int1_shard"], - } - - def __new__( - cls, - shards: List[torch.Tensor], - packed_shape: List[int], - bit_width: int, - pack_dim: int = -1, - ): - kwargs = {"device": shards[0].device} - kwargs["device"] = shards[0].device - kwargs["layout"] = shards[0].layout - kwargs["requires_grad"] = False - kwargs["dtype"] = torch.uint8 - return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs) - - def __init__( - self, - shards: List[torch.Tensor], - packed_shape: List[int], - bit_width: int, - pack_dim: int = -1, - ): - for i, attrib in enumerate(self.bits_to_shard[bit_width]): - setattr(self, attrib, shards[i]) - - self.packed_shape = packed_shape - self.bit_width = bit_width - self.pack_dim = pack_dim - - def get_shards(self): - return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]] - - def __repr__(self): - return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim=self.pack_dim)})" - - def __tensor_flatten__(self): - return self.__class__.bits_to_shard[self.bit_width], [ - self.packed_shape, - self.bit_width, - self.pack_dim, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - shards = list(tensor_data_dict.values()) - packed_shape, bit_width, pack_dim = tensor_attributes - return cls(shards, packed_shape, bit_width, pack_dim) - - def get_plain(self): - return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim) - - # temporary until kernels on packed tensors are created - def apply_transformation(self, fn): - og = self.get_plain() - new = fn(og) - dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width] - return self.from_uint8(new, dtype, self.pack_dim) - - # temporary until kernels on packed tensors are created - def apply_fn_to_shards(self, fn): - new_shards = [fn(shard) for shard in self.get_shards()] - return self.__class__( - new_shards, self.packed_shape, self.bit_width, self.pack_dim - ) - - @classmethod - def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): - assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), ( - "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" - ) - bit_width = _DTYPE_TO_BIT_WIDTH[dtype] - shards = pack(int_data, bit_width, dim=pack_dim) - shape = list(int_data.shape) - shape[pack_dim] = shape[pack_dim] * bit_width // 8 - return cls(shards, int_data.shape, bit_width, pack_dim) - - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - - def to(self, *args, **kwargs): - if "copy" in kwargs: - return super().to(*args, **kwargs) - kwargs = self._get_to_kwargs(*args, **kwargs) - if "device" in kwargs: - return self.__class__( - list(shard.to(kwargs["device"]) for shard in self.get_shards()), - self.packed_shape, - self.bit_width, - self.pack_dim, - ) - return super().to(*args, **kwargs) - - -implements = UintxTensor.implements - - -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) - ) - - -@implements(aten.view.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) - ) - - -@implements(aten._to_copy.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing(func, args, kwargs, args[0]) - - -@implements(aten.sub.Tensor) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)), - ) - - -@implements(aten.mul.Tensor) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)), - ) - - -# quantization api integrations -to_uintx = UintxTensor.from_uint8 - - -@dataclass(frozen=True) -class UintxLayout(Layout): - """A layout class for Uintx tensors, which are tensors with elements packed into - smaller bit-widths than the standard 8-bit byte. This layout is used to define - how the data is stored and processed in UintxTensor objects. - - Attributes: - dtype (torch.dtype): The data type of the tensor elements, which determines - the bit-width used for packing. - pack_dim (int): The dimension along which the data is packed. Default is -1, - which indicates the last dimension. - """ - - dtype: torch.dtype - pack_dim: int = -1 - - def post_process( - self, - input: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - block_size: Tuple[int, ...], - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point - - -@register_layout(UintxLayout) -class UintxAQTTensorImpl(PlainAQTTensorImpl): - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.int_data.get_plain(), self.scale, self.zero_point - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - assert isinstance(_layout, UintxLayout) - return cls(int_data, scale, zero_point, _layout) +from torchao.prototype.dtypes.uintx.uintx_layout import ( # noqa: F401 + _BIT_WIDTH_TO_DTYPE, # noqa: F401 + _DTYPE_TO_BIT_WIDTH, # noqa: F401 + UintxAQTTensorImpl, # noqa: F401 + UintxLayout, # noqa: F401 + UintxTensor, # noqa: F401 + to_uintx, # noqa: F401 +) diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py index 859e1cfe02..159fcb3c3d 100644 --- a/torchao/prototype/autoround/core.py +++ b/torchao/prototype/autoround/core.py @@ -189,7 +189,7 @@ def to_uintx_weight(input_float): quant_min = 0 quant_max = _auto_round_config.bits**2 - 1 block_size = (1, observed_linear.group_size) - from torchao.dtypes.uintx.uintx import ( + from torchao.prototype.dtypes.uintx.uintx_layout import ( _BIT_WIDTH_TO_DTYPE, UintxLayout, ) diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py index 4846f919cc..4f6850be88 100644 --- a/torchao/prototype/autoround/eval_autoround.py +++ b/torchao/prototype/autoround/eval_autoround.py @@ -111,7 +111,9 @@ def main(args): ) elif args.uintx: msg += f" (uintx {args.bits} bits)" - from torchao.dtypes.uintx.uintx import _BIT_WIDTH_TO_DTYPE + from torchao.prototype.dtypes.uintx.uintx_layout import ( + _BIT_WIDTH_TO_DTYPE, + ) from torchao.quantization.quant_api import ( UIntXWeightOnlyConfig, quantize_, diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 7ad78dbed6..88fe73ab76 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -11,7 +11,11 @@ Int8DynamicActInt4WeightCPULayout, MarlinQQQLayout, MarlinQQQTensor, + UintxAQTTensorImpl, + UintxLayout, + UintxTensor, to_marlinqqq_quantized_intx, + to_uintx, ) __all__ = [ @@ -22,4 +26,8 @@ "MarlinQQQTensor", "to_marlinqqq_quantized_intx", "GemlitePackedLayout", + "UintxLayout", + "UintxTensor", + "UintxAQTTensorImpl", + "to_uintx", ] diff --git a/torchao/prototype/dtypes/uintx/__init__.py b/torchao/prototype/dtypes/uintx/__init__.py index 56b1eed50a..2b6372d748 100644 --- a/torchao/prototype/dtypes/uintx/__init__.py +++ b/torchao/prototype/dtypes/uintx/__init__.py @@ -13,6 +13,12 @@ MarlinQQQTensor, to_marlinqqq_quantized_intx, ) +from .uintx_layout import ( + UintxAQTTensorImpl, + UintxLayout, + UintxTensor, + to_uintx, +) __all__ = [ "BlockSparseLayout", @@ -22,4 +28,8 @@ "MarlinQQQTensor", "to_marlinqqq_quantized_intx", "GemlitePackedLayout", + "UintxLayout", + "UintxTensor", + "UintxAQTTensorImpl", + "to_uintx", ] diff --git a/torchao/prototype/dtypes/uintx/uintx_layout.py b/torchao/prototype/dtypes/uintx/uintx_layout.py new file mode 100644 index 0000000000..ce9ce836e7 --- /dev/null +++ b/torchao/prototype/dtypes/uintx/uintx_layout.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.affine_quantized_tensor import register_layout +from torchao.dtypes.uintx.bitpacking import pack, unpack +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +from torchao.dtypes.utils import ( + Layout, +) +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + +# Note: Uintx does not work for torch 2.3 and below +_DTYPE_TO_BIT_WIDTH = {} +_BIT_WIDTH_TO_DTYPE = {} + +_DTYPE_TO_BIT_WIDTH = { + torch.uint1: 1, + torch.uint2: 2, + torch.uint3: 3, + torch.uint4: 4, + torch.uint5: 5, + torch.uint6: 6, + torch.uint7: 7, +} + +_BIT_WIDTH_TO_DTYPE = {v: k for k, v in _DTYPE_TO_BIT_WIDTH.items()} + + +class UintxTensor(TorchAOBaseTensor): + """ + Splits int data into packed shards based on bit size + fields: + int4_shard (torch.Tensor): 4 bit packed shard + int2_shard (torch.Tensor): 2 bit packed shard + int1_shard (torch.Tensor): 1 bit packed shard + bit_width (int): number of bits for each element + pack_dim: (int) dimension to pack along + """ + + bits_to_shard = { + 1: ["int1_shard"], + 2: ["int2_shard"], + 3: ["int2_shard", "int1_shard"], + 4: ["int4_shard"], + 5: ["int4_shard", "int1_shard"], + 6: ["int4_shard", "int2_shard"], + 7: ["int4_shard", "int2_shard", "int1_shard"], + } + + def __new__( + cls, + shards: List[torch.Tensor], + packed_shape: List[int], + bit_width: int, + pack_dim: int = -1, + ): + kwargs = {"device": shards[0].device} + kwargs["device"] = shards[0].device + kwargs["layout"] = shards[0].layout + kwargs["requires_grad"] = False + kwargs["dtype"] = torch.uint8 + return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs) + + def __init__( + self, + shards: List[torch.Tensor], + packed_shape: List[int], + bit_width: int, + pack_dim: int = -1, + ): + for i, attrib in enumerate(self.bits_to_shard[bit_width]): + setattr(self, attrib, shards[i]) + + self.packed_shape = packed_shape + self.bit_width = bit_width + self.pack_dim = pack_dim + + def get_shards(self): + return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]] + + def __repr__(self): + return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim=self.pack_dim)})" + + def __tensor_flatten__(self): + return self.__class__.bits_to_shard[self.bit_width], [ + self.packed_shape, + self.bit_width, + self.pack_dim, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + shards = list(tensor_data_dict.values()) + packed_shape, bit_width, pack_dim = tensor_attributes + return cls(shards, packed_shape, bit_width, pack_dim) + + def get_plain(self): + return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim) + + # temporary until kernels on packed tensors are created + def apply_transformation(self, fn): + og = self.get_plain() + new = fn(og) + dtype = _BIT_WIDTH_TO_DTYPE[self.bit_width] + return self.from_uint8(new, dtype, self.pack_dim) + + # temporary until kernels on packed tensors are created + def apply_fn_to_shards(self, fn): + new_shards = [fn(shard) for shard in self.get_shards()] + return self.__class__( + new_shards, self.packed_shape, self.bit_width, self.pack_dim + ) + + @classmethod + def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): + assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), ( + "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" + ) + bit_width = _DTYPE_TO_BIT_WIDTH[dtype] + shards = pack(int_data, bit_width, dim=pack_dim) + shape = list(int_data.shape) + shape[pack_dim] = shape[pack_dim] * bit_width // 8 + return cls(shards, int_data.shape, bit_width, pack_dim) + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + if "copy" in kwargs: + return super().to(*args, **kwargs) + kwargs = self._get_to_kwargs(*args, **kwargs) + if "device" in kwargs: + return self.__class__( + list(shard.to(kwargs["device"]) for shard in self.get_shards()), + self.packed_shape, + self.bit_width, + self.pack_dim, + ) + return super().to(*args, **kwargs) + + +implements = UintxTensor.implements + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) + ) + + +@implements(aten.view.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0]) + + +@implements(aten.sub.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)), + ) + + +@implements(aten.mul.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)), + ) + + +# quantization api integrations +to_uintx = UintxTensor.from_uint8 + + +@dataclass(frozen=True) +class UintxLayout(Layout): + """A layout class for Uintx tensors, which are tensors with elements packed into + smaller bit-widths than the standard 8-bit byte. This layout is used to define + how the data is stored and processed in UintxTensor objects. + + Attributes: + dtype (torch.dtype): The data type of the tensor elements, which determines + the bit-width used for packing. + pack_dim (int): The dimension along which the data is packed. Default is -1, + which indicates the last dimension. + """ + + dtype: torch.dtype + pack_dim: int = -1 + + def post_process( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point + + +@register_layout(UintxLayout) +class UintxAQTTensorImpl(PlainAQTTensorImpl): + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.int_data.get_plain(), self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + assert isinstance(_layout, UintxLayout) + return cls(int_data, scale, zero_point, _layout) diff --git a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py index e16a339e82..9c3ef1e9b0 100644 --- a/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py +++ b/torchao/prototype/quantization/codebook/codebook_quantized_tensor.py @@ -9,7 +9,7 @@ import torch from torchao.core.config import AOBaseConfig -from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxTensor +from torchao.prototype.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxTensor from torchao.prototype.quantization.codebook.codebook_ops import ( choose_qparams_codebook, dequantize_codebook, From 726607d5f116b07182de90d49e81e5eae3656e81 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 12 Nov 2025 12:34:11 -0500 Subject: [PATCH 16/22] Add __str__ to FqnToConfig to make printing more readable (#3323) * Adds __str__ to FqnToConfig to make printing more readable Summary: att, adds `__str__` method to `FqnToConfig` so that printing is more legible. For some config: ```python config = FqnToConfig({ "model.layers.fig.1.1": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), "model.layers.fig.1.3": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), "model.layers.fig.8.3": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), }) ``` the output will be: ``` FqnToConfig({ 'model.layers.fig.1.1': Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerRow(dim=-1), PerRow(dim=-1)], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=, set_inductor_config=True, version=2), 'model.layers.fig.1.3': Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerRow(dim=-1), PerRow(dim=-1)], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=, set_inductor_config=True, version=2), 'model.layers.fig.8.3': Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerRow(dim=-1), PerRow(dim=-1)], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=, set_inductor_config=True, version=2), }) ``` also adds in a test so that you cannot specify both fqn_to_config and module_fqn_to_config unless they are both equal. Test Plan: ``` pytest test/quantization/test_quant_api.py -k test_fqn_config_module_config_and_fqn_config_both_specified ``` Reviewers: Subscribers: Tasks: Tags: * fix ruff check --- test/quantization/test_quant_api.py | 7 +++++++ torchao/quantization/quant_api.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 506cec9dea..e530babdb9 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1178,6 +1178,13 @@ def __init__(self): assert isinstance(m.nested.linear.weight, AffineQuantizedTensor) assert isinstance(m.linear1.weight, AffineQuantizedTensor) + def test_fqn_config_module_config_and_fqn_config_both_specified(self): + with self.assertRaises(ValueError): + FqnToConfig( + fqn_to_config={"test": Float8WeightOnlyConfig()}, + module_fqn_to_config={"test2": Float8WeightOnlyConfig()}, + ) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c29382b658..f8602fa66c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2466,6 +2466,15 @@ class FqnToConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") + if ( + len(self.fqn_to_config) > 0 + and len(self.module_fqn_to_config) > 0 + and self.fqn_to_config != self.module_fqn_to_config + ): + raise ValueError( + "`fqn_to_config` and `module_fqn_to_config` are both specified and are not equal!" + ) + # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0: self.fqn_to_config = self.module_fqn_to_config @@ -2479,6 +2488,18 @@ def __post_init__(self): "Config Deprecation: _default is deprecated and will no longer be supported in a future release. Please see https://github.com/pytorch/ao/issues/3229 for more details." ) + def __str__(self): + return "\n".join( + [ + "FqnToConfig({", + *( + f" '{key}':\n {value}," + for key, value in self.fqn_to_config.items() + ), + "})", + ] + ) + # maintain BC ModuleFqnToConfig = FqnToConfig From 42fc6bdb48292e24c65a40107a6a7eb81131cd9e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 12 Nov 2025 10:03:09 -0800 Subject: [PATCH 17/22] Add support for e2e benchmark for conv2d/conv3d (#3329) Summary: att, we added this to float8_inference_roofline to reuse code but we haven't enabled the roofline feature. For now we just need the e2e speedup time for single conv2d/conv3d against bf16 to understand the speedup expecatation Also added B200 hardware spec. Test Plan: python $SCRIPT_PATH $OUTPUT_FILE \ --recipe_name $RECIPE_NAME \ --shape_gen_name $SHAPE_GEN_NAME \ --M $M --K $K --N $N \ --D $D --H $H --W $W \ --kernel_size $kernel_size \ --op_name conv3d This doesn't run yet because OSS fbgemm can't be installed in the B200 machine Reviewers: Subscribers: Tasks: Tags: Co-authored-by: jerryzh --- .../float8/float8_inference_roofline.py | 226 +++++++++++++----- torchao/testing/training/roofline_utils.py | 16 ++ 2 files changed, 177 insertions(+), 65 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index ea28d3236e..f5fa75cfb9 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -50,6 +50,7 @@ from torchao.quantization.quant_api import ( Float8DynamicActivationFloat8WeightConfig, PerRow, + PerTensor, quantize_, ) from torchao.quantization.quantize_.common import KernelPreference @@ -179,6 +180,11 @@ def run( n_limit: Optional[int] = None, save_profile_traces: bool = False, enable_fusion_modeling: bool = False, + op_name: str = "linear", + D: Optional[int] = None, + H: Optional[int] = None, + W: Optional[int] = None, + kernel_size: Optional[int] = None, ): """ Args: @@ -189,7 +195,29 @@ def run( * `n_limit (optional)`: if specified, only runs `n_limit` iterations # `save_profile_traces (optional)`: if True, saves profiling traces # `enable_fusion_modeling`: if True, models activation -> gemm instead of just gemm + # `op_name`: linear, conv2d or conv3d, decides which op to benchmark + # `D`, `H`, `W`: spatial dimensiosn for conv3d / conv2d + # `kernel_size`: kernel_size for conv3d / conv2d """ + _SUPPORTED_OPS = ["linear", "conv2d", "conv3d"] + assert op_name in _SUPPORTED_OPS, ( + f"Unsupported op: {op_name}, supported are: {_SUPPORTED_OPS}" + ) + if op_name == "conv2d": + assert H is not None and W is not None, ( + "Expected D, H, W to be specified for conv2d" + ) + assert kernel_size is not None, ( + "Expected kernel_size to be specified for conv2d" + ) + elif op_name == "conv3d": + assert D is not None and H is not None and W is not None, ( + "Expected D, H, W to be specified for conv3d" + ) + assert kernel_size is not None, ( + "Expected kernel_size to be specified for conv3d" + ) + config_table = [ ["GPU", torch.cuda.get_device_name(0)], ["torch version", torch.__version__], @@ -198,7 +226,10 @@ def run( ["do_benchmarks", do_benchmarks], ["shape_gen_name", shape_gen_name], ["enable_fusion_modeling", enable_fusion_modeling], + ["op_name", op_name], ["MKN", f"{M} {K} {N}"], + ["DHW", f"{D} {H} {W}"], + ["kernel_size", kernel_size], ] print(tabulate(config_table, headers=["Parameter", "Value"], tablefmt="simple")) @@ -207,33 +238,45 @@ def run( M, K, N = sympy.symbols("M K N") - fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( - M, - K, - N, - recipe_name, - # TODO(future): also enable fusion modeling here - ) - bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None) - - if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")): - fp8_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.float4_e2m1fn_x2, recipe_name + if op_name == "linear": + fp8_ovhd_time_sympy = get_inference_float8_mem_sympy( + M, + K, + N, + recipe_name, + # TODO(future): also enable fusion modeling here ) - else: - gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None - fp8_gemm_time_sympy = get_inference_gemm_time_sympy( - M, K, N, torch.float8_e4m3fn, gemm_recipe_name + bf16_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.bfloat16, None ) - print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) - print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) - print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) - print() + if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")): + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float4_e2m1fn_x2, recipe_name + ) + else: + gemm_recipe_name = "mxfp8" if recipe_name.startswith("mxfp8") else None + fp8_gemm_time_sympy = get_inference_gemm_time_sympy( + M, K, N, torch.float8_e4m3fn, gemm_recipe_name + ) + print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) + print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) + print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) + print() + else: + # TODO: enable roofline analysis for conv + pass + + # Note: roofline for conv2d/conv3d is not added yet, so most of the + # things for conv2d/conv3d we'll left out for now headers = [ "fwd_M", "fwd_K", "fwd_N", + "D", + "H", + "W", + "kernel_size", # roofline - gemm time (fwd + bwd, 3 gemms) "r_bf16_gemm_s", "r_fp8_gemm_s", @@ -258,6 +301,7 @@ def run( "rb_bf16_gemm_ratio", "rb_fp8_gemm_ratio", ] + results = [] name_to_shapes = get_name_to_shapes_iter(shape_gen_name, user_M, user_K, user_N) @@ -266,54 +310,93 @@ def run( if n_limit is not None and idx >= n_limit: break - # use roofline model to estimate gemm time - # note: cast from sympy.core.numbers.Float to float to make pandas formatting work - r_bf16_gemm_time_s = float( - bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) - r_fp8_gemm_time_s = float( - fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) - - # if enabled, also measured observed gemm time - b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 - rb_bf16_gemm_ratio = -1 - rb_fp8_gemm_ratio = -1 + if op_name == "linear": + # use roofline model to estimate gemm time + # note: cast from sympy.core.numbers.Float to float to make pandas formatting work + r_bf16_gemm_time_s = float( + bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) + r_fp8_gemm_time_s = float( + fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + ) - if do_benchmarks: - # TODO(future): make the bf16 gemm times exactly match the e2e - # benchmarks, there is a slight deviation, probably related to gemm - # operand memory formats/transpositions below not exactly matching - # what PyTorch core is doing for `torch.mm` - # input @ weight_t = output - bf16_g1, f8_g1 = get_gemm_times( - M_val, - K_val, - N_val, - True, - recipe_name, + # note: cast from sympy.core.numbers.Float to float to make pandas formatting work + r_fp8_ovhd_time_s = float( + fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) ) - b_bf16_gemm_time_s = bf16_g1 - b_fp8_gemm_time_s = f8_g1 - rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s - rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s - - # note: cast from sympy.core.numbers.Float to float to make pandas formatting work - r_fp8_ovhd_time_s = float( - fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) + r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s + r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) + + # if enabled, also measured observed gemm time + b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 + rb_bf16_gemm_ratio = -1 + rb_fp8_gemm_ratio = -1 + + if do_benchmarks: + # TODO(future): make the bf16 gemm times exactly match the e2e + # benchmarks, there is a slight deviation, probably related to gemm + # operand memory formats/transpositions below not exactly matching + # what PyTorch core is doing for `torch.mm` + # input @ weight_t = output + bf16_g1, f8_g1 = get_gemm_times( + M_val, + K_val, + N_val, + True, + recipe_name, + ) + b_bf16_gemm_time_s = bf16_g1 + b_fp8_gemm_time_s = f8_g1 + rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s + rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s + + else: + # roofline analysis for conv2d/conv3d are not added yet + r_bf16_gemm_time_s = None + r_fp8_gemm_time_s = None + + r_fp8_ovhd_time_s = None + r_fp8_gemm_and_ovhd_s = None + r_speedup = None + + # real gemm benchmark time, also not added yet + # if enabled, also measured observed gemm time + b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 + # gemm roofline ratio achieved in real benchmark + rb_bf16_gemm_ratio = -1 + rb_fp8_gemm_ratio = -1 b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 if do_benchmarks: # create the model - if not enable_fusion_modeling: - m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False)) + if op_name == "conv2d": + m_orig = nn.Sequential( + nn.Conv2d(K_val, N_val, kernel_size, bias=False) + ).to(memory_format=torch.channels_last) + elif op_name == "conv3d": + m_orig = nn.Sequential( + nn.Conv3d(K_val, N_val, kernel_size, bias=False) + ).to(memory_format=torch.channels_last_3d) else: - m_orig = nn.Sequential(nn.ReLU(), nn.Linear(K_val, N_val, bias=False)) + if not enable_fusion_modeling: + m_orig = nn.Sequential(nn.Linear(K_val, N_val, bias=False)) + else: + m_orig = nn.Sequential( + nn.ReLU(), nn.Linear(K_val, N_val, bias=False) + ) m_orig = m_orig.cuda().bfloat16() - x = torch.randn( - M_val, K_val, dtype=torch.bfloat16, device="cuda" - ).requires_grad_() + if op_name == "conv2d": + x = torch.randn( + M_val, K_val, H, W, dtype=torch.bfloat16, device="cuda" + ).to(memory_format=torch.channels_last) + elif op_name == "conv3d": + x = torch.randn( + M_val, K_val, D, H, W, dtype=torch.bfloat16, device="cuda" + ).to(memory_format=torch.channels_last_3d) + else: + x = torch.randn( + M_val, K_val, dtype=torch.bfloat16, device="cuda" + ).requires_grad_() # get the bf16 gpu kernel time torch._dynamo.reset() @@ -327,7 +410,11 @@ def run( # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() - if recipe_name == "rowwise": + if recipe_name == "tensorwise": + config = Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ) + elif recipe_name == "rowwise": config = Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), # for now, use TORCH. In the future might be interesting @@ -355,7 +442,14 @@ def run( assert False, "unsupported" m_fp8_dyn = copy.deepcopy(m_orig) - quantize_(m_fp8_dyn, config) + if op_name == "linear": + quantize_(m_fp8_dyn, config) + elif op_name == "conv2d": + _is_conv2d = lambda m, fqn: isinstance(m, torch.nn.Conv2d) + quantize_(m_fp8_dyn, config, filter_fn=_is_conv2d) + else: + _is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d) + quantize_(m_fp8_dyn, config, filter_fn=_is_conv3d) m_fp8_dyn = torch.compile(m_fp8_dyn) @@ -364,20 +458,22 @@ def run( fp8_trace_filename = f"{outfile}_{M_val}_{K_val}_{N_val}_fp8.json" b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, fp8_trace_filename) - r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s) - results.append( [ M_val, K_val, N_val, + D, + H, + W, + kernel_size, # roofline - gemm r_bf16_gemm_time_s, r_fp8_gemm_time_s, # roofline - fp8 overhead r_fp8_ovhd_time_s, # roofline - gemm + overhead, and speedup - r_fp8_gemm_time_s + r_fp8_ovhd_time_s, + r_fp8_gemm_and_ovhd_s, r_speedup, # benchmarks - gemm b_bf16_gemm_time_s, diff --git a/torchao/testing/training/roofline_utils.py b/torchao/testing/training/roofline_utils.py index e391a4d44b..bf234b3717 100644 --- a/torchao/testing/training/roofline_utils.py +++ b/torchao/testing/training/roofline_utils.py @@ -43,6 +43,22 @@ # TODO(future): measure once we have the hardware "pct_achievable_mem_bw": 0.92, }, + "NVIDIA GB200": { + # https://resources.nvidia.com/en-us-blackwell-architecture, page 19, + # divide by 2 because no sparsity + "bf16_peak_tops": 2.25e15, + "fp8_peak_tops": 4.5e15, + "fp4_peak_tops": 9.0e15, + # https://resources.nvidia.com/en-us-blackwell-architecture, page 20 + # 8.0 TB per second + "peak_mem_bw_bytes_sec": 8.0e12, + # for now, copy over from H100 + # TODO(future): measure once we have the hardware + "pct_achievable_gemm_tops": 0.78, + # for now, copy over from H100 + # TODO(future): measure once we have the hardware + "pct_achievable_mem_bw": 0.92, + }, "AMD Instinct MI300X": { # https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/data-sheets/amd-instinct-mi300x-data-sheet.pdf, page 1, "bf16_peak_tops": 1307e12, From 8c375689e5236238d84242bc5c241b3772c59251 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 12 Nov 2025 10:14:39 -0800 Subject: [PATCH 18/22] Move floatx_tensor_core_layout to prototype/dtypes (#3317) --- benchmarks/benchmark_fp6.py | 2 +- docs/source/api_ref_dtypes.rst | 2 +- test/dtypes/test_floatx.py | 12 +- torchao/dtypes/affine_quantized_tensor.py | 5 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 +- .../floatx/floatx_tensor_core_layout.py | 679 +----------------- torchao/prototype/dtypes/__init__.py | 2 + torchao/prototype/dtypes/floatx/__init__.py | 17 + .../floatx/floatx_tensor_core_layout.py | 666 +++++++++++++++++ torchao/quantization/quant_api.py | 2 +- 10 files changed, 723 insertions(+), 672 deletions(-) create mode 100644 torchao/prototype/dtypes/floatx/__init__.py create mode 100644 torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index c22eba9e1a..4aac4b952f 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -9,7 +9,7 @@ from tqdm import tqdm from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.floatx import FloatxTensorCoreLayout +from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout from torchao.utils import benchmark_torch_function_in_microseconds diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 3997b444b3..826c16fe19 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -20,7 +20,6 @@ Layouts and Tensor Subclasses TensorCoreTiledLayout Float8Layout FloatxTensor - FloatxTensorCoreLayout MarlinSparseLayout Int4CPULayout CutlassSemiSparseLayout @@ -52,6 +51,7 @@ Prototype Int8DynamicActInt4WeightCPULayout MarlinQQQTensor MarlinQQQLayout + FloatxTensorCoreLayout UintxLayout .. diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index ab4a13d24c..a3dd4d19e3 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -14,20 +14,20 @@ run_tests, ) -from torchao.dtypes.floatx import ( +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, +) +from torchao.prototype.dtypes.floatx import ( FloatxTensorCoreLayout, from_scaled_tc_floatx, to_scaled_tc_floatx, ) -from torchao.dtypes.floatx.floatx_tensor_core_layout import ( +from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( FloatxTensorCoreAQTTensorImpl, _pack_tc_floatx, _pack_tc_fp6, ) -from torchao.prototype.custom_fp_utils import ( - _f32_to_floatx_unpacked, - _floatx_unpacked_to_f32, -) from torchao.quantization import ( FPXWeightOnlyConfig, quantize_, diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 0d7ed8d9e2..3303bd5267 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -136,7 +136,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.floatx import Float8Layout, FloatxTensorCoreLayout + from torchao.dtypes.floatx import Float8Layout + from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout if isinstance(self._layout, FloatxTensorCoreLayout): int_data, scale = self.tensor_impl.get_plain() @@ -539,7 +540,7 @@ def from_hp_to_fpx( _layout: Layout, ): """Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7.""" - from torchao.dtypes.floatx import FloatxTensorCoreLayout + from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout assert isinstance(_layout, FloatxTensorCoreLayout), ( f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 6c7216ab12..730d33d2c6 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,10 +21,6 @@ _linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl, ) -from torchao.dtypes.floatx.floatx_tensor_core_layout import ( - _linear_f16_bf16_act_floatx_weight_check, - _linear_f16_bf16_act_floatx_weight_impl, -) from torchao.dtypes.uintx.int4_cpu_layout import ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, @@ -72,6 +68,10 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, +) from torchao.prototype.dtypes.uintx.block_sparse_layout import ( _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index c7fb1e1a7c..7f96564458 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -3,664 +3,29 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from functools import reduce -from typing import Optional, Tuple -import torch -from torch import Tensor -from torch.utils._python_dispatch import ( - is_traceable_wrapper_subclass, - return_and_correct_aliasing, -) +# Backward compatibility stub - imports from the new location +import warnings -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.utils import ( - AQTTensorImpl, - Layout, -) -from torchao.prototype.custom_fp_utils import ( - _f32_to_floatx_unpacked, - _floatx_unpacked_to_f32, - _n_ones, +warnings.warn( + "Importing from torchao.dtypes.floatx.floatx_tensor_core_layout is deprecated. " + "Please use 'from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ...' instead. " + "This import path will be removed in a future torchao release. " + "Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ", + DeprecationWarning, + stacklevel=2, ) -aten = torch.ops.aten -_ONES_TABLE = [_n_ones(i) for i in range(8)] - - -def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce( - torch.bitwise_or, - [ - x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) - for i in range(8 // n_bits) - ], - ) - - -def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack( - [ - (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) - for i in range(8 // n_bits) - ], - dim=-1, - ).flatten(-2) - - -# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 -def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: - # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8 - # thus, we need to reverse byte order within a uint32 word. - x = x.reshape(-1, 4).flip(1) - - x = _unpack(x, n_bits) - x = x.view(-1, 4 * (8 // n_bits)) - - if not undo: - bit_order = { - 1: [ - 1, - 5, - 9, - 13, - 17, - 21, - 25, - 29, - 3, - 7, - 11, - 15, - 19, - 23, - 27, - 31, - 0, - 4, - 8, - 12, - 16, - 20, - 24, - 28, - 2, - 6, - 10, - 14, - 18, - 22, - 26, - 30, - ], - 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], - 4: [1, 5, 3, 7, 0, 4, 2, 6], - }[n_bits] - - else: - # this is inverse of the above, obtained by running - # [v.index(i) for i in range(len(v))] - bit_order = { - 1: [ - 16, - 0, - 24, - 8, - 17, - 1, - 25, - 9, - 18, - 2, - 26, - 10, - 19, - 3, - 27, - 11, - 20, - 4, - 28, - 12, - 21, - 5, - 29, - 13, - 22, - 6, - 30, - 14, - 23, - 7, - 31, - 15, - ], - 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], - 4: [4, 0, 6, 2, 5, 1, 7, 3], - }[n_bits] - - x = x[:, bit_order] - x = _pack(x, n_bits) - - # reverse byte order within a uint32 word again. - x = x.reshape(-1, 4).flip(1) - return x.flatten() - - -# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing -# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h -def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - assert tensor.ndim == 2, tensor.dtype == torch.uint8 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - # Pass 1 from original code - tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6) - tensor = tensor.reshape(-1, 32, 2) - tensor = tensor.permute(1, 0, 2) - tensor = tensor.flatten() - - used_bits = 0 - fragments = [] - - for y in [1, 2, 4]: - if nbits & y: - mask = (1 << y) - 1 - tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask - tensor_ybit = _pack(tensor_ybit, y) - - tensor_ybit = ( - tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) - ) # Pass 2 from original code - tensor_ybit = _bit_interleave( - tensor_ybit.flatten(), y - ) # Pass 3 from original code - fragments.append(tensor_ybit) - used_bits += y - - return torch.cat(fragments, dim=0).view(M, -1) - - -# more optimized version of _pack_tc_floatx() for FP6 by merging ops -def _pack_tc_fp6(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2, tensor.dtype == torch.uint8 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) - tensor = tensor.flip(3) - - tensor_2bit = (tensor >> 4) & 0b11 - tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) - tensor_2bit = _pack(tensor_2bit.flatten(), 2) - - tensor_4bit = tensor & 0b1111 - tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) - tensor_4bit = _pack(tensor_4bit.flatten(), 4) - - return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) - - -# currently only optimize for TC-FP6 packing -def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - if nbits == 6: - return _pack_tc_fp6(tensor) - return _pack_tc_floatx(tensor, nbits) - - -def to_scaled_tc_floatx( - tensor: Tensor, ebits: int, mbits: int -) -> Tuple[Tensor, Tensor]: - # _n_ones() is not compatible with torch.compile() due to << operator - # https://github.com/pytorch/pytorch/issues/119152 - # exp_bias = _n_ones(ebits - 1) - # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) - - # workaround: global lookup table - exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( - _ONES_TABLE[mbits + 1] / (2**mbits) - ) - - dtype = tensor.dtype - tensor = tensor.float() - scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal - tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) - tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) - return tensor_tc_floatx, scale.to(dtype) - - -# inverse of _pack_tc_floatx() -def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - assert tensor.ndim == 2 and tensor.dtype == torch.uint8 - M = tensor.shape[0] - size = tensor.numel() - tensor = tensor.flatten() - offset = 0 - used_bits = 0 - - tensor_floatx = None - - for y in [1, 2, 4]: - if nbits & y: - size_ybit = size // nbits * y - tensor_ybit = tensor[offset : offset + size_ybit] - offset += size_ybit - - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = ( - tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) - ) # undo Pass 2 - - tensor_ybit = _unpack(tensor_ybit.flatten(), y) - tensor_ybit = tensor_ybit << (nbits - used_bits - y) - used_bits += y - - if tensor_floatx is None: - tensor_floatx = tensor_ybit - else: - tensor_floatx |= tensor_ybit - - # undo Pass 1 - tensor_floatx = tensor_floatx.view(32, -1, 2).permute(1, 0, 2) - tensor_floatx = tensor_floatx.reshape(M // 64, -1, 4, 2, 2, 8, 8) - tensor_floatx = tensor_floatx.permute(0, 2, 4, 5, 1, 3, 6) - tensor_floatx = tensor_floatx.reshape(M, -1) - return tensor_floatx - - -# more optimized version of _unpack_tc_floatx() for FP6 by merging ops -# inverse of _unpack_tc_fp6() -def _unpack_tc_fp6(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2 and tensor.dtype == torch.uint8 - M = tensor.shape[0] - N = tensor.shape[1] // 3 * 4 - assert (M % 64 == 0) and (N % 64 == 0) - size_2bit = M * N // 4 - size_4bit = M * N // 2 - tensor = tensor.view(-1) - assert tensor.numel() == size_2bit + size_4bit - - tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) - - tensor_2bit = _unpack(tensor_2bit, 2) - tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) - tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) - - tensor_4bit = _unpack(tensor_4bit, 4) - tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) - tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) - - tensor_fp6 = (tensor_2bit << 4) | tensor_4bit - tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) - return tensor_fp6 - - -def unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: - if nbits == 6: - return _unpack_tc_fp6(tensor) - return _unpack_tc_floatx(tensor, nbits) - - -def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: - floatx_unpacked = unpack_tc_floatx(tensor, 1 + ebits + mbits) - tensor = _floatx_unpacked_to_f32(floatx_unpacked, ebits, mbits) - if scale is not None: - tensor = tensor * scale.float().view(-1, 1) - return tensor - - -# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py -_SPLIT_K_MAP = [ - { # tokens: [1, 64] - 3072: 18, - 4096: 13, - 5120: 10, - 6144: 9, - 8192: 6, - 10240: 5, - 14336: 7, - 28672: 7, - 57344: 7, - }, - { # tokens: [65:128] - 3072: 9, - 4096: 6, - 5120: 5, - 6144: 9, - 8192: 3, - 10240: 5, - 14336: 7, - 28672: 7, - 57344: 6, - }, - { # tokens: [129:192] - 3072: 6, - 4096: 4, - 5120: 7, - 6144: 3, - 8192: 2, - 10240: 5, - 14336: 5, - 28672: 5, - 57344: 4, - }, - { # tokens: [193:256] - 3072: 9, - 4096: 3, - 5120: 5, - 6144: 2, - 8192: 5, - 10240: 4, - 14336: 8, - 28672: 6, - 57344: 4, - }, - { # tokens: [257:320] - 3072: 7, - 4096: 5, - 5120: 2, - 6144: 5, - 8192: 4, - 10240: 1, - 14336: 3, - 28672: 3, - 57344: 4, - }, - { # tokens: [321:384] - 3072: 3, - 4096: 2, - 5120: 5, - 6144: 3, - 8192: 1, - 10240: 8, - 14336: 3, - 28672: 4, - 57344: 3, - }, - { # tokens: [385:448] - 3072: 5, - 4096: 7, - 5120: 3, - 6144: 5, - 8192: 7, - 10240: 3, - 14336: 1, - 28672: 1, - 57344: 3, - }, - { # tokens: [449:512] - 3072: 2, - 4096: 5, - 5120: 4, - 6144: 1, - 8192: 5, - 10240: 2, - 14336: 6, - 28672: 4, - 57344: 1, - }, - { # tokens: [513:576] - 3072: 2, - 4096: 3, - 5120: 1, - 6144: 1, - 8192: 3, - 10240: 3, - 14336: 3, - 28672: 1, - 57344: 1, - }, - { # tokens: [577:640] - 3072: 5, - 4096: 4, - 5120: 1, - 6144: 4, - 8192: 2, - 10240: 1, - 14336: 1, - 28672: 1, - 57344: 1, - }, - { # tokens: [641:704] - 3072: 3, - 4096: 1, - 5120: 2, - 6144: 2, - 8192: 1, - 10240: 2, - 14336: 1, - 28672: 1, - 57344: 1, - }, - { # tokens: [705:768] - 3072: 3, - 4096: 1, - 5120: 3, - 6144: 2, - 8192: 1, - 10240: 1, - 14336: 1, - 28672: 1, - 57344: 1, - }, -] - - -# quantization api integrations -@dataclass(frozen=True) -class FloatxTensorCoreLayout(Layout): - """FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits). - This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations. - """ - - ebits: int - mbits: int - - -@register_layout(FloatxTensorCoreLayout) -class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): - """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), - it has a internal tensor field of "packed_floatx_data", which is packed from the - uint8 unpacked data (the output of `_quantize_affine_floatx` operator) - - The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 - github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - - At a high level packing is done by grouping bits into 1 bit fragments (shards), 2 bit fragments and - 4 bit fragments each fragments are packed separately and concatenated together. - For example for 6 bit dtype, we can extract the first 4 bits for all elements and pack them together - in a fragment, and extract the last 2 bits for all elements and pack them into fragment, in the end - we concatenate the fragments together. - - If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be - (M, N // 8 * nbit) - - FloatxTensorCoreAQTTensorImpl.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of - (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 - it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor - FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) - """ - - def __new__( - cls, - packed_floatx_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - assert packed_floatx_data.ndim == 2 - assert packed_floatx_data.dtype == torch.uint8 - shape = ( - packed_floatx_data.shape[0], - packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, - ) - kwargs = {} - kwargs["device"] = packed_floatx_data.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_floatx_data.layout - ) - kwargs["dtype"] = packed_floatx_data.dtype - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_floatx_data: torch.Tensor, - scale: torch.Tensor, - _layout: Layout, - ): - self.packed_floatx_data = packed_floatx_data - self.scale = scale - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_floatx_data", "scale"], [self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_floatx_data, scale = ( - tensor_data_dict["packed_floatx_data"], - tensor_data_dict["scale"], - ) - (_layout,) = tensor_attributes - return cls(packed_floatx_data, scale, _layout) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx( - self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits - ) - return unpacked_floatx_data, self.scale - - @classmethod - def from_plain( - cls, - unpacked_floatx_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - """ - Format for `unpacked_floatx_data` will be: - zero padding bits | sign bit | exponent bits | mantissa bits - - For example for fp6_e3_m2, the format will be: `00SEEEMM`, where S is sign bit, E is exponent - bit, M is mantissa bit - """ - assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx( - unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits - ) - return cls(packed_floatx_data, scale, _layout) - - def __repr__(self): - unpacked_floatx_data, scale = self.get_plain() - _layout = self.get_layout() - return f"{self.__class__.__name__}(unpacked_floatx_data={unpacked_floatx_data}, scale={scale}, _layout={_layout})" - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_floatx_data), - fn(self.scale), - self._layout, - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs.pop("device") - return self.__class__( - self.packed_floatx_data.to(device), - self.scale.to(device), - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - elif func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - elif func is aten._to_copy.default: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: x.to(device=kwargs.pop("device", None)) - ), - ) - - raise NotImplementedError( - f"FloatxTensorCoreAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_layout(self) -> Layout: - return self._layout - - -def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - return ( - # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) - and input_tensor.is_floating_point() - and input_tensor.dtype in (torch.float16, torch.bfloat16) - and - # weight is floatx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) - and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) - and ( - # weight is using fp6 quantization - (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) - or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) - or - # weight is using fp5 quantization - (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) - or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) - ) - ) - - -def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): - from torchao.ops import quant_llm_linear - - act = input_tensor - weight = weight_tensor - - out_dim, in_dim = weight.shape - act_reshaped = act.view(-1, in_dim) - - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - bsize = act_reshaped.shape[0] - splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - out = quant_llm_linear( - weight._layout.ebits, - weight._layout.mbits, - act_reshaped, - weight.tensor_impl.packed_floatx_data, - weight.tensor_impl.scale, - splitK=splitK, - ) - - if bias is not None: - out += bias - - return out.view(*act.shape[:-1], out_dim).to(act.dtype) +# Re-export all public symbols from the new location for backward compatibility +from torchao.prototype.dtypes.floatx.floatx_tensor_core_layout import ( # noqa: F401 + FloatxTensorCoreAQTTensorImpl, # noqa: F401 + FloatxTensorCoreLayout, # noqa: F401 + _linear_f16_bf16_act_floatx_weight_check, # noqa: F401 + _linear_f16_bf16_act_floatx_weight_impl, # noqa: F401 + _pack_tc_floatx, # noqa: F401 + _pack_tc_fp6, # noqa: F401 + from_scaled_tc_floatx, # noqa: F401 + pack_tc_floatx, # noqa: F401 + to_scaled_tc_floatx, # noqa: F401 + unpack_tc_floatx, # noqa: F401 +) diff --git a/torchao/prototype/dtypes/__init__.py b/torchao/prototype/dtypes/__init__.py index 88fe73ab76..bfb82fdd60 100644 --- a/torchao/prototype/dtypes/__init__.py +++ b/torchao/prototype/dtypes/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from .floatx import FloatxTensorCoreLayout from .uintx import ( BlockSparseLayout, CutlassInt4PackedLayout, @@ -26,6 +27,7 @@ "MarlinQQQTensor", "to_marlinqqq_quantized_intx", "GemlitePackedLayout", + "FloatxTensorCoreLayout", "UintxLayout", "UintxTensor", "UintxAQTTensorImpl", diff --git a/torchao/prototype/dtypes/floatx/__init__.py b/torchao/prototype/dtypes/floatx/__init__.py new file mode 100644 index 0000000000..edd045f8a9 --- /dev/null +++ b/torchao/prototype/dtypes/floatx/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from .floatx_tensor_core_layout import ( + FloatxTensorCoreLayout, + from_scaled_tc_floatx, + to_scaled_tc_floatx, +) + +__all__ = [ + "FloatxTensorCoreLayout", + "to_scaled_tc_floatx", + "from_scaled_tc_floatx", +] diff --git a/torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py new file mode 100644 index 0000000000..c7fb1e1a7c --- /dev/null +++ b/torchao/prototype/dtypes/floatx/floatx_tensor_core_layout.py @@ -0,0 +1,666 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from functools import reduce +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import ( + AQTTensorImpl, + Layout, +) +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, + _n_ones, +) + +aten = torch.ops.aten +_ONES_TABLE = [_n_ones(i) for i in range(8)] + + +def _pack(x: Tensor, n_bits: int) -> Tensor: + return reduce( + torch.bitwise_or, + [ + x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ], + ) + + +def _unpack(x: Tensor, n_bits: int) -> Tensor: + return torch.stack( + [ + (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits) + ], + dim=-1, + ).flatten(-2) + + +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 +def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: + # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8 + # thus, we need to reverse byte order within a uint32 word. + x = x.reshape(-1, 4).flip(1) + + x = _unpack(x, n_bits) + x = x.view(-1, 4 * (8 // n_bits)) + + if not undo: + bit_order = { + 1: [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 3, + 7, + 11, + 15, + 19, + 23, + 27, + 31, + 0, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + ], + 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], + 4: [1, 5, 3, 7, 0, 4, 2, 6], + }[n_bits] + + else: + # this is inverse of the above, obtained by running + # [v.index(i) for i in range(len(v))] + bit_order = { + 1: [ + 16, + 0, + 24, + 8, + 17, + 1, + 25, + 9, + 18, + 2, + 26, + 10, + 19, + 3, + 27, + 11, + 20, + 4, + 28, + 12, + 21, + 5, + 29, + 13, + 22, + 6, + 30, + 14, + 23, + 7, + 31, + 15, + ], + 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], + 4: [4, 0, 6, 2, 5, 1, 7, 3], + }[n_bits] + + x = x[:, bit_order] + x = _pack(x, n_bits) + + # reverse byte order within a uint32 word again. + x = x.reshape(-1, 4).flip(1) + return x.flatten() + + +# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h +def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + # Pass 1 from original code + tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6) + tensor = tensor.reshape(-1, 32, 2) + tensor = tensor.permute(1, 0, 2) + tensor = tensor.flatten() + + used_bits = 0 + fragments = [] + + for y in [1, 2, 4]: + if nbits & y: + mask = (1 << y) - 1 + tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask + tensor_ybit = _pack(tensor_ybit, y) + + tensor_ybit = ( + tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + ) # Pass 2 from original code + tensor_ybit = _bit_interleave( + tensor_ybit.flatten(), y + ) # Pass 3 from original code + fragments.append(tensor_ybit) + used_bits += y + + return torch.cat(fragments, dim=0).view(M, -1) + + +# more optimized version of _pack_tc_floatx() for FP6 by merging ops +def _pack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) + tensor = tensor.flip(3) + + tensor_2bit = (tensor >> 4) & 0b11 + tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) + tensor_2bit = _pack(tensor_2bit.flatten(), 2) + + tensor_4bit = tensor & 0b1111 + tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) + tensor_4bit = _pack(tensor_4bit.flatten(), 4) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) + + +# currently only optimize for TC-FP6 packing +def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _pack_tc_fp6(tensor) + return _pack_tc_floatx(tensor, nbits) + + +def to_scaled_tc_floatx( + tensor: Tensor, ebits: int, mbits: int +) -> Tuple[Tensor, Tensor]: + # _n_ones() is not compatible with torch.compile() due to << operator + # https://github.com/pytorch/pytorch/issues/119152 + # exp_bias = _n_ones(ebits - 1) + # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) + + # workaround: global lookup table + exp_bias = _ONES_TABLE[ebits - 1] + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( + _ONES_TABLE[mbits + 1] / (2**mbits) + ) + + dtype = tensor.dtype + tensor = tensor.float() + scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) + tensor_tc_floatx = pack_tc_floatx(tensor_floatx, 1 + ebits + mbits) + return tensor_tc_floatx, scale.to(dtype) + + +# inverse of _pack_tc_floatx() +def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + size = tensor.numel() + tensor = tensor.flatten() + offset = 0 + used_bits = 0 + + tensor_floatx = None + + for y in [1, 2, 4]: + if nbits & y: + size_ybit = size // nbits * y + tensor_ybit = tensor[offset : offset + size_ybit] + offset += size_ybit + + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = ( + tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + ) # undo Pass 2 + + tensor_ybit = _unpack(tensor_ybit.flatten(), y) + tensor_ybit = tensor_ybit << (nbits - used_bits - y) + used_bits += y + + if tensor_floatx is None: + tensor_floatx = tensor_ybit + else: + tensor_floatx |= tensor_ybit + + # undo Pass 1 + tensor_floatx = tensor_floatx.view(32, -1, 2).permute(1, 0, 2) + tensor_floatx = tensor_floatx.reshape(M // 64, -1, 4, 2, 2, 8, 8) + tensor_floatx = tensor_floatx.permute(0, 2, 4, 5, 1, 3, 6) + tensor_floatx = tensor_floatx.reshape(M, -1) + return tensor_floatx + + +# more optimized version of _unpack_tc_floatx() for FP6 by merging ops +# inverse of _unpack_tc_fp6() +def _unpack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + N = tensor.shape[1] // 3 * 4 + assert (M % 64 == 0) and (N % 64 == 0) + size_2bit = M * N // 4 + size_4bit = M * N // 2 + tensor = tensor.view(-1) + assert tensor.numel() == size_2bit + size_4bit + + tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) + + tensor_2bit = _unpack(tensor_2bit, 2) + tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) + tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) + + tensor_4bit = _unpack(tensor_4bit, 4) + tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) + tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) + + tensor_fp6 = (tensor_2bit << 4) | tensor_4bit + tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) + return tensor_fp6 + + +def unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _unpack_tc_fp6(tensor) + return _unpack_tc_floatx(tensor, nbits) + + +def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: + floatx_unpacked = unpack_tc_floatx(tensor, 1 + ebits + mbits) + tensor = _floatx_unpacked_to_f32(floatx_unpacked, ebits, mbits) + if scale is not None: + tensor = tensor * scale.float().view(-1, 1) + return tensor + + +# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +_SPLIT_K_MAP = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7, + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6, + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4, + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4, + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4, + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3, + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3, + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1, + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1, + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1, + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1, + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1, + }, +] + + +# quantization api integrations +@dataclass(frozen=True) +class FloatxTensorCoreLayout(Layout): + """FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits). + This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations. + """ + + ebits: int + mbits: int + + +@register_layout(FloatxTensorCoreLayout) +class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): + """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), + it has a internal tensor field of "packed_floatx_data", which is packed from the + uint8 unpacked data (the output of `_quantize_affine_floatx` operator) + + The packing is optimized for TensorCore, from the fp6-llm paper: https://arxiv.org/abs/2401.14112 + github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm + + At a high level packing is done by grouping bits into 1 bit fragments (shards), 2 bit fragments and + 4 bit fragments each fragments are packed separately and concatenated together. + For example for 6 bit dtype, we can extract the first 4 bits for all elements and pack them together + in a fragment, and extract the last 2 bits for all elements and pack them into fragment, in the end + we concatenate the fragments together. + + If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be + (M, N // 8 * nbit) + + FloatxTensorCoreAQTTensorImpl.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of + (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 + it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor + FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) + """ + + def __new__( + cls, + packed_floatx_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + assert packed_floatx_data.ndim == 2 + assert packed_floatx_data.dtype == torch.uint8 + shape = ( + packed_floatx_data.shape[0], + packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, + ) + kwargs = {} + kwargs["device"] = packed_floatx_data.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_floatx_data.layout + ) + kwargs["dtype"] = packed_floatx_data.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_floatx_data: torch.Tensor, + scale: torch.Tensor, + _layout: Layout, + ): + self.packed_floatx_data = packed_floatx_data + self.scale = scale + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_floatx_data", "scale"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_floatx_data, scale = ( + tensor_data_dict["packed_floatx_data"], + tensor_data_dict["scale"], + ) + (_layout,) = tensor_attributes + return cls(packed_floatx_data, scale, _layout) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: + unpacked_floatx_data = unpack_tc_floatx( + self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits + ) + return unpacked_floatx_data, self.scale + + @classmethod + def from_plain( + cls, + unpacked_floatx_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + """ + Format for `unpacked_floatx_data` will be: + zero padding bits | sign bit | exponent bits | mantissa bits + + For example for fp6_e3_m2, the format will be: `00SEEEMM`, where S is sign bit, E is exponent + bit, M is mantissa bit + """ + assert isinstance(_layout, FloatxTensorCoreLayout) + packed_floatx_data = pack_tc_floatx( + unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits + ) + return cls(packed_floatx_data, scale, _layout) + + def __repr__(self): + unpacked_floatx_data, scale = self.get_plain() + _layout = self.get_layout() + return f"{self.__class__.__name__}(unpacked_floatx_data={unpacked_floatx_data}, scale={scale}, _layout={_layout})" + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_floatx_data), + fn(self.scale), + self._layout, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + return self.__class__( + self.packed_floatx_data.to(device), + self.scale.to(device), + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + elif func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: x.to(device=kwargs.pop("device", None)) + ), + ) + + raise NotImplementedError( + f"FloatxTensorCoreAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_layout(self) -> Layout: + return self._layout + + +def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import FloatxTensorCoreLayout + + return ( + # input is native float32 tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and + # weight is floatx Tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) + and ( + # weight is using fp6 quantization + (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) + or + # weight is using fp5 quantization + (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) + ) + ) + + +def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): + from torchao.ops import quant_llm_linear + + act = input_tensor + weight = weight_tensor + + out_dim, in_dim = weight.shape + act_reshaped = act.view(-1, in_dim) + + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + bsize = act_reshaped.shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + out = quant_llm_linear( + weight._layout.ebits, + weight._layout.mbits, + act_reshaped, + weight.tensor_impl.packed_floatx_data, + weight.tensor_impl.scale, + splitK=splitK, + ) + + if bias is not None: + out += bias + + return out.view(*act.shape[:-1], out_dim).to(act.dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f8602fa66c..83af9068ae 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2401,7 +2401,7 @@ def _fpx_weight_only_transform( module = _unwrap_float8_linear(module) from torchao.dtypes import to_affine_quantized_fpx - from torchao.dtypes.floatx import FloatxTensorCoreLayout + from torchao.prototype.dtypes.floatx import FloatxTensorCoreLayout assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" out_dim, in_dim = weight.shape From d7b537b0293798daec58bb98a10d988f3bebc2d5 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 12 Nov 2025 17:53:34 -0800 Subject: [PATCH 19/22] Use conda libgcc-ng 11.2 for nightly tests (#3326) --- .github/workflows/regression_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 149a7b07da..575aca6df0 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -45,7 +45,7 @@ jobs: gpu-arch-version: ${{ matrix.gpu-arch-version }} submodules: recursive script: | - conda create -n venv python=3.10 -y + conda create -n venv python=3.10 libgcc-ng=11.2.0 libstdcxx-ng=11.2.0 -y conda activate venv python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} @@ -117,7 +117,7 @@ jobs: gpu-arch-version: ${{ matrix.gpu-arch-version }} submodules: recursive script: | - conda create -n venv python=3.10 libgcc-ng=11.2.0 libstdcxx-ng=11.2.0 -y + conda create -n venv python=3.10 libgcc-ng=11.2.0 libstdcxx-ng=11.2.0 -y conda activate venv python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} From 9ba0a3f487f230bba4b56fd978a1a29fca7e70a2 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 13 Nov 2025 18:13:42 +0000 Subject: [PATCH 20/22] Fix tests --- torchao/quantization/pt2e/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index 7ff1dbc619..df92d485b9 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -859,6 +859,15 @@ def _get_aten_graph_module_for_pattern( ): aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + # PyTorch 2.9+ adds _guards_fn nodes to exported graphs. + # These should not be part of pattern matching, so remove them. + for node in list(aten_pattern.graph.nodes): # type: ignore[union-attr] + if node.op == "call_module" and node.name == "_guards_fn": + aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + # Also remove the _guards_fn module from the graph module if it exists + if hasattr(aten_pattern, "_guards_fn"): + delattr(aten_pattern, "_guards_fn") + aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] aten_pattern.recompile() # type: ignore[operator] From 38848060580baa1bce0d3632e31d18417f55b677 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Fri, 14 Nov 2025 22:15:01 +0000 Subject: [PATCH 21/22] Add a condition to run only if torch 2.9 --- torchao/quantization/pt2e/utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index df92d485b9..333e8ffc00 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -859,14 +859,15 @@ def _get_aten_graph_module_for_pattern( ): aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] - # PyTorch 2.9+ adds _guards_fn nodes to exported graphs. - # These should not be part of pattern matching, so remove them. - for node in list(aten_pattern.graph.nodes): # type: ignore[union-attr] - if node.op == "call_module" and node.name == "_guards_fn": - aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] - # Also remove the _guards_fn module from the graph module if it exists - if hasattr(aten_pattern, "_guards_fn"): - delattr(aten_pattern, "_guards_fn") + if torch.__version__.startswith("2.9"): + # PyTorch 2.9 adds _guards_fn nodes to exported graphs. + # These have errors only on torch 2.9 and 2.9.0 + for node in list(aten_pattern.graph.nodes): # type: ignore[union-attr] + if node.op == "call_module" and node.name == "_guards_fn": + aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] + # Also remove the _guards_fn module from the graph module if it exists + if hasattr(aten_pattern, "_guards_fn"): + delattr(aten_pattern, "_guards_fn") aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] aten_pattern.recompile() # type: ignore[operator] From a543b2a7132d2c9cc18d95454845a67190c31da9 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 14 Nov 2025 14:17:10 -0800 Subject: [PATCH 22/22] Update utils.py --- torchao/quantization/pt2e/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index 333e8ffc00..f3cbffa430 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -861,7 +861,7 @@ def _get_aten_graph_module_for_pattern( if torch.__version__.startswith("2.9"): # PyTorch 2.9 adds _guards_fn nodes to exported graphs. - # These have errors only on torch 2.9 and 2.9.0 + # These have errors only on torch 2.9.0 and 2.9.1 for node in list(aten_pattern.graph.nodes): # type: ignore[union-attr] if node.op == "call_module" and node.name == "_guards_fn": aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]