From 6c4d01def77743f864dbb535edff0aafb36875ba Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 5 Jul 2025 17:47:06 +0800 Subject: [PATCH 1/6] add gguf kernel support Signed-off-by: Isotr0py <2037008807@qq.com> --- src/diffusers/quantizers/gguf/utils.py | 84 +++++++++++++++++++++++++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 5 ++ 3 files changed, 89 insertions(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 41d351712961..27010408541b 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -17,10 +17,11 @@ from contextlib import nullcontext import gguf +from gguf import GGMLQuantizationType as WeightType import torch import torch.nn as nn -from ...utils import is_accelerate_available +from ...utils import is_accelerate_available, is_kernels_available if is_accelerate_available(): @@ -29,6 +30,76 @@ from accelerate.hooks import add_hook_to_module, remove_hook_from_module +can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 +if can_use_cuda_kernels and is_kernels_available(): + from kernels import get_kernel + ops = get_kernel("Isotr0py/ggml") +else: + ops = None + + +UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} +STANDARD_QUANT_TYPES = { + WeightType.Q4_0, + WeightType.Q4_1, + WeightType.Q5_0, + WeightType.Q5_1, + WeightType.Q8_0, + WeightType.Q8_1, +} +KQUANT_TYPES = { + WeightType.Q2_K, + WeightType.Q3_K, + WeightType.Q4_K, + WeightType.Q5_K, + WeightType.Q6_K, +} +IMATRIX_QUANT_TYPES = { + WeightType.IQ1_M, + WeightType.IQ1_S, + WeightType.IQ2_XXS, + WeightType.IQ2_XS, + WeightType.IQ2_S, + WeightType.IQ3_XXS, + WeightType.IQ3_S, + WeightType.IQ4_XS, + WeightType.IQ4_NL, +} +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# MMQ kernel for I-Matrix quantization. +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES + + +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, + qweight_type: int) -> torch.Tensor: + # there is no need to call any kernel for fp16/bf16 + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T + # enable MMVQ in contiguous batching with batch_size=1 + if qweight_type in MMVQ_QUANT_TYPES: + y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # Use MMQ Kernel if it's available (standard + k-quants) + elif qweight_type in MMQ_QUANT_TYPES: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + # If there is no available MMQ kernel, fallback to dequantize + elif qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) + y = x @ weight.T + else: + # Raise an error if the quantization type is not supported. + # Might be useful if llama.cpp adds a new quantization type. + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. + qweight_type = WeightType(qweight_type) + raise NotImplementedError( + f"Unsupported GGUF quantization type: {qweight_type}") + return y + + # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook def _create_accelerate_new_hook(old_hook): r""" @@ -451,11 +522,22 @@ def __init__( ) -> None: super().__init__(in_features, out_features, bias, device) self.compute_dtype = compute_dtype + self.device = device def forward(self, inputs): + if ops is not None and self.weight.is_cuda and inputs.is_cuda: + return self.forward_cuda(inputs) + return self.forward_native(inputs) + + def forward_native(self, inputs): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) bias = self.bias.to(self.compute_dtype) if self.bias is not None else None output = torch.nn.functional.linear(inputs, weight, bias) return output + + def forward_cuda(self, inputs): + quant_type = self.weight.quant_type + return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2df05cb8eb36..72f020ec193e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -76,6 +76,7 @@ is_hpu_available, is_inflect_available, is_invisible_watermark_available, + is_kernels_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f12e9de33172..6174d5b72c32 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") +_kernels_available, _kernels_version = _is_package_available("kernels") _inflect_available, _inflect_version = _is_package_available("inflect") _unidecode_available, _unidecode_version = _is_package_available("unidecode") _k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion") @@ -274,6 +275,10 @@ def is_accelerate_available(): return _accelerate_available +def is_kernels_available(): + return _kernels_available + + def is_k_diffusion_available(): return _k_diffusion_available From 66bd237bc5fddafa813f6564cf041a539f39429d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 6 Jul 2025 01:00:01 +0800 Subject: [PATCH 2/6] fix Signed-off-by: Isotr0py <2037008807@qq.com> --- src/diffusers/quantizers/gguf/utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 27010408541b..03521eadb2b4 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -78,17 +78,21 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T - # enable MMVQ in contiguous batching with batch_size=1 - if qweight_type in MMVQ_QUANT_TYPES: - y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) - # Use MMQ Kernel if it's available (standard + k-quants) - elif qweight_type in MMQ_QUANT_TYPES: - y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + + # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for + # contiguous batching and inefficient with diffusers' batching, + # so we disabled it now. + + # elif qweight_type in MMVQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # elif qweight_type in MMQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) # If there is no available MMQ kernel, fallback to dequantize + elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) - weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) y = x @ weight.T else: # Raise an error if the quantization type is not supported. @@ -539,5 +543,10 @@ def forward_native(self, inputs): def forward_cuda(self, inputs): quant_type = self.weight.quant_type - return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + orig_shape = inputs.shape + inputs = inputs.view(-1, orig_shape[-1]) + output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + if self.bias is not None: + output = output + self.bias.to(self.compute_dtype) + return output.view(*orig_shape[:-1], -1) From e46571a7aad95b2a4efc10d076740ad260e129fc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 6 Jul 2025 01:47:13 +0800 Subject: [PATCH 3/6] optimize Signed-off-by: Isotr0py <2037008807@qq.com> --- src/diffusers/quantizers/gguf/utils.py | 68 ++++++++++++-------------- src/diffusers/utils/__init__.py | 2 +- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 03521eadb2b4..31f6ec3e7321 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -17,7 +17,6 @@ from contextlib import nullcontext import gguf -from gguf import GGMLQuantizationType as WeightType import torch import torch.nn as nn @@ -33,37 +32,37 @@ can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 if can_use_cuda_kernels and is_kernels_available(): from kernels import get_kernel + ops = get_kernel("Isotr0py/ggml") else: ops = None - -UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} +UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16} STANDARD_QUANT_TYPES = { - WeightType.Q4_0, - WeightType.Q4_1, - WeightType.Q5_0, - WeightType.Q5_1, - WeightType.Q8_0, - WeightType.Q8_1, + gguf.GGMLQuantizationType.Q4_0, + gguf.GGMLQuantizationType.Q4_1, + gguf.GGMLQuantizationType.Q5_0, + gguf.GGMLQuantizationType.Q5_1, + gguf.GGMLQuantizationType.Q8_0, + gguf.GGMLQuantizationType.Q8_1, } KQUANT_TYPES = { - WeightType.Q2_K, - WeightType.Q3_K, - WeightType.Q4_K, - WeightType.Q5_K, - WeightType.Q6_K, + gguf.GGMLQuantizationType.Q2_K, + gguf.GGMLQuantizationType.Q3_K, + gguf.GGMLQuantizationType.Q4_K, + gguf.GGMLQuantizationType.Q5_K, + gguf.GGMLQuantizationType.Q6_K, } IMATRIX_QUANT_TYPES = { - WeightType.IQ1_M, - WeightType.IQ1_S, - WeightType.IQ2_XXS, - WeightType.IQ2_XS, - WeightType.IQ2_S, - WeightType.IQ3_XXS, - WeightType.IQ3_S, - WeightType.IQ4_XS, - WeightType.IQ4_NL, + gguf.GGMLQuantizationType.IQ1_M, + gguf.GGMLQuantizationType.IQ1_S, + gguf.GGMLQuantizationType.IQ2_XXS, + gguf.GGMLQuantizationType.IQ2_XS, + gguf.GGMLQuantizationType.IQ2_S, + gguf.GGMLQuantizationType.IQ3_XXS, + gguf.GGMLQuantizationType.IQ3_S, + gguf.GGMLQuantizationType.IQ4_XS, + gguf.GGMLQuantizationType.IQ4_NL, } # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. # Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add @@ -73,8 +72,7 @@ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES -def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, - qweight_type: int) -> torch.Tensor: +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T @@ -87,8 +85,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) # elif qweight_type in MMQ_QUANT_TYPES: # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) - # If there is no available MMQ kernel, fallback to dequantize + # If there is no available MMQ kernel, fallback to dequantize elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) @@ -98,9 +96,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # Raise an error if the quantization type is not supported. # Might be useful if llama.cpp adds a new quantization type. # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. - qweight_type = WeightType(qweight_type) - raise NotImplementedError( - f"Unsupported GGUF quantization type: {qweight_type}") + qweight_type = gguf.GGMLQuantizationType(qweight_type) + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") return y @@ -528,12 +525,12 @@ def __init__( self.compute_dtype = compute_dtype self.device = device - def forward(self, inputs): + def forward(self, inputs: torch.Tensor): if ops is not None and self.weight.is_cuda and inputs.is_cuda: return self.forward_cuda(inputs) return self.forward_native(inputs) - def forward_native(self, inputs): + def forward_native(self, inputs: torch.Tensor): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) bias = self.bias.to(self.compute_dtype) if self.bias is not None else None @@ -541,12 +538,9 @@ def forward_native(self, inputs): output = torch.nn.functional.linear(inputs, weight, bias) return output - def forward_cuda(self, inputs): + def forward_cuda(self, inputs: torch.Tensor): quant_type = self.weight.quant_type - orig_shape = inputs.shape - inputs = inputs.view(-1, orig_shape[-1]) output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) if self.bias is not None: - output = output + self.bias.to(self.compute_dtype) - return output.view(*orig_shape[:-1], -1) - + output += self.bias.to(self.compute_dtype) + return output diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 72f020ec193e..72b12badf269 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -76,9 +76,9 @@ is_hpu_available, is_inflect_available, is_invisible_watermark_available, - is_kernels_available, is_k_diffusion_available, is_k_diffusion_version, + is_kernels_available, is_librosa_available, is_matplotlib_available, is_nltk_available, From de1fb4b615b9941e77602d132a36795a6f2d2961 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 24 Jul 2025 08:31:47 +0530 Subject: [PATCH 4/6] update --- src/diffusers/quantizers/gguf/utils.py | 8 +++- tests/quantization/gguf/test_gguf.py | 58 ++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 31f6ec3e7321..edbc60abf54e 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -12,8 +12,8 @@ # # See the License for the specific language governing permissions and # # limitations under the License. - import inspect +import os from contextlib import nullcontext import gguf @@ -29,7 +29,11 @@ from accelerate.hooks import add_hook_to_module, remove_hook_from_module -can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 +can_use_cuda_kernels = ( + os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "true").lower() in ["1", "true", "yes"] + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 7 +) if can_use_cuda_kernels and is_kernels_available(): from kernels import get_kernel diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 0d786de7e78f..aa558b3e82c1 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -29,6 +29,7 @@ nightly, numpy_cosine_similarity_distance, require_accelerate, + require_accelerator, require_big_accelerator, require_gguf_version_greater_or_equal, require_peft_backend, @@ -37,11 +38,68 @@ if is_gguf_available(): + import gguf + from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter enable_full_determinism() +@nightly +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFCudaKernelsTests(unittest.TestCase): + def setUp(self): + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_cuda_kernels_vs_native(self): + if torch_device != "cuda": + self.skipTest("CUDA kernels test requires CUDA device") + + from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels + + if not can_use_cuda_kernels: + self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)") + + test_quant_types = ["Q4_0", "Q4_K"] + test_shape = (1, 64, 512) # batch, seq_len, hidden_dim + compute_dtype = torch.bfloat16 + + for quant_type in test_quant_types: + qtype = getattr(gguf.GGMLQuantizationType, quant_type) + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + + in_features, out_features = 512, 512 + total_elements = in_features * out_features + n_blocks = total_elements // block_size + weight_bytes = n_blocks * type_size + + torch.manual_seed(42) + weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device) + weight = GGUFParameter(weight_data, quant_type=qtype) + + x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device) + + linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype) + linear.weight = weight + linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype)) + linear = linear.to(torch_device) + + with torch.no_grad(): + output_native = linear.forward_native(x) + output_cuda = linear.forward_cuda(x) + + # Compare outputs + max_diff = torch.abs(output_cuda - output_native).max() + assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output" + + @nightly @require_big_accelerator @require_accelerate From db94e2b5a773ac65a2db8ed3f03bb3257d6b3f01 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Jul 2025 06:30:12 +0200 Subject: [PATCH 5/6] update --- src/diffusers/quantizers/gguf/utils.py | 6 +++--- tests/quantization/gguf/test_gguf.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index edbc60abf54e..aa6a2818d158 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -91,18 +91,18 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) # If there is no available MMQ kernel, fallback to dequantize - elif qweight_type in DEQUANT_TYPES: + if qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape) - y = x @ weight.T + y = x @ weight.to(x.dtype).T else: # Raise an error if the quantization type is not supported. # Might be useful if llama.cpp adds a new quantization type. # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. qweight_type = gguf.GGMLQuantizationType(qweight_type) raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") - return y + return y.as_tensor() # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index aa558b3e82c1..a03efdd2be99 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -73,15 +73,12 @@ def test_cuda_kernels_vs_native(self): for quant_type in test_quant_types: qtype = getattr(gguf.GGMLQuantizationType, quant_type) - block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] - in_features, out_features = 512, 512 - total_elements = in_features * out_features - n_blocks = total_elements // block_size - weight_bytes = n_blocks * type_size torch.manual_seed(42) - weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device) + float_weight = torch.randn(out_features, in_features, dtype=torch.float32) + quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype) + weight_data = torch.from_numpy(quantized_data).to(device=torch_device) weight = GGUFParameter(weight_data, quant_type=qtype) x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device) @@ -95,9 +92,9 @@ def test_cuda_kernels_vs_native(self): output_native = linear.forward_native(x) output_cuda = linear.forward_cuda(x) - # Compare outputs - max_diff = torch.abs(output_cuda - output_native).max() - assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output" + assert torch.allclose(output_native, output_cuda, 1e-2), ( + f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" + ) @nightly From cb004ad5e6955f0622422b7ce1c13bc20086f201 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Jul 2025 11:03:39 +0200 Subject: [PATCH 6/6] update --- docs/source/en/quantization/gguf.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md index aec0875c6511..cb4be6712273 100644 --- a/docs/source/en/quantization/gguf.md +++ b/docs/source/en/quantization/gguf.md @@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0] image.save("flux-gguf.png") ``` +## Using Optimized CUDA Kernels with GGUF + +Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library: + +```shell +pip install -U kernels +``` + +Once installed, GGUF inference automatically uses optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`. + ## Supported Quantization Types - BF16