-
Notifications
You must be signed in to change notification settings - Fork 371
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
08e9095
db23cf3
b861dbc
9383550
2c84ba4
8ddddd3
bd6f58a
b5cb3c8
9a51cae
c53dad0
d300b02
c43a3ec
590e0b7
b3d4f3e
df79aa8
910906b
c61b36e
0a45f90
1251187
844d99d
a844678
2c0389a
bafeb43
7006cae
49a7a89
062f3cc
680cec9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal import common_utils | ||
| from torch.testing._internal.common_utils import run_tests | ||
| from torch._inductor.utils import run_and_get_code | ||
|
|
||
| from torchao.quantization.quantize_.common import KernelPreference | ||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||
| Int8Tensor, | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.testing.utils import TorchAOIntegrationTestCase | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| class TestInt8Tensor(TorchAOIntegrationTestCase): | ||
| def setUp(self): | ||
| super().setUp() | ||
| torch.manual_seed(42) | ||
| self.weight_fp = torch.randn(4, 3, dtype=torch.float32) | ||
| self.input_fp = torch.randn(2, 3, dtype=torch.float32) | ||
| self.bias = torch.randn(4) | ||
| self.block_size = [4, 3] | ||
|
|
||
| def test_creation_and_attributes(self): | ||
| """Test tensor creation, dtypes, and ranges""" | ||
| tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) | ||
|
|
||
| self.assertEqual(tensor.shape, (4, 3)) | ||
| self.assertEqual(tensor.qdata.dtype, torch.int8) | ||
| self.assertTrue( | ||
| torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) | ||
| ) | ||
|
|
||
| @common_utils.parametrize( | ||
| "kernel_preference", | ||
| [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM], | ||
| ) | ||
| def test_kernel_preference(self, kernel_preference): | ||
| """Test Int8Tensor with different kernels""" | ||
| tensor = Int8Tensor.from_hp( | ||
| self.weight_fp, self.block_size, kernel_preference=kernel_preference | ||
| ) | ||
|
|
||
| self.assertEqual(tensor.kernel_preference, kernel_preference) | ||
|
|
||
| def test_linear_operations(self): | ||
| """Test fp+int8 and int8+int8 linear ops with quantization error check""" | ||
| weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size) | ||
| input_q8 = Int8Tensor.from_hp(self.input_fp, self.block_size) | ||
|
|
||
| reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) | ||
| result_fp = torch.nn.functional.linear(self.input_fp, weight_q8, self.bias) | ||
| result_q8 = torch.nn.functional.linear(input_q8, weight_q8, self.bias) | ||
|
|
||
| self.assertEqual(result_fp.shape, reference.shape) | ||
| self.assertEqual(result_q8.shape, reference.shape) | ||
| self.assertTrue(compute_error(result_fp, reference) > 10) | ||
| self.assertTrue(compute_error(result_q8, reference) > 10) | ||
|
|
||
| def test_dynamic_quantization(self): | ||
| weight_q8_dynamic = Int8Tensor.from_hp( | ||
| self.weight_fp, | ||
| self.block_size, | ||
| act_quant_kwargs=QuantizeTensorToInt8Kwargs(), | ||
| ) | ||
|
|
||
| reference = torch.nn.functional.linear(self.input_fp, self.weight_fp, self.bias) | ||
| result_dynamic = torch.nn.functional.linear( | ||
| self.input_fp, weight_q8_dynamic, self.bias | ||
| ) | ||
|
|
||
| self.assertEqual(result_dynamic.shape, reference.shape) | ||
|
||
|
|
||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| def test_expected_kernel_operations(self): | ||
| """Test Int8Tensor with FBGEMM kernels""" | ||
|
|
||
| # Setup model with Int8Tensor | ||
| weight_q8 = Int8Tensor.from_hp( | ||
| self.weight_fp, | ||
| self.block_size, | ||
| kernel_preference=KernelPreference.FBGEMM | ||
| ) | ||
|
|
||
| def model(x): | ||
| return torch.nn.functional.linear(x, weight_q8, self.bias) | ||
|
|
||
| compiled_model = torch.compile(model) | ||
|
|
||
| output, code = run_and_get_code(compiled_model, self.input_fp) | ||
|
|
||
| self.assertEqual(output.shape, (2, 4)) | ||
| self.assertTrue(len(code) > 0, "Should generate some compiled code") | ||
|
|
||
| # Test dequantization kernel | ||
| dequant_output = torch.ops.aten.dequantize.self(weight_q8) | ||
| self.assertEqual(dequant_output.shape, self.weight_fp.shape) | ||
|
|
||
| def test_error_handling_and_dequant(self): | ||
| """Test input validation and dequantization accuracy""" | ||
| # Test 1D tensor validation | ||
| with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
| Int8Tensor.from_hp(torch.randn(5), [1]) | ||
|
|
||
| # Test wrong block_size validation | ||
| with self.assertRaises((AssertionError, ValueError, RuntimeError)): | ||
| Int8Tensor.from_hp(self.weight_fp, [1]) | ||
|
|
||
| # Test dequantization with exact values | ||
| test_data = torch.tensor([[1.0, -1.0]], dtype=torch.float32) | ||
| tensor = Int8Tensor.from_hp(test_data, [1, 1]) | ||
|
|
||
| dequantized = torch.ops.aten.dequantize.self(tensor) | ||
| self.assertEqual(dequantized.shape, test_data.shape) | ||
| self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( | |
| """ | ||
| from torchao.quantization.quantize_.workflows import ( | ||
| Float8Tensor, | ||
| Int8Tensor, | ||
| QuantizeTensorToFloat8Kwargs, | ||
| QuantizeTensorToInt8Kwargs, | ||
| ) | ||
|
|
||
| if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): | ||
|
|
@@ -52,5 +54,11 @@ def _choose_quant_func_and_quantize_tensor( | |
| quant_kwargs.hp_value_ub, | ||
| quant_kwargs.kernel_preference, | ||
| ) | ||
| elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): | ||
| return Int8Tensor.from_hp( | ||
| tensor, | ||
| quant_kwargs.block_size or [1, tensor.shape[-1]], | ||
|
||
| kernel_preference=quant_kwargs.kernel_preference, | ||
| ) | ||
|
|
||
| raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") | ||
Uh oh!
There was an error while loading. Please reload this page.