Skip to content

Commit c611606

Browse files
cthifacebook-github-bot
authored andcommitted
Use torch allocation instead of cutlass::device_memory::allocation
Summary: `cutlass::device_memory::allocation` is a wrapper around [`cudaMalloc`](https://github.com/NVIDIA/cutlass/blob/2252254ce2c3f11ef5cfff9721ebbe7bd62cf8cb/tools/util/include/cutlass/util/device_memory.h#L56), but this would bypass PyTorch CCA and might use wrong device. We replace all usages with torch tensor allocation instead which would be less error prone. Differential Revision: D86768064
1 parent 648e57a commit c611606

File tree

8 files changed

+24
-16
lines changed

8 files changed

+24
-16
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
244244
size_t workspace_size = Gemm::get_workspace_size(arguments);
245245

246246
// Allocate workspace memory
247-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
247+
at::Tensor workspace =
248+
at::empty(workspace_size, X.options().dtype(at::kByte));
248249

249250
// Check the problem size is supported or not
250251
cutlass::Status status = gemm.can_implement(arguments);
@@ -253,7 +254,7 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
253254
}
254255

255256
// Initialize CUTLASS kernel with arguments and workspace pointer
256-
status = gemm.initialize(arguments, workspace.get());
257+
status = gemm.initialize(arguments, workspace.data_ptr());
257258
if (status != cutlass::Status::kSuccess) {
258259
throw std::runtime_error("cutlass cannot initialize");
259260
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16/f4f4bf16_common.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ at::Tensor _f4f4bf16(
214214
size_t workspace_size = Gemm::get_workspace_size(arguments);
215215

216216
// Allocate workspace memory
217-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
217+
at::Tensor workspace =
218+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
218219

219220
// Check the problem size is supported or not
220221
cutlass::Status status = gemm.can_implement(arguments);
@@ -223,7 +224,7 @@ at::Tensor _f4f4bf16(
223224
}
224225

225226
// Initialize CUTLASS kernel with arguments and workspace pointer
226-
status = gemm.initialize(arguments, workspace.get());
227+
status = gemm.initialize(arguments, workspace.data_ptr());
227228
if (status != cutlass::Status::kSuccess) {
228229
throw std::runtime_error("cutlass cannot initialize");
229230
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ at::Tensor f8f8bf16_impl(
179179
size_t workspace_size = Gemm::get_workspace_size(arguments);
180180

181181
// Allocate workspace memory
182-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
182+
at::Tensor workspace =
183+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
183184

184185
// Check the problem size is supported or not
185186
cutlass::Status status = gemm.can_implement(arguments);
@@ -188,7 +189,7 @@ at::Tensor f8f8bf16_impl(
188189
}
189190

190191
// Initialize CUTLASS kernel with arguments and workspace pointer
191-
status = gemm.initialize(arguments, workspace.get());
192+
status = gemm.initialize(arguments, workspace.data_ptr());
192193
if (status != cutlass::Status::kSuccess) {
193194
throw std::runtime_error("cutlass cannot initialize");
194195
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_blockwise.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ at::Tensor f8f8bf16_blockwise_impl(
214214
size_t workspace_size = Gemm::get_workspace_size(arguments);
215215

216216
// Allocate workspace memory
217-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
217+
at::Tensor workspace =
218+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
218219

219220
// Check the problem size is supported or not
220221
cutlass::Status status = gemm.can_implement(arguments);
@@ -223,7 +224,7 @@ at::Tensor f8f8bf16_blockwise_impl(
223224
}
224225

225226
// Initialize CUTLASS kernel with arguments and workspace pointer
226-
status = gemm.initialize(arguments, workspace.get());
227+
status = gemm.initialize(arguments, workspace.data_ptr());
227228
if (status != cutlass::Status::kSuccess) {
228229
throw std::runtime_error("cutlass cannot initialize");
229230
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_conv/f8f8bf16_conv_common.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,15 @@ at::Tensor f8f8bf16_conv_impl(
218218
Conv conv;
219219

220220
size_t workspace_size = Conv::get_workspace_size(arguments);
221-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
221+
at::Tensor workspace =
222+
at::empty(workspace_size, activation.options().dtype(at::kByte));
222223

223224
cutlass::Status status = conv.can_implement(arguments);
224225
if (status != cutlass::Status::kSuccess) {
225226
throw std::runtime_error("cutlass cannot implement convolution");
226227
}
227228

228-
status = conv.initialize(arguments, workspace.get());
229+
status = conv.initialize(arguments, workspace.data_ptr());
229230
if (status != cutlass::Status::kSuccess) {
230231
throw std::runtime_error("cutlass cannot initialize convolution");
231232
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ at::Tensor f8f8bf16_tensorwise_impl(
209209
size_t workspace_size = Gemm::get_workspace_size(arguments);
210210

211211
// Allocate workspace memory
212-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
212+
at::Tensor workspace =
213+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
213214

214215
// Check the problem size is supported or not
215216
cutlass::Status status = gemm.can_implement(arguments);
@@ -218,7 +219,7 @@ at::Tensor f8f8bf16_tensorwise_impl(
218219
}
219220

220221
// Initialize CUTLASS kernel with arguments and workspace pointer
221-
status = gemm.initialize(arguments, workspace.get());
222+
status = gemm.initialize(arguments, workspace.data_ptr());
222223
if (status != cutlass::Status::kSuccess) {
223224
throw std::runtime_error("cutlass cannot initialize");
224225
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ at::Tensor f8i4bf16_rowwise_impl(
214214
size_t workspace_size = Gemm::get_workspace_size(arguments);
215215

216216
// Allocate workspace memory
217-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
217+
at::Tensor workspace =
218+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
218219

219220
// Check the problem size is supported or not
220221
cutlass::Status status = gemm.can_implement(arguments);
@@ -223,7 +224,7 @@ at::Tensor f8i4bf16_rowwise_impl(
223224
}
224225

225226
// Initialize CUTLASS kernel with arguments and workspace pointer
226-
status = gemm.initialize(arguments, workspace.get());
227+
status = gemm.initialize(arguments, workspace.data_ptr());
227228
if (status != cutlass::Status::kSuccess) {
228229
throw std::runtime_error("cutlass cannot initialize");
229230
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/i8i8bf16.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ at::Tensor i8i8bf16sm90a_impl(
257257
size_t workspace_size = Gemm::get_workspace_size(arguments);
258258

259259
// Allocate workspace memory
260-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
260+
at::Tensor workspace =
261+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
261262

262263
// Check the problem size is supported or not
263264
cutlass::Status status = gemm.can_implement(arguments);
@@ -266,7 +267,7 @@ at::Tensor i8i8bf16sm90a_impl(
266267
}
267268

268269
// Initialize CUTLASS kernel with arguments and workspace pointer
269-
status = gemm.initialize(arguments, workspace.get());
270+
status = gemm.initialize(arguments, workspace.data_ptr());
270271
if (status != cutlass::Status::kSuccess) {
271272
throw std::runtime_error("cutlass cannot initialize");
272273
}

0 commit comments

Comments
 (0)