Skip to content

Commit 58ae4dd

Browse files
jianyuhmeta-codesync[bot]
authored andcommitted
Support fp16 for cutlass grouped GEMM (#5111)
Summary: Pull Request resolved: #5111 X-link: https://github.com/facebookresearch/FBGEMM/pull/2116 Support FP16 grouped GEMM: * `torch.ops.fbgemm.bf16bf16bf16_grouped_stacked` (fprop) * `torch.ops.fbgemm.bf16bf16bf16_grouped_grad` (dgrad) * `torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad` (wgrad) Reviewed By: jwfromm Differential Revision: D86718224 fbshipit-source-id: a7817961f4949e87c59e3ca376f8caa03de390db
1 parent 62bdc5f commit 58ae4dd

File tree

38 files changed

+496
-116
lines changed

38 files changed

+496
-116
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -719,21 +719,35 @@ Kernel_bf16bf16bf16_grouped<InputType> get_kernel_via_tuning(
719719
return kernel;
720720
}
721721

722-
// BF16 grouped cutlass kernel dispatch.
722+
// BF16/FP16 grouped cutlass kernel dispatch.
723723
template <typename InputType>
724724
at::Tensor dispatch_bf16_grouped_kernel(
725725
int G,
726726
int total_M,
727727
int N,
728728
int K,
729-
InputType X, // BF16
730-
InputType W, // BF16
729+
InputType X, // BF16 or FP16
730+
InputType W, // BF16 or FP16
731731
at::Tensor output,
732732
int sm_count,
733733
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
734734
std::optional<at::Tensor> M_sizes = std::nullopt) {
735735
const int arch = getDeviceArch();
736736

737+
// Get dtype from input
738+
at::ScalarType dtype;
739+
if constexpr (std::is_same_v<InputType, at::TensorList>) {
740+
dtype = X[0].scalar_type();
741+
} else {
742+
dtype = X.scalar_type();
743+
}
744+
745+
// Validate dtype is supported
746+
TORCH_CHECK(
747+
dtype == at::kBFloat16 || dtype == at::kHalf,
748+
"Only BFloat16 and Float16 dtypes are supported, got ",
749+
dtype);
750+
737751
// Select kernel to run via heuristics or tuning.
738752
auto kernel = [&]() {
739753
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
@@ -781,7 +795,7 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
781795
total_output_size += output_size;
782796
output_sizes.push_back(output_size);
783797
}
784-
Y = at::empty(total_output_size, X[0].options().dtype(at::kBFloat16));
798+
Y = at::empty(total_output_size, X[0].options());
785799

786800
int64_t sm_count = getSMCount(Y.device().index(), std::nullopt);
787801

@@ -835,7 +849,7 @@ at::Tensor bf16bf16bf16_grouped_stacked(
835849
if (out.has_value()) {
836850
Y = out.value();
837851
} else {
838-
Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
852+
Y = at::empty(total_M * N, X.options());
839853
}
840854

841855
// Early exit for empty inputs.
@@ -866,7 +880,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
866880
int64_t K = W.size(2);
867881
int64_t total_output_size = G * M * N;
868882
at::Tensor Y;
869-
Y = at::zeros(total_output_size, X.options().dtype(at::kBFloat16));
883+
Y = at::zeros(total_output_size, X.options());
870884

871885
int64_t sm_count = getSMCount(Y.device().index(), std::nullopt);
872886

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 1, 1, false>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
1,
27+
1,
28+
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 1, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
1,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, false>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_4_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 4, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
4,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 1, 1, false>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 1, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_2_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 2, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
2,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_4_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 4, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
4,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

0 commit comments

Comments
 (0)