Skip to content

Commit 2cd86ea

Browse files
authored
[CINN] Use cooperative_groups for grid reduce synchronization (#71999)
1 parent 3e00f80 commit 2cd86ea

13 files changed

+95
-326
lines changed

paddle/cinn/backends/codegen_cuda_dev.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using cinn::common::float162;
3434
using cinn::common::bfloat168;
3535
using cinn::common::bfloat164;
3636
using cinn::common::bfloat162;
37+
#include <cooperative_groups.h>
3738
#include "cinn_cuda_runtime_source.cuh"
3839
)";
3940
const std::string CodeGenCudaDev::source_header_ = // NOLINT
@@ -55,8 +56,8 @@ using cinn::common::float162;
5556
using cinn::common::bfloat168;
5657
using cinn::common::bfloat164;
5758
using cinn::common::bfloat162;
59+
#include <cooperative_groups.h>
5860
#include <cinn_cuda_runtime_source_h>
59-
6061
)";
6162

6263
const std::string &CodeGenCudaDev::GetSourceHeader() { return source_header_; }

paddle/cinn/backends/codegen_device_util.cc

+11-4
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ static std::string CurTailFnName(const std::string &origin_fn_name) {
177177
return new_fn_name;
178178
}
179179

180+
bool RequiresCooperativeLaunch(const ir::LoweredFunc &func) {
181+
for (auto &space : func->temp_spaces) {
182+
if (space.size() != ir::Expr(0)) {
183+
return true;
184+
}
185+
}
186+
return false;
187+
}
188+
180189
std::string
181190
detail::CollectBucketStrategyHostFunctionVisitor::GenDeviceKernelName(
182191
const std::string &fn_name, ir::Expr predicate) {
@@ -257,10 +266,8 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
257266
CINN_NOT_IMPLEMENTED;
258267
},
259268
[&](common::NVGPUArch) {
260-
// TODO(liangshuhao): when cooperative group is supported, change the
261-
// second call to `call_cuda_cooperative_kernel`.
262-
call_kernel = func->temp_spaces.empty()
263-
? runtime::intrinsic::call_cuda_kernel
269+
call_kernel = RequiresCooperativeLaunch(func)
270+
? runtime::intrinsic::call_cuda_cooperative_kernel
264271
: runtime::intrinsic::call_cuda_kernel;
265272
},
266273
[&](common::HygonDCUArchHIP) {

paddle/cinn/common/target.cc

+45
Original file line numberDiff line numberDiff line change
@@ -409,5 +409,50 @@ const Target &DefaultTarget() {
409409
#endif
410410
}
411411

412+
bool GetSupportsCooperativeLaunchImpl(UnknownArch) {
413+
LOG(FATAL)
414+
<< "The target is not GPU! Cannot get supports cooperative launch.";
415+
}
416+
417+
bool GetSupportsCooperativeLaunchImpl(X86Arch) {
418+
LOG(FATAL)
419+
<< "The target is not GPU! Cannot get supports cooperative launch.";
420+
}
421+
422+
bool GetSupportsCooperativeLaunchImpl(ARMArch) {
423+
LOG(FATAL)
424+
<< "The target is not GPU! Cannot get supports cooperative launch.";
425+
}
426+
427+
bool GetSupportsCooperativeLaunchImpl(NVGPUArch) {
428+
int supportsCoopLaunch = 0;
429+
#ifdef CINN_WITH_CUDA
430+
cudaDeviceGetAttribute(&supportsCoopLaunch, cudaDevAttrCooperativeLaunch, 0);
431+
#endif
432+
return supportsCoopLaunch != 0;
433+
}
434+
435+
bool GetSupportsCooperativeLaunchImpl(HygonDCUArchHIP) {
436+
CINN_NOT_IMPLEMENTED
437+
LOG(FATAL)
438+
<< "The target is not GPU! Cannot get supports cooperative launch.";
439+
}
440+
441+
bool GetSupportsCooperativeLaunchImpl(HygonDCUArchSYCL) {
442+
CINN_NOT_IMPLEMENTED
443+
LOG(FATAL)
444+
<< "The target is not GPU! Cannot get supports cooperative launch.";
445+
}
446+
447+
bool GetSupportsCooperativeLaunch(Arch arch) {
448+
return std::visit(
449+
[](const auto &impl) { return GetSupportsCooperativeLaunchImpl(impl); },
450+
arch.variant());
451+
}
452+
453+
bool Target::get_supports_cooperative_launch() const {
454+
return GetSupportsCooperativeLaunch(arch);
455+
}
456+
412457
} // namespace common
413458
} // namespace cinn

paddle/cinn/common/target.h

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ struct Target {
8787

8888
std::vector<Lib> get_target_libs() const;
8989

90+
bool get_supports_cooperative_launch() const;
91+
9092
std::string arch_str() const;
9193

9294
std::string device_name_str() const;

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

+4
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(
141141
fusion_group_info->can_apply_grid_reduce = false;
142142
}
143143

144+
if (!target_.get_supports_cooperative_launch()) {
145+
fusion_group_info->can_apply_grid_reduce = false;
146+
}
147+
144148
if (FLAGS_cinn_check_tensor_buffer_map) {
145149
optim::CheckTensorBufferMap(func_bodies, "BucketLower OpFusion");
146150
VLOG(3) << "OpFusion tensor-buffer map check succeed";

paddle/cinn/hlir/framework/pir/trivial_op_impl.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,7 @@ std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
607607
}
608608

609609
if (FLAGS_cinn_enable_grid_reduce) {
610-
group_info->can_apply_grid_reduce =
611-
GetCanApplyGridReduce(op_compute_bodies, group_info->reduce_axis);
610+
group_info->can_apply_grid_reduce = true;
612611
}
613612

614613
if (FLAGS_cinn_enable_vectorize) {

paddle/cinn/ir/group_schedule/config/group_tile_util.cc

-76
Original file line numberDiff line numberDiff line change
@@ -386,82 +386,6 @@ std::vector<int64_t> GetLoopStrides(const ir::Expr& body) {
386386
return loop_strides;
387387
}
388388

389-
bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
390-
const std::vector<int64_t>& reduce_axis) {
391-
// Names of tensors that are downstream of reduce.
392-
// A tensor is downstream of reduce either if it is produced by a reduce, or
393-
// if it has data dependency on another tensor that is downstream of reduce.
394-
std::unordered_set<std::string> reduce_downstream_tensor_names;
395-
396-
const auto IsReduceDownstream = [&](const ir::Expr& expr_block) {
397-
for (auto& expr_load : GetRValueLoads(expr_block)) {
398-
std::string load_tensor_name = expr_load.As<ir::Load>()->name();
399-
if (reduce_downstream_tensor_names.count(load_tensor_name) > 0) {
400-
return true;
401-
}
402-
}
403-
return false;
404-
};
405-
406-
const auto AddReduceDownstream = [&](const ir::Expr& expr_block) {
407-
auto expr_store = analyzer::GetStoreOfSBlock(expr_block);
408-
std::string store_tensor_name = expr_store.As<ir::Store>()->name();
409-
reduce_downstream_tensor_names.insert(store_tensor_name);
410-
};
411-
412-
const auto CheckOutputHasReduceAxis = [&](const ir::Expr& body,
413-
const ir::Expr& expr_block) {
414-
std::vector<ir::Var> all_loop_vars = GetAllForIters(body);
415-
std::unordered_set<std::string> reduce_loop_vars;
416-
for (int64_t axis : reduce_axis) {
417-
reduce_loop_vars.insert(all_loop_vars[axis]->name);
418-
}
419-
420-
std::unordered_set<std::string> reduce_iter_vars;
421-
auto* block = expr_block.As<ir::ScheduleBlockRealize>();
422-
auto& iter_vars = block->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
423-
for (int i = 0; i < iter_vars.size(); i++) {
424-
if (block->iter_values[i].is_var() &&
425-
reduce_loop_vars.count(block->iter_values[i].as_var()->name) > 0) {
426-
reduce_iter_vars.insert(iter_vars[i]->name);
427-
}
428-
}
429-
430-
// The result is true if the indices of the output tensor contain any
431-
// reduce iter vars.
432-
auto expr_store = analyzer::GetStoreOfSBlock(expr_block);
433-
for (auto& index_expr : expr_store.As<ir::Store>()->indices) {
434-
if (index_expr.is_var() &&
435-
reduce_iter_vars.count(index_expr.as_var_ref()->name) > 0) {
436-
return true;
437-
}
438-
}
439-
return false;
440-
};
441-
442-
for (const auto& body : op_compute_bodies) {
443-
ir::Expr expr_block =
444-
(ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit)
445-
.GetSingle(body);
446-
bool is_reduce = analyzer::IsReductionSBlock(expr_block);
447-
bool is_reduce_downstream = IsReduceDownstream(expr_block);
448-
bool output_has_reduce_axis = CheckOutputHasReduceAxis(body, expr_block);
449-
450-
if (is_reduce_downstream || is_reduce) {
451-
AddReduceDownstream(expr_block);
452-
}
453-
454-
// When a block is downstream of reduce, its loop iters shouldn't contain
455-
// any reduce axis. Otherwise, it broadcasts the result of reduce. If this
456-
// is the case, we cannot apply grid reduce.
457-
if (is_reduce_downstream && (is_reduce || output_has_reduce_axis)) {
458-
VLOG(4) << "grid reduce is prohibited by block: " << expr_block;
459-
return false;
460-
}
461-
}
462-
return true;
463-
}
464-
465389
GroupVectorizeInfo GetGroupVectorizeInfo(
466390
const std::vector<ir::Expr>& op_compute_bodies,
467391
const std::unordered_set<std::string>& group_args) {

paddle/cinn/ir/group_schedule/config/group_tile_util.h

-6
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,6 @@ namespace ir {
5656
*/
5757
std::vector<int64_t> GetLoopStrides(const ir::Expr& reduce_compute_body);
5858

59-
// Check whether we can apply grid reduce in this group.
60-
// We can apply grid reduce if there is no reduce-then-broadcast dependency
61-
// in this group.
62-
bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
63-
const std::vector<int64_t>& reduce_axis);
64-
6559
// Check whether we can apply vectorize in this group.
6660
GroupVectorizeInfo GetGroupVectorizeInfo(
6761
const std::vector<ir::Expr>& op_compute_bodies,

0 commit comments

Comments
 (0)