Skip to content

[CINN] Add the cuLaunchCooperativeKernel call for grid reduce #71941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/cinn/backends/codegen_cuda_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class CodeGenGpuHost : public CodeGenHost {
[&](common::X86Arch) { return CodeGenHost::Visit(op); },
[&](common::ARMArch) { return CodeGenHost::Visit(op); },
[&](common::NVGPUArch) {
if (op->name == runtime::intrinsic::call_cuda_kernel) {
if (op->name == runtime::intrinsic::call_cuda_kernel ||
op->name == runtime::intrinsic::call_cuda_cooperative_kernel) {
return LowerGPUKernelCall(op);
} else {
return CodeGenHost::Visit(op);
Expand Down
6 changes: 5 additions & 1 deletion paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,11 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
CINN_NOT_IMPLEMENTED;
},
[&](common::NVGPUArch) {
call_kernel = runtime::intrinsic::call_cuda_kernel;
// TODO(liangshuhao): when cooperative group is supported, change the
// second call to `call_cuda_cooperative_kernel`.
call_kernel = func->temp_spaces.empty()
? runtime::intrinsic::call_cuda_kernel
: runtime::intrinsic::call_cuda_kernel;
},
[&](common::HygonDCUArchHIP) {
call_kernel = runtime::intrinsic::call_hip_kernel;
Expand Down
17 changes: 17 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,23 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) {
.AddInputType<void *>() // stream
.End();

using cinn::runtime::cuda::cinn_call_cuda_cooperative_kernel;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_cooperative_kernel,
cinn::common::DefaultHostTarget())
.SetRetType<void>()
.AddInputType<void *>() // kernel_fn
.AddInputType<void *>() // args
.AddInputType<int>() // num_args
.AddInputType<int>() // grid_x
.AddInputType<int>() // grid_y
.AddInputType<int>() // grid_z
.AddInputType<int>() // block_x
.AddInputType<int>() // block_y
.AddInputType<int>() // block_z
.AddInputType<int>() // shared_mem
.AddInputType<void *>() // stream
.End();

using cinn::runtime::cuda::cinn_call_cublas;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_cublas,
cinn::common::DefaultHostTarget())
Expand Down
50 changes: 50 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,56 @@ void cinn_call_cuda_kernel(void *kernel_fn,
}
}

void cinn_call_cuda_cooperative_kernel(void *kernel_fn,
void *v_args,
int num_args,
int grid_x,
int grid_y,
int grid_z,
int block_x,
int block_y,
int block_z,
int shared_memory_bytes,
void *stream) {
VLOG(3) << "cinn_call_cuda_cooperative_kernel, grid_dim={" << grid_x << ", "
<< grid_y << ", " << grid_z << "}, block_dim={" << block_x << ", "
<< block_y << ", " << block_z << "}, num_args=" << num_args
<< ", shared_memory_bytes=" << shared_memory_bytes
<< ", stream=" << stream << ", kernel_fn=" << kernel_fn;

std::vector<void *> kernel_args;
{
cinn::utils::RecordEvent record_run("prepare_args",
cinn::utils::EventType::kInstruction);
kernel_args.reserve(num_args);
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
for (int idx = 0; idx < num_args; ++idx) {
if (args[idx].type_code() == ::cinn_type_code<cinn_buffer_t *>()) {
kernel_args.emplace_back(
&((cinn_buffer_t *)(args[idx]))->memory); // NOLINT
} else {
kernel_args.emplace_back(args[idx].data_addr());
}
}
}

{
cinn::utils::RecordEvent record_run("cuLaunchCooperativeKernel",
cinn::utils::EventType::kInstruction);
CUDA_DRIVER_CALL(
cuLaunchCooperativeKernel(static_cast<CUfunction>(kernel_fn),
grid_x,
grid_y,
grid_z,
block_x,
block_y,
block_z,
shared_memory_bytes,
static_cast<CUstream>(stream),
kernel_args.data()))
}
}

void cinn_call_cublas(void *v_args,
int num_args,
bool trans_a,
Expand Down
18 changes: 18 additions & 0 deletions paddle/cinn/runtime/cuda/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,24 @@ void cinn_call_cuda_kernel(void* kernel_fn,
int shared_memory_bytes,
void* stream);

/**
* Call a CUDA compiled kernel with cooperative groups.
*
* @param kernel_fn the compiled PTX kernel.
* @param args an array of cinn_pod_value_ts(consists of scalars and buffers).
*/
void cinn_call_cuda_cooperative_kernel(void* kernel_fn,
void* v_args,
int num_args,
int grid_x,
int grid_y,
int grid_z,
int block_x,
int block_y,
int block_z,
int shared_memory_bytes,
void* stream);

void cinn_call_cublas(void* v_args,
int num_args,
bool trans_a,
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/runtime/intrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ static const char* pod_value_to_void_p = "cinn_pod_value_to_void_p";
static const char* print_debug_args_repr = "cinn_print_debug_args";

static const char* call_cuda_kernel = "cinn_call_cuda_kernel";
static const char* call_cuda_cooperative_kernel =
"cinn_call_cuda_cooperative_kernel";

static const char* call_hip_kernel = "cinn_call_hip_kernel";

Expand Down