diff --git a/paddle/cinn/backends/codegen_cuda_host.h b/paddle/cinn/backends/codegen_cuda_host.h index 3665c018942830..33214a533de3e2 100644 --- a/paddle/cinn/backends/codegen_cuda_host.h +++ b/paddle/cinn/backends/codegen_cuda_host.h @@ -43,7 +43,8 @@ class CodeGenGpuHost : public CodeGenHost { [&](common::X86Arch) { return CodeGenHost::Visit(op); }, [&](common::ARMArch) { return CodeGenHost::Visit(op); }, [&](common::NVGPUArch) { - if (op->name == runtime::intrinsic::call_cuda_kernel) { + if (op->name == runtime::intrinsic::call_cuda_kernel || + op->name == runtime::intrinsic::call_cuda_cooperative_kernel) { return LowerGPUKernelCall(op); } else { return CodeGenHost::Visit(op); diff --git a/paddle/cinn/backends/codegen_device_util.cc b/paddle/cinn/backends/codegen_device_util.cc index ecfd3346790c28..968f9c392a27d6 100644 --- a/paddle/cinn/backends/codegen_device_util.cc +++ b/paddle/cinn/backends/codegen_device_util.cc @@ -249,7 +249,11 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( CINN_NOT_IMPLEMENTED; }, [&](common::NVGPUArch) { - call_kernel = runtime::intrinsic::call_cuda_kernel; + // TODO(liangshuhao): when cooperative group is supported, change the + // second call to `call_cuda_cooperative_kernel`. + call_kernel = func->temp_spaces.empty() + ? runtime::intrinsic::call_cuda_kernel + : runtime::intrinsic::call_cuda_kernel; }, [&](common::HygonDCUArchHIP) { call_kernel = runtime::intrinsic::call_hip_kernel; diff --git a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc index 79188b113cde82..8dd78a96d76f9c 100644 --- a/paddle/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/paddle/cinn/runtime/cuda/cuda_intrinsics.cc @@ -468,6 +468,23 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .AddInputType() // stream .End(); + using cinn::runtime::cuda::cinn_call_cuda_cooperative_kernel; + REGISTER_EXTERN_FUNC_HELPER(cinn_call_cuda_cooperative_kernel, + cinn::common::DefaultHostTarget()) + .SetRetType() + .AddInputType() // kernel_fn + .AddInputType() // args + .AddInputType() // num_args + .AddInputType() // grid_x + .AddInputType() // grid_y + .AddInputType() // grid_z + .AddInputType() // block_x + .AddInputType() // block_y + .AddInputType() // block_z + .AddInputType() // shared_mem + .AddInputType() // stream + .End(); + using cinn::runtime::cuda::cinn_call_cublas; REGISTER_EXTERN_FUNC_HELPER(cinn_call_cublas, cinn::common::DefaultHostTarget()) diff --git a/paddle/cinn/runtime/cuda/cuda_util.cc b/paddle/cinn/runtime/cuda/cuda_util.cc index f3e39881d779d3..f09829b7404865 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.cc +++ b/paddle/cinn/runtime/cuda/cuda_util.cc @@ -138,6 +138,56 @@ void cinn_call_cuda_kernel(void *kernel_fn, } } +void cinn_call_cuda_cooperative_kernel(void *kernel_fn, + void *v_args, + int num_args, + int grid_x, + int grid_y, + int grid_z, + int block_x, + int block_y, + int block_z, + int shared_memory_bytes, + void *stream) { + VLOG(3) << "cinn_call_cuda_cooperative_kernel, grid_dim={" << grid_x << ", " + << grid_y << ", " << grid_z << "}, block_dim={" << block_x << ", " + << block_y << ", " << block_z << "}, num_args=" << num_args + << ", shared_memory_bytes=" << shared_memory_bytes + << ", stream=" << stream << ", kernel_fn=" << kernel_fn; + + std::vector kernel_args; + { + cinn::utils::RecordEvent record_run("prepare_args", + cinn::utils::EventType::kInstruction); + kernel_args.reserve(num_args); + cinn_pod_value_t *args = static_cast(v_args); + for (int idx = 0; idx < num_args; ++idx) { + if (args[idx].type_code() == ::cinn_type_code()) { + kernel_args.emplace_back( + &((cinn_buffer_t *)(args[idx]))->memory); // NOLINT + } else { + kernel_args.emplace_back(args[idx].data_addr()); + } + } + } + + { + cinn::utils::RecordEvent record_run("cuLaunchCooperativeKernel", + cinn::utils::EventType::kInstruction); + CUDA_DRIVER_CALL( + cuLaunchCooperativeKernel(static_cast(kernel_fn), + grid_x, + grid_y, + grid_z, + block_x, + block_y, + block_z, + shared_memory_bytes, + static_cast(stream), + kernel_args.data())) + } +} + void cinn_call_cublas(void *v_args, int num_args, bool trans_a, diff --git a/paddle/cinn/runtime/cuda/cuda_util.h b/paddle/cinn/runtime/cuda/cuda_util.h index 592ef50343bf09..e236e0c9791fac 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.h +++ b/paddle/cinn/runtime/cuda/cuda_util.h @@ -112,6 +112,24 @@ void cinn_call_cuda_kernel(void* kernel_fn, int shared_memory_bytes, void* stream); +/** + * Call a CUDA compiled kernel with cooperative groups. + * + * @param kernel_fn the compiled PTX kernel. + * @param args an array of cinn_pod_value_ts(consists of scalars and buffers). + */ +void cinn_call_cuda_cooperative_kernel(void* kernel_fn, + void* v_args, + int num_args, + int grid_x, + int grid_y, + int grid_z, + int block_x, + int block_y, + int block_z, + int shared_memory_bytes, + void* stream); + void cinn_call_cublas(void* v_args, int num_args, bool trans_a, diff --git a/paddle/cinn/runtime/intrinsic.h b/paddle/cinn/runtime/intrinsic.h index 9eeec277bb9123..88cf467edd5800 100644 --- a/paddle/cinn/runtime/intrinsic.h +++ b/paddle/cinn/runtime/intrinsic.h @@ -103,6 +103,8 @@ static const char* pod_value_to_void_p = "cinn_pod_value_to_void_p"; static const char* print_debug_args_repr = "cinn_print_debug_args"; static const char* call_cuda_kernel = "cinn_call_cuda_kernel"; +static const char* call_cuda_cooperative_kernel = + "cinn_call_cuda_cooperative_kernel"; static const char* call_hip_kernel = "cinn_call_hip_kernel";