Skip to content

Commit 3bb5290

Browse files
jianyuhmeta-codesync[bot]
authored andcommitted
Support fp16 for cutlass grouped GEMM (pytorch#5111)
Summary: Pull Request resolved: pytorch#5111 X-link: https://github.com/facebookresearch/FBGEMM/pull/2116 Support FP16 grouped GEMM: * `torch.ops.fbgemm.bf16bf16bf16_grouped_stacked` (fprop) * `torch.ops.fbgemm.bf16bf16bf16_grouped_grad` (dgrad) * `torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad` (wgrad) Differential Revision: D86718224
1 parent a2d5c3a commit 3bb5290

File tree

38 files changed

+493
-116
lines changed

38 files changed

+493
-116
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -718,21 +718,35 @@ Kernel_bf16bf16bf16_grouped<InputType> get_kernel_via_tuning(
718718
return kernel;
719719
}
720720

721-
// BF16 grouped cutlass kernel dispatch.
721+
// BF16/FP16 grouped cutlass kernel dispatch.
722722
template <typename InputType>
723723
at::Tensor dispatch_bf16_grouped_kernel(
724724
int G,
725725
int total_M,
726726
int N,
727727
int K,
728-
InputType X, // BF16
729-
InputType W, // BF16
728+
InputType X, // BF16 or FP16
729+
InputType W, // BF16 or FP16
730730
at::Tensor output,
731731
int sm_count,
732732
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
733733
std::optional<at::Tensor> M_sizes = std::nullopt) {
734734
const int arch = getDeviceArch();
735735

736+
// Get dtype from input
737+
at::ScalarType dtype;
738+
if constexpr (std::is_same_v<InputType, at::TensorList>) {
739+
dtype = X[0].scalar_type();
740+
} else {
741+
dtype = X.scalar_type();
742+
}
743+
744+
// Validate dtype is supported
745+
TORCH_CHECK(
746+
dtype == at::kBFloat16 || dtype == at::kHalf,
747+
"Only BFloat16 and Float16 dtypes are supported, got ",
748+
dtype);
749+
736750
// Select kernel to run via heuristics or tuning.
737751
auto kernel = [&]() {
738752
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
@@ -778,7 +792,7 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
778792
total_output_size += output_size;
779793
output_sizes.push_back(output_size);
780794
}
781-
Y = at::empty(total_output_size, X[0].options().dtype(at::kBFloat16));
795+
Y = at::empty(total_output_size, X[0].options());
782796

783797
int64_t sm_count = getSMCount(Y.device().index(), std::nullopt);
784798

@@ -830,7 +844,7 @@ at::Tensor bf16bf16bf16_grouped_stacked(
830844
if (out.has_value()) {
831845
Y = out.value();
832846
} else {
833-
Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
847+
Y = at::empty(total_M * N, X.options());
834848
}
835849

836850
// Early exit for empty inputs.
@@ -859,7 +873,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
859873
int64_t K = W.size(2);
860874
int64_t total_output_size = G * M * N;
861875
at::Tensor Y;
862-
Y = at::zeros(total_output_size, X.options().dtype(at::kBFloat16));
876+
Y = at::zeros(total_output_size, X.options());
863877

864878
int64_t sm_count = getSMCount(Y.device().index(), std::nullopt);
865879

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 1, 1, false>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
1,
27+
1,
28+
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_1_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 1, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
1,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, false>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_2_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_1_4_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 4, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
1,
26+
4,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 1, 1, false>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_1_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 1, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_2_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 2, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
2,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped/bf16bf16bf16_grouped_128_128_128_2_4_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
1717
int sm_count,
1818
std::optional<at::Tensor> zero_start_index_M,
1919
std::optional<at::Tensor> M_sizes) {
20-
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 4, 1, true>(
21-
X, W, output, sm_count, zero_start_index_M, M_sizes);
20+
return bf16bf16bf16_grouped_impl_dispatch<
21+
at::Tensor,
22+
128,
23+
128,
24+
128,
25+
2,
26+
4,
27+
1,
28+
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
2229
}
2330

2431
at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
@@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
2835
int sm_count,
2936
std::optional<at::Tensor> zero_start_index_M,
3037
std::optional<at::Tensor> M_sizes) {
31-
return bf16bf16bf16_grouped_impl<
38+
return bf16bf16bf16_grouped_impl_dispatch<
3239
at::TensorList,
3340
128,
3441
128,

0 commit comments

Comments
 (0)