Skip to content

Commit 722f8b6

Browse files
cthifacebook-github-bot
authored andcommitted
Add CUDAGuard to ensure correct device (#5113)
Summary: X-link: facebookresearch/FBGEMM#2119 If the input tensors use a device that differs from the current device, it would cause the wrong device to be used for things such as workspace allocation (when using `cutlass::device_memory::allocation`) and kernel to run on the wrong stream. Either would break the kernel. As a fix we add the `CUDAGuard` to ensure correct device is used. - `cutlass::device_memory::allocation` is a wrapper around [`cudaMalloc`](https://github.com/NVIDIA/cutlass/blob/2252254ce2c3f11ef5cfff9721ebbe7bd62cf8cb/tools/util/include/cutlass/util/device_memory.h#L56), but this would bypass PyTorch CCA. We replace all usages with torch tensor allocation instead which would be less error prone and allow proper memory reuse. Differential Revision: D86768064
1 parent 648e57a commit 722f8b6

23 files changed

+99
-16
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112

1213
#include "bf16bf16bf16_grouped/bf16bf16bf16_grouped_manifest.cuh"
1314
#include "fbgemm_gpu/quantize/tuning_cache.cuh"
@@ -758,6 +759,8 @@ at::Tensor dispatch_bf16_grouped_kernel(
758759

759760
template <typename OutputType>
760761
OutputType _bf16bf16bf16_grouped(at::TensorList X, at::TensorList W) {
762+
c10::cuda::CUDAGuard deviceGuard(X[0].device());
763+
761764
at::Tensor Y;
762765
int64_t total_M = 0;
763766
int64_t G = X.size();
@@ -816,6 +819,8 @@ at::Tensor bf16bf16bf16_grouped_stacked(
816819
at::Tensor M_sizes,
817820
std::optional<at::Tensor> out,
818821
std::optional<int64_t> num_sms) {
822+
c10::cuda::CUDAGuard deviceGuard(X.device());
823+
819824
int64_t total_M = X.size(0);
820825
int64_t N = W.size(1);
821826
int64_t K = W.size(2);
@@ -853,6 +858,8 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
853858
TORCH_CHECK(
854859
zero_start_index_M.device() == X.device(),
855860
"zero_start_index_M must be on same device as inputs.");
861+
c10::cuda::CUDAGuard deviceGuard(X.device());
862+
856863
int64_t G = X.size(0);
857864
int64_t M = X.size(1);
858865
int64_t N = W.size(1);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad/bf16bf16bf16_grouped_grad_common.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
#include <cutlass/util/device_memory.h>
1213
#include <cutlass/util/packed_stride.hpp>
1314

@@ -112,6 +113,8 @@ at::Tensor bf16bf16bf16_grouped_grad_impl(
112113
at::Tensor output,
113114
int sm_count,
114115
std::optional<at::Tensor> M_sizes) {
116+
c10::cuda::CUDAGuard deviceGuard(X.device());
117+
115118
int64_t G;
116119
at::TensorOptions options;
117120
G = W.size(0);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
#include <cutlass/util/device_memory.h>
1213
#include <cutlass/util/packed_stride.hpp>
1314

@@ -113,6 +114,8 @@ at::Tensor bf16bf16bf16_grouped_wgrad_impl(
113114
at::Tensor M_sizes,
114115
at::Tensor output,
115116
int sm_count) {
117+
c10::cuda::CUDAGuard deviceGuard(X.device());
118+
116119
int64_t G;
117120
at::TensorOptions options;
118121
G = M_sizes.size(0);
@@ -377,6 +380,8 @@ at::Tensor bf16bf16bf16_grouped_wgrad_sm100_impl(
377380
at::Tensor M_sizes,
378381
at::Tensor output,
379382
int sm_count) {
383+
c10::cuda::CUDAGuard deviceGuard(X.device());
384+
380385
int64_t G;
381386
at::TensorOptions options;
382387
G = M_sizes.size(0);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112

1213
#include "cutlass/cutlass.h"
1314

@@ -41,6 +42,8 @@ at::Tensor _bf16i4bf16(
4142
at::Tensor w_scale_group,
4243
at::Tensor w_zero_group,
4344
at::Tensor Y) {
45+
c10::cuda::CUDAGuard deviceGuard(X.device());
46+
4447
// Get shape information from input tensors.
4548
int M = size_to_dim_(X.dim() - 1, X.sizes());
4649
int K = X.size(-1);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
#include <cutlass/util/device_memory.h>
1213
#include <cutlass/util/packed_stride.hpp>
1314

@@ -48,6 +49,8 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
4849
at::Tensor WQ, // INT4
4950
at::Tensor w_scale,
5051
at::Tensor w_zp) {
52+
c10::cuda::CUDAGuard deviceGuard(X.device());
53+
5154
// XQ: B x M x K
5255
// WQ: B x N x K
5356
// output: B x M x N
@@ -244,7 +247,8 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
244247
size_t workspace_size = Gemm::get_workspace_size(arguments);
245248

246249
// Allocate workspace memory
247-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
250+
at::Tensor workspace =
251+
at::empty(workspace_size, X.options().dtype(at::kByte));
248252

249253
// Check the problem size is supported or not
250254
cutlass::Status status = gemm.can_implement(arguments);
@@ -253,7 +257,7 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
253257
}
254258

255259
// Initialize CUTLASS kernel with arguments and workspace pointer
256-
status = gemm.initialize(arguments, workspace.get());
260+
status = gemm.initialize(arguments, workspace.data_ptr());
257261
if (status != cutlass::Status::kSuccess) {
258262
throw std::runtime_error("cutlass cannot initialize");
259263
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_shuffled_grouped.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112

1213
#include "cutlass/cutlass.h"
1314

@@ -137,6 +138,8 @@ void _bf16i4bf16_shuffled_grouped(
137138
at::Tensor w_zero_group,
138139
at::Tensor M_sizes,
139140
at::Tensor Y) {
141+
c10::cuda::CUDAGuard deviceGuard(X.device());
142+
140143
// Get basic shape information.
141144
int G = M_sizes.size(0);
142145
// X is shape [total_M, K]

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_common.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
#include <cutlass/util/device_memory.h>
1213
#include <cutlass/util/packed_stride.hpp>
1314

@@ -34,6 +35,8 @@ at::Tensor _f4f4bf16(
3435
at::Tensor x_scale,
3536
at::Tensor w_scale,
3637
std::optional<at::Tensor> global_scale) {
38+
c10::cuda::CUDAGuard deviceGuard(XQ.device());
39+
3740
int M = XQ.size(0);
3841
int N = WQ.size(0);
3942
int K = XQ.size(1) * 2; // Since K is packed
@@ -214,7 +217,8 @@ at::Tensor _f4f4bf16(
214217
size_t workspace_size = Gemm::get_workspace_size(arguments);
215218

216219
// Allocate workspace memory
217-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
220+
at::Tensor workspace =
221+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
218222

219223
// Check the problem size is supported or not
220224
cutlass::Status status = gemm.can_implement(arguments);
@@ -223,7 +227,7 @@ at::Tensor _f4f4bf16(
223227
}
224228

225229
// Initialize CUTLASS kernel with arguments and workspace pointer
226-
status = gemm.initialize(arguments, workspace.get());
230+
status = gemm.initialize(arguments, workspace.data_ptr());
227231
if (status != cutlass::Status::kSuccess) {
228232
throw std::runtime_error("cutlass cannot initialize");
229233
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_common.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
#include <cutlass/util/device_memory.h>
1213
#include <cutlass/util/packed_stride.hpp>
1314

@@ -160,6 +161,8 @@ at::Tensor f4f4bf16_grouped_impl(
160161
std::optional<at::Tensor> M_sizes,
161162
std::optional<at::Tensor> global_scale,
162163
std::optional<at::Tensor> starting_row_after_padding) {
164+
c10::cuda::CUDAGuard deviceGuard(XQ.device());
165+
163166
// The number of groups the kernel uses may vary.
164167
const int64_t G = [&]() {
165168
if (M_sizes) {

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
#include <cutlass/util/host_tensor.h>
1213
#include <cutlass/util/packed_stride.hpp>
1314

@@ -36,6 +37,7 @@ at::Tensor f8f8bf16_impl(
3637
at::Tensor XQ, // FP8
3738
at::Tensor WQ, // FP8
3839
at::Tensor scale) {
40+
c10::cuda::CUDAGuard deviceGuard(XQ.device());
3941
// XQ: M x K
4042
// WQ: N x K
4143
// output: M x N
@@ -179,7 +181,8 @@ at::Tensor f8f8bf16_impl(
179181
size_t workspace_size = Gemm::get_workspace_size(arguments);
180182

181183
// Allocate workspace memory
182-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
184+
at::Tensor workspace =
185+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
183186

184187
// Check the problem size is supported or not
185188
cutlass::Status status = gemm.can_implement(arguments);
@@ -188,7 +191,7 @@ at::Tensor f8f8bf16_impl(
188191
}
189192

190193
// Initialize CUTLASS kernel with arguments and workspace pointer
191-
status = gemm.initialize(arguments, workspace.get());
194+
status = gemm.initialize(arguments, workspace.data_ptr());
192195
if (status != cutlass::Status::kSuccess) {
193196
throw std::runtime_error("cutlass cannot initialize");
194197
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
#include <cutlass/util/host_tensor.h>
1213
#include <cutlass/util/packed_stride.hpp>
1314

@@ -46,6 +47,7 @@ at::Tensor f8f8bf16_blockwise_impl(
4647
at::Tensor WQ, // FP8
4748
at::Tensor x_scale,
4849
at::Tensor w_scale) {
50+
c10::cuda::CUDAGuard deviceGuard(XQ.device());
4951
// XQ: M x K
5052
// WQ: N x K
5153
// output: M x N
@@ -214,7 +216,8 @@ at::Tensor f8f8bf16_blockwise_impl(
214216
size_t workspace_size = Gemm::get_workspace_size(arguments);
215217

216218
// Allocate workspace memory
217-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
219+
at::Tensor workspace =
220+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
218221

219222
// Check the problem size is supported or not
220223
cutlass::Status status = gemm.can_implement(arguments);
@@ -223,7 +226,7 @@ at::Tensor f8f8bf16_blockwise_impl(
223226
}
224227

225228
// Initialize CUTLASS kernel with arguments and workspace pointer
226-
status = gemm.initialize(arguments, workspace.get());
229+
status = gemm.initialize(arguments, workspace.data_ptr());
227230
if (status != cutlass::Status::kSuccess) {
228231
throw std::runtime_error("cutlass cannot initialize");
229232
}

0 commit comments

Comments
 (0)