From f711c72a4f686ee52fb696b40c6d9161db97bc90 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Tue, 11 Nov 2025 11:21:11 -0800 Subject: [PATCH] Support fp16 for cutlass grouped GEMM (#5111) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5111 X-link: https://github.com/facebookresearch/FBGEMM/pull/2116 Support FP16 grouped GEMM: * `torch.ops.fbgemm.bf16bf16bf16_grouped_stacked` (fprop) * `torch.ops.fbgemm.bf16bf16bf16_grouped_grad` (dgrad) * `torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad` (wgrad) Reviewed By: jwfromm Differential Revision: D86718224 --- .../bf16bf16bf16_grouped.cu | 26 ++++-- ...6bf16bf16_grouped_128_128_128_1_1_1_9_f.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_1_1_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_1_2_1_9_f.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_1_2_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_1_4_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_2_1_1_9_f.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_2_1_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_2_2_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_2_4_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_4_2_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_128_128_4_4_1_9_t.cu | 13 ++- ...6bf16bf16_grouped_128_256_128_1_1_1_9_f.cu | 13 ++- ...6bf16bf16_grouped_128_256_128_1_2_1_9_f.cu | 13 ++- ...6bf16bf16_grouped_128_256_128_2_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_32_128_1_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_32_128_1_2_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_32_128_1_4_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_32_128_2_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_32_128_2_2_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_32_128_4_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_64_128_1_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_64_128_1_4_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_64_128_2_2_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_64_128_4_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_128_64_128_4_2_1_9_f.cu | 13 ++- ...6bf16bf16_grouped_256_128_128_1_1_1_9_f.cu | 13 ++- ...6bf16bf16_grouped_256_128_128_1_2_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_256_32_128_1_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_256_32_128_1_2_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_256_32_128_2_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_256_32_128_4_2_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_256_64_128_1_1_1_9_f.cu | 13 ++- ...16bf16bf16_grouped_256_64_128_2_1_1_9_f.cu | 13 ++- .../bf16bf16bf16_grouped_common.cuh | 91 ++++++++++++++++++- .../bf16bf16bf16_grouped_grad.cu | 5 +- .../bf16bf16bf16_grouped_wgrad.cu | 8 +- .../gen_ai/test/quantize/quantize_test.py | 53 +++++++++++ 38 files changed, 496 insertions(+), 116 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index 3f464ed601..4f2ebfc82b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -718,21 +718,35 @@ Kernel_bf16bf16bf16_grouped get_kernel_via_tuning( return kernel; } -// BF16 grouped cutlass kernel dispatch. +// BF16/FP16 grouped cutlass kernel dispatch. template at::Tensor dispatch_bf16_grouped_kernel( int G, int total_M, int N, int K, - InputType X, // BF16 - InputType W, // BF16 + InputType X, // BF16 or FP16 + InputType W, // BF16 or FP16 at::Tensor output, int sm_count, std::optional zero_start_index_M = std::nullopt, std::optional M_sizes = std::nullopt) { const int arch = getDeviceArch(); + // Get dtype from input + at::ScalarType dtype; + if constexpr (std::is_same_v) { + dtype = X[0].scalar_type(); + } else { + dtype = X.scalar_type(); + } + + // Validate dtype is supported + TORCH_CHECK( + dtype == at::kBFloat16 || dtype == at::kHalf, + "Only BFloat16 and Float16 dtypes are supported, got ", + dtype); + // Select kernel to run via heuristics or tuning. auto kernel = [&]() { if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { @@ -778,7 +792,7 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) { total_output_size += output_size; output_sizes.push_back(output_size); } - Y = at::empty(total_output_size, X[0].options().dtype(at::kBFloat16)); + Y = at::empty(total_output_size, X[0].options()); int64_t sm_count = getSMCount(Y.device().index(), std::nullopt); @@ -830,7 +844,7 @@ at::Tensor bf16bf16bf16_grouped_stacked( if (out.has_value()) { Y = out.value(); } else { - Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16)); + Y = at::empty(total_M * N, X.options()); } // Early exit for empty inputs. @@ -859,7 +873,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic( int64_t K = W.size(2); int64_t total_output_size = G * M * N; at::Tensor Y; - Y = at::zeros(total_output_size, X.options().dtype(at::kBFloat16)); + Y = at::zeros(total_output_size, X.options()); int64_t sm_count = getSMCount(Y.device().index(), std::nullopt); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_f.cu index fed8a11bd4..c7c20bc392 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 1, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_t.cu index 74473eece9..4efa91e70d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 1, + 1, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_f.cu index ce01fd0d65..5d33459e41 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 1, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_t.cu index 4f972c7526..69fcd1579f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 1, + 2, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_4_1_9_t.cu index ac4825ff6e..e13f6510eb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_4_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_4_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 1, + 4, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_f.cu index b606fa0421..48c17f57f1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 2, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_t.cu index 7b8d19285f..0814319673 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 2, + 1, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_2_1_9_t.cu index 6fe1c7bef3..43e0ce2a11 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_2_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_2_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 2, + 2, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_4_1_9_t.cu index 0c59a0efa2..cd196f11a3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_4_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_4_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 2, + 4, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_2_1_9_t.cu index e88a5d67eb..6e4b4f3dec 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_2_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_2_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 4, + 2, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_4_1_9_t.cu index 3c37ca4271..004abcf8e6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_4_1_9_t.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_4_4_1_9_t.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 128, + 128, + 4, + 4, + 1, + true>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_9_f.cu index d81b76d1e9..b402b10734 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 256, + 128, + 1, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 256, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_2_1_9_f.cu index 72c872a3d3..e4c2090715 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_1_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 256, + 128, + 1, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_256_128_1_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 256, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_9_f.cu index 03bb44a9dd..f9a3eae6c6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_256_128_2_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 256, + 128, + 2, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 256, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_1_1_9_f.cu index 845d70cb63..1842299bf3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 32, + 128, + 1, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_2_1_9_f.cu index a0acf31483..e62d9b81fb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 32, + 128, + 1, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_32_128_1_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_4_1_9_f.cu index 5ad86d8cb1..dec464a845 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_4_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_1_4_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_1_4_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 32, + 128, + 1, + 4, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_32_128_1_4_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_1_4_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_9_f.cu index 8cdcf1ba69..42a6600e66 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 32, + 128, + 2, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_2_1_9_f.cu index fb1f4a0c19..165a215216 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_2_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_2_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 32, + 128, + 2, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_32_128_2_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_2_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_9_f.cu index 33f85bc364..594b83f5b6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_32_128_4_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_4_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 32, + 128, + 4, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_32_128_4_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_32_128_4_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_1_1_9_f.cu index 3dfbd55c78..cb8858901c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 64, + 128, + 1, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 64, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_4_1_9_f.cu index 5be7365e0c..18f665eab3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_4_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_1_4_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_1_4_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 64, + 128, + 1, + 4, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_64_128_1_4_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_1_4_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 64, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_2_1_9_f.cu index 5f70e128ac..8920fbf285 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_2_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_2_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 64, + 128, + 2, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_64_128_2_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_2_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 64, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_9_f.cu index 68ab0125c3..1ad15bdeaf 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_4_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 64, + 128, + 4, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_64_128_4_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_4_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 64, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_2_1_9_f.cu index 1d877bf049..ae44b04990 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_64_128_4_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_4_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 128, + 64, + 128, + 4, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_128_64_128_4_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_64_128_4_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 128, 64, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_9_f.cu index 1a6a9d6acc..60846c22b7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 128, + 128, + 1, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_128_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_2_1_9_f.cu index 926b859786..0e675f3ebe 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_128_128_1_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_128_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 128, + 128, + 1, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_128_128_1_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_128_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_9_f.cu index 3f99b5eb7c..924c4e220f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 32, + 128, + 1, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_32_128_1_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_2_1_9_f.cu index a767ce6db9..73255d287a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_1_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 32, + 128, + 1, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_32_128_1_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_1_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_9_f.cu index 5844d25fa8..14d8910744 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_2_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 32, + 128, + 2, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_32_128_2_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_2_1_9_f.cu index cf6a8e7309..8441f0afcd 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_32_128_4_2_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_4_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 32, + 128, + 4, + 2, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_32_128_4_2_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_32_128_4_2_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 32, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_9_f.cu index c7d45d6dc8..83a7ebb9f7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_1_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_64_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 64, + 128, + 1, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_64_128_1_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_64_128_1_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 64, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_9_f.cu index b5729f89d2..ea61ec8bae 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_256_64_128_2_1_1_9_f.cu @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_256_64_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl( - X, W, output, sm_count, zero_start_index_M, M_sizes); + return bf16bf16bf16_grouped_impl_dispatch< + at::Tensor, + 256, + 64, + 128, + 2, + 1, + 1, + false>(X, W, output, sm_count, zero_start_index_M, M_sizes); } at::Tensor bf16bf16bf16_grouped_256_64_128_2_1_1_9_f( @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_256_64_128_2_1_1_9_f( int sm_count, std::optional zero_start_index_M, std::optional M_sizes) { - return bf16bf16bf16_grouped_impl< + return bf16bf16bf16_grouped_impl_dispatch< at::TensorList, 256, 64, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_common.cuh index 9b9ebe23bd..0ddc7ec53a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_common.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_common.cuh @@ -21,6 +21,20 @@ namespace fbgemm_gpu { +// Type trait to map torch dtype to CUTLASS element type +template +struct CutlassElementType; + +template <> +struct CutlassElementType { + static constexpr auto torch_dtype = at::kBFloat16; +}; + +template <> +struct CutlassElementType { + static constexpr auto torch_dtype = at::kHalf; +}; + inline int64_t _byte_align(int64_t offset) { int64_t remainder = offset % 16; if (remainder != 0) { @@ -187,6 +201,7 @@ __global__ void set_stacked_kernel_args_kernel( } template < + typename ElementType, typename InputType, int TB_M, int TB_N, @@ -204,9 +219,11 @@ at::Tensor bf16bf16bf16_grouped_impl( std::optional M_sizes) { int64_t G; at::TensorOptions options; + at::ScalarType dtype; if constexpr (std::is_same_v) { G = X.size(); options = X[0].options(); + dtype = X[0].scalar_type(); TORCH_CHECK(W.size() == G); } else { TORCH_CHECK( @@ -214,7 +231,16 @@ at::Tensor bf16bf16bf16_grouped_impl( "One of zero_start_index_M or M_sizes must be provided."); G = W.size(0); options = X.options(); + dtype = X.scalar_type(); } + + // Validate ElementType matches input dtype + TORCH_CHECK( + (std::is_same_v && + dtype == at::kBFloat16) || + (std::is_same_v && dtype == at::kHalf), + "ElementType must match input dtype"); + // The number of groups the kernel uses may vary. int kernel_groups = int(G); // Return early if there are no elements in the output. @@ -225,9 +251,9 @@ at::Tensor bf16bf16bf16_grouped_impl( // Define gemm configuration. using ProblemShape = cutlass::gemm::GroupProblemShape>; - using ElementA = cutlass::bfloat16_t; - using ElementB = cutlass::bfloat16_t; - using ElementC = cutlass::bfloat16_t; + using ElementA = ElementType; + using ElementB = ElementType; + using ElementC = ElementType; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -239,7 +265,6 @@ at::Tensor bf16bf16bf16_grouped_impl( using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that // supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; - using StageCountType = cutlass::gemm::collective::StageCountAuto; using TileShape = cute::Shape, cute::Int, cute::Int>; using ClusterShape = @@ -814,4 +839,62 @@ at::Tensor bf16bf16bf16_grouped_sm100_impl( } #endif +// Helper function to dispatch based on dtype +template < + typename InputType, + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG> +at::Tensor bf16bf16bf16_grouped_impl_dispatch( + InputType X, + InputType W, + at::Tensor output, + int sm_count, + std::optional zero_start_index_M, + std::optional M_sizes) { + // Get dtype from input + at::ScalarType dtype; + if constexpr (std::is_same_v) { + dtype = X[0].scalar_type(); + } else { + dtype = X.scalar_type(); + } + + // Dispatch to the correct ElementType based on dtype + if (dtype == at::kBFloat16) { + return bf16bf16bf16_grouped_impl< + cutlass::bfloat16_t, + InputType, + TB_M, + TB_N, + TB_K, + TBS_M, + TBS_N, + TBS_K, + PONG>(X, W, output, sm_count, zero_start_index_M, M_sizes); + } else if (dtype == at::kHalf) { + return bf16bf16bf16_grouped_impl< + cutlass::half_t, + InputType, + TB_M, + TB_N, + TB_K, + TBS_M, + TBS_N, + TBS_K, + PONG>(X, W, output, sm_count, zero_start_index_M, M_sizes); + } else { + TORCH_CHECK( + false, + "Unsupported dtype: ", + dtype, + ". Only BFloat16 and Float16 are supported."); + return output; // unreachable + } +} + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad.cu index b4b29ff956..1ecd022f96 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad.cu @@ -853,7 +853,7 @@ at::Tensor bf16bf16bf16_grouped_grad( if (out.has_value()) { Y = out.value(); } else { - Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16)); + Y = at::empty(total_M * N, X.options()); } // Early exit for empty inputs. if (total_M == 0) { @@ -894,8 +894,7 @@ at::Tensor bf16bf16bf16_grouped_grad_meta( if (out.has_value()) { return out.value(); } else { - at::Tensor output = - at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16)); + at::Tensor output = at::empty_symint({total_M, N}, X.options()); return output; } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu index 1e63443ac9..cf4a42b351 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu @@ -1076,11 +1076,11 @@ at::Tensor bf16bf16bf16_grouped_wgrad( "Output tensor must be Float32 when output_accum=True"); } else { TORCH_CHECK( - Y.dtype() == at::kBFloat16, - "Output tensor must be BFloat16 when output_accum=False"); + Y.dtype() == at::kBFloat16 || Y.dtype() == at::kHalf, + "Output tensor must be BFloat16 or Float16 when output_accum=False"); } } else { - Y = at::empty(G * N * K, X.options().dtype(at::kBFloat16)); + Y = at::empty(G * N * K, X.options()); } // Early exit for empty inputs. @@ -1121,7 +1121,7 @@ at::Tensor bf16bf16bf16_grouped_wgrad_meta( const at::SymInt G = M_sizes.size(0); const at::SymInt N = X.sym_size(1); const at::SymInt K = W.sym_size(1); - at::Tensor Y = at::empty_symint({G, N, K}, X.options().dtype(at::kBFloat16)); + at::Tensor Y = at::empty_symint({G, N, K}, X.options()); return Y; } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index bdd2279894..711a1271cf 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -1016,6 +1016,59 @@ def fp8_loopover_bmm( torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2) + @unittest.skipIf( + not torch.version.cuda, "Skip on AMD: FP16 GMM ops not yet suported." + ) + @settings(deadline=None) + @given( + G=st.sampled_from([1, 4, 8]), + M=st.sampled_from([512, 1024, 2048]), + N=st.sampled_from([512, 1024, 2048]), + K=st.sampled_from([512, 1024, 2048]), + dtype=st.sampled_from([torch.float16, torch.bfloat16]), + ) + def test_grouped_gemm_fp16_bf16_support( + self, G: int, M: int, N: int, K: int, dtype: torch.dtype + ): + """Test that CUTLASS grouped GEMM supports both FP16 and BF16 dtypes.""" + # Setup: Create input tensors with specified dtype + ms = torch.randint(1, (M // 64) + 1, (G,), dtype=torch.int64) * 64 + + x_group = [] + w_group = [] + + for m in ms: + x = torch.rand(size=(m, K), dtype=dtype, device=self.device) + w = torch.rand(size=(N, K), dtype=dtype, device=self.device) + x_group.append(x) + w_group.append(w) + + # Execute: Run CUTLASS grouped GEMM + y_cutlass = torch.ops.fbgemm.bf16bf16bf16_grouped(x_group, w_group) + + # Assert: Verify output dtype matches input dtype + if not isinstance(y_cutlass, (tuple, list)): + y_cutlass = torch.split(y_cutlass, tuple(ms.tolist()), dim=0) + + for i, y in enumerate(y_cutlass): + # Verify output dtype matches input dtype + self.assertEqual( + y.dtype, + dtype, + f"Output dtype {y.dtype} does not match input dtype {dtype} for group {i}", + ) + + # Compute reference output + y_ref = torch.matmul(x_group[i], w_group[i].t()) + + # Verify numerical correctness + torch.testing.assert_close( + y, + y_ref, + atol=1e-2 if dtype == torch.float16 else 8e-3, + rtol=1e-2 if dtype == torch.float16 else 8e-3, + ) + @unittest.skipIf(not SUPPORTS_FP8, "FP8 not supported on this platform") @settings(deadline=None) @given(