Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -718,21 +718,35 @@ Kernel_bf16bf16bf16_grouped<InputType> get_kernel_via_tuning(
return kernel;
}

// BF16 grouped cutlass kernel dispatch.
// BF16/FP16 grouped cutlass kernel dispatch.
template <typename InputType>
at::Tensor dispatch_bf16_grouped_kernel(
int G,
int total_M,
int N,
int K,
InputType X, // BF16
InputType W, // BF16
InputType X, // BF16 or FP16
InputType W, // BF16 or FP16
at::Tensor output,
int sm_count,
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
std::optional<at::Tensor> M_sizes = std::nullopt) {
const int arch = getDeviceArch();

// Get dtype from input
at::ScalarType dtype;
if constexpr (std::is_same_v<InputType, at::TensorList>) {
dtype = X[0].scalar_type();
} else {
dtype = X.scalar_type();
}

// Validate dtype is supported
TORCH_CHECK(
dtype == at::kBFloat16 || dtype == at::kHalf,
"Only BFloat16 and Float16 dtypes are supported, got ",
dtype);

// Select kernel to run via heuristics or tuning.
auto kernel = [&]() {
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
Expand Down Expand Up @@ -778,7 +792,7 @@ OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
total_output_size += output_size;
output_sizes.push_back(output_size);
}
Y = at::empty(total_output_size, X[0].options().dtype(at::kBFloat16));
Y = at::empty(total_output_size, X[0].options());

int64_t sm_count = getSMCount(Y.device().index(), std::nullopt);

Expand Down Expand Up @@ -830,7 +844,7 @@ at::Tensor bf16bf16bf16_grouped_stacked(
if (out.has_value()) {
Y = out.value();
} else {
Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
Y = at::empty(total_M * N, X.options());
}

// Early exit for empty inputs.
Expand Down Expand Up @@ -859,7 +873,7 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
int64_t K = W.size(2);
int64_t total_output_size = G * M * N;
at::Tensor Y;
Y = at::zeros(total_output_size, X.options().dtype(at::kBFloat16));
Y = at::zeros(total_output_size, X.options());

int64_t sm_count = getSMCount(Y.device().index(), std::nullopt);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 1, 1, false>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
1,
1,
1,
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 1, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
1,
1,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_1_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, false>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
1,
2,
1,
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
1,
2,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 4, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
1,
4,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 1, 1, false>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
2,
1,
1,
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 1, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
2,
1,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_1_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 2, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
2,
2,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 4, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
2,
4,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_2_4_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 4, 2, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
4,
2,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_2_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 4, 4, 1, true>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
128,
128,
4,
4,
1,
true>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,15 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 256, 128, 1, 1, 1, false>(
X, W, output, sm_count, zero_start_index_M, M_sizes);
return bf16bf16bf16_grouped_impl_dispatch<
at::Tensor,
128,
256,
128,
1,
1,
1,
false>(X, W, output, sm_count, zero_start_index_M, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_9_f(
Expand All @@ -28,7 +35,7 @@ at::Tensor bf16bf16bf16_grouped_128_256_128_1_1_1_9_f(
int sm_count,
std::optional<at::Tensor> zero_start_index_M,
std::optional<at::Tensor> M_sizes) {
return bf16bf16bf16_grouped_impl<
return bf16bf16bf16_grouped_impl_dispatch<
at::TensorList,
128,
256,
Expand Down
Loading
Loading