File tree Expand file tree Collapse file tree 8 files changed +24
-16
lines changed
fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions Expand file tree Collapse file tree 8 files changed +24
-16
lines changed Original file line number Diff line number Diff line change @@ -244,7 +244,8 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
244244 size_t workspace_size = Gemm::get_workspace_size (arguments);
245245
246246 // Allocate workspace memory
247- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
247+ at::Tensor workspace =
248+ at::empty (workspace_size, X.options ().dtype (at::kByte ));
248249
249250 // Check the problem size is supported or not
250251 cutlass::Status status = gemm.can_implement (arguments);
@@ -253,7 +254,7 @@ at::Tensor bf16i4bf16_rowwise_batched_impl(
253254 }
254255
255256 // Initialize CUTLASS kernel with arguments and workspace pointer
256- status = gemm.initialize (arguments, workspace.get ());
257+ status = gemm.initialize (arguments, workspace.data_ptr ());
257258 if (status != cutlass::Status::kSuccess ) {
258259 throw std::runtime_error (" cutlass cannot initialize" );
259260 }
Original file line number Diff line number Diff line change @@ -214,7 +214,8 @@ at::Tensor _f4f4bf16(
214214 size_t workspace_size = Gemm::get_workspace_size (arguments);
215215
216216 // Allocate workspace memory
217- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
217+ at::Tensor workspace =
218+ at::empty (workspace_size, XQ.options ().dtype (at::kByte ));
218219
219220 // Check the problem size is supported or not
220221 cutlass::Status status = gemm.can_implement (arguments);
@@ -223,7 +224,7 @@ at::Tensor _f4f4bf16(
223224 }
224225
225226 // Initialize CUTLASS kernel with arguments and workspace pointer
226- status = gemm.initialize (arguments, workspace.get ());
227+ status = gemm.initialize (arguments, workspace.data_ptr ());
227228 if (status != cutlass::Status::kSuccess ) {
228229 throw std::runtime_error (" cutlass cannot initialize" );
229230 }
Original file line number Diff line number Diff line change @@ -179,7 +179,8 @@ at::Tensor f8f8bf16_impl(
179179 size_t workspace_size = Gemm::get_workspace_size (arguments);
180180
181181 // Allocate workspace memory
182- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
182+ at::Tensor workspace =
183+ at::empty (workspace_size, XQ.options ().dtype (at::kByte ));
183184
184185 // Check the problem size is supported or not
185186 cutlass::Status status = gemm.can_implement (arguments);
@@ -188,7 +189,7 @@ at::Tensor f8f8bf16_impl(
188189 }
189190
190191 // Initialize CUTLASS kernel with arguments and workspace pointer
191- status = gemm.initialize (arguments, workspace.get ());
192+ status = gemm.initialize (arguments, workspace.data_ptr ());
192193 if (status != cutlass::Status::kSuccess ) {
193194 throw std::runtime_error (" cutlass cannot initialize" );
194195 }
Original file line number Diff line number Diff line change @@ -214,7 +214,8 @@ at::Tensor f8f8bf16_blockwise_impl(
214214 size_t workspace_size = Gemm::get_workspace_size (arguments);
215215
216216 // Allocate workspace memory
217- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
217+ at::Tensor workspace =
218+ at::empty (workspace_size, XQ.options ().dtype (at::kByte ));
218219
219220 // Check the problem size is supported or not
220221 cutlass::Status status = gemm.can_implement (arguments);
@@ -223,7 +224,7 @@ at::Tensor f8f8bf16_blockwise_impl(
223224 }
224225
225226 // Initialize CUTLASS kernel with arguments and workspace pointer
226- status = gemm.initialize (arguments, workspace.get ());
227+ status = gemm.initialize (arguments, workspace.data_ptr ());
227228 if (status != cutlass::Status::kSuccess ) {
228229 throw std::runtime_error (" cutlass cannot initialize" );
229230 }
Original file line number Diff line number Diff line change @@ -218,14 +218,15 @@ at::Tensor f8f8bf16_conv_impl(
218218 Conv conv;
219219
220220 size_t workspace_size = Conv::get_workspace_size (arguments);
221- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
221+ at::Tensor workspace =
222+ at::empty (workspace_size, activation.options ().dtype (at::kByte ));
222223
223224 cutlass::Status status = conv.can_implement (arguments);
224225 if (status != cutlass::Status::kSuccess ) {
225226 throw std::runtime_error (" cutlass cannot implement convolution" );
226227 }
227228
228- status = conv.initialize (arguments, workspace.get ());
229+ status = conv.initialize (arguments, workspace.data_ptr ());
229230 if (status != cutlass::Status::kSuccess ) {
230231 throw std::runtime_error (" cutlass cannot initialize convolution" );
231232 }
Original file line number Diff line number Diff line change @@ -209,7 +209,8 @@ at::Tensor f8f8bf16_tensorwise_impl(
209209 size_t workspace_size = Gemm::get_workspace_size (arguments);
210210
211211 // Allocate workspace memory
212- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
212+ at::Tensor workspace =
213+ at::empty (workspace_size, XQ.options ().dtype (at::kByte ));
213214
214215 // Check the problem size is supported or not
215216 cutlass::Status status = gemm.can_implement (arguments);
@@ -218,7 +219,7 @@ at::Tensor f8f8bf16_tensorwise_impl(
218219 }
219220
220221 // Initialize CUTLASS kernel with arguments and workspace pointer
221- status = gemm.initialize (arguments, workspace.get ());
222+ status = gemm.initialize (arguments, workspace.data_ptr ());
222223 if (status != cutlass::Status::kSuccess ) {
223224 throw std::runtime_error (" cutlass cannot initialize" );
224225 }
Original file line number Diff line number Diff line change @@ -214,7 +214,8 @@ at::Tensor f8i4bf16_rowwise_impl(
214214 size_t workspace_size = Gemm::get_workspace_size (arguments);
215215
216216 // Allocate workspace memory
217- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
217+ at::Tensor workspace =
218+ at::empty (workspace_size, XQ.options ().dtype (at::kByte ));
218219
219220 // Check the problem size is supported or not
220221 cutlass::Status status = gemm.can_implement (arguments);
@@ -223,7 +224,7 @@ at::Tensor f8i4bf16_rowwise_impl(
223224 }
224225
225226 // Initialize CUTLASS kernel with arguments and workspace pointer
226- status = gemm.initialize (arguments, workspace.get ());
227+ status = gemm.initialize (arguments, workspace.data_ptr ());
227228 if (status != cutlass::Status::kSuccess ) {
228229 throw std::runtime_error (" cutlass cannot initialize" );
229230 }
Original file line number Diff line number Diff line change @@ -257,7 +257,8 @@ at::Tensor i8i8bf16sm90a_impl(
257257 size_t workspace_size = Gemm::get_workspace_size (arguments);
258258
259259 // Allocate workspace memory
260- cutlass::device_memory::allocation<uint8_t > workspace (workspace_size);
260+ at::Tensor workspace =
261+ at::empty (workspace_size, XQ.options ().dtype (at::kByte ));
261262
262263 // Check the problem size is supported or not
263264 cutlass::Status status = gemm.can_implement (arguments);
@@ -266,7 +267,7 @@ at::Tensor i8i8bf16sm90a_impl(
266267 }
267268
268269 // Initialize CUTLASS kernel with arguments and workspace pointer
269- status = gemm.initialize (arguments, workspace.get ());
270+ status = gemm.initialize (arguments, workspace.data_ptr ());
270271 if (status != cutlass::Status::kSuccess ) {
271272 throw std::runtime_error (" cutlass cannot initialize" );
272273 }
You can’t perform that action at this time.
0 commit comments