Skip to content

[CINN] [New Hardware Update]Clean obsolete HIP Reduce templates #72142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions paddle/cinn/backends/codegen_device_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
const auto &symbolic_arg_define = [&]() -> std::vector<ir::Expr> {
std::vector<ir::Expr> arg_defs;
for (const auto &item : symbolic_shape_var_index) {
#ifdef CINN_WITH_CUDA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call Node 不需要区分CUDA runtime::intrinsic::get_value_in_cuda_kernel_args和HIP/SYCL的区别吗?

ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(64),
runtime::intrinsic::get_value_in_cuda_kernel_args,
Expand All @@ -57,18 +56,6 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
ir::CallType::Extern,
ir::FunctionRef(),
0);
#elif defined(CINN_WITH_HIP)
ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(64),
runtime::intrinsic::get_value_in_hip_kernel_args,
{kernel_args, ir::Expr(item.first)},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
#else
CINN_NOT_IMPLEMENTED
#endif
ir::Expr let_symbol = ir::Expr(item.second);
let_symbol->set_type(type_of<int64_t>());
ir::Expr stmt = ir::Let::Make(let_symbol, call_get_value_in_kernel_args);
Expand Down Expand Up @@ -366,7 +353,6 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
const std::vector<ir::Argument> &args = func->args;
for (int i = 0; i < args.size(); ++i) {
if (args[i].is_var()) {
#ifdef CINN_WITH_CUDA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(64),
runtime::intrinsic::get_value_in_cuda_kernel_args,
Expand All @@ -375,18 +361,6 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
ir::CallType::Extern,
ir::FunctionRef(),
0);
#elif defined(CINN_WITH_HIP)
ir::Expr call_get_value_in_kernel_args =
ir::Call::Make(Int(64),
runtime::intrinsic::get_value_in_hip_kernel_args,
{kernel_args_, ir::Expr(i)},
{},
ir::CallType::Extern,
ir::FunctionRef(),
0);
#else
CINN_NOT_IMPLEMENTED
#endif
ir::Expr let_symbol = ir::ir_utils::IRCopy(args[i].var_arg());
let_symbol->set_type(type_of<int64_t>());
ir::stmt::StmtRef stmt =
Expand Down
26 changes: 5 additions & 21 deletions paddle/cinn/backends/codegen_invoke_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,11 @@ class CodeGenSwitchHost : public CodeGenInvokeModule {
: CodeGenInvokeModule(m, b, vars) {}
// only support call of args get function and inner case host function call
llvm::Value *Visit(const ir::Call *op) override {
return common::DefaultDeviceTarget().arch.Match(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

[&](common::NVGPUArch) -> llvm::Value * {
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
return CodeGenLLVM::Visit(op);
} else {
return LowerInnerCaseCall(op);
}
},
[&](common::HygonDCUArchHIP) -> llvm::Value * {
if (op->name == runtime::intrinsic::get_value_in_hip_kernel_args) {
return CodeGenLLVM::Visit(op);
} else {
return LowerInnerCaseCall(op);
}
},
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch,
common::HygonDCUArchSYCL>) -> llvm::Value * {
CINN_NOT_IMPLEMENTED;
});
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
return CodeGenLLVM::Visit(op);
} else {
return LowerInnerCaseCall(op);
}
}

private:
Expand Down
272 changes: 51 additions & 221 deletions paddle/cinn/runtime/hip/cinn_hip_runtime_source.h

Large diffs are not rendered by default.

16 changes: 0 additions & 16 deletions paddle/cinn/runtime/hip/hip_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@ CINN_REGISTER_HELPER(cinn_hip_host_api) {
GlobalSymbolRegistry::Global().RegisterFn(
"backend_api.hip", reinterpret_cast<void *>(HIPBackendAPI::Global()));

using cinn::runtime::hip::cinn_get_value_in_hip_kernel_args;
REGISTER_EXTERN_FUNC_HELPER(cinn_get_value_in_hip_kernel_args,
cinn::common::DefaultHostTarget())
.SetRetType<int64_t>()
.AddInputType<void *>() // args
.AddInputType<int>() // index
.End();

using cinn::runtime::hip::cinn_get_item_in_hip_kernel_args;
REGISTER_EXTERN_FUNC_HELPER(cinn_get_item_in_hip_kernel_args,
cinn::common::DefaultHostTarget())
.SetRetType<void *>()
.AddInputType<void *>() // args
.AddInputType<int>() // index
.End();

REGISTER_EXTERN_FUNC_HELPER(cinn_call_hip_kernel,
cinn::common::DefaultHostTarget())
.SetRetType<void>()
Expand Down
159 changes: 28 additions & 131 deletions paddle/cinn/runtime/hip/hip_intrinsics_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,162 +63,59 @@ CINN_REGISTER_HELPER(hip_intrinsics_reduce) {
MACRO(min_fp16, float16, ##__VA_ARGS__)
#endif

#define REGISTER_WARP_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_warp_reduce_##REDUCE_TYPE, target) \
.SetRetType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
.AddInputType<int>() \
.AddInputType<int>() \
.End();

EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_WARP_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_WARP_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)

#ifdef CINN_HIP_BF16
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
#endif

#ifdef CINN_HIP_FP16
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
#endif

#undef REGISTER_WARP_REDUCE_FUNC_IMPL

REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_warp_reduce_avg_fp32, target)
.SetRetType<float>()
.AddInputType<cinn_buffer_t *>()
.AddInputType<int>()
.AddInputType<int>()
.End();

#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
cinn_block_reduce_##REDUCE_TYPE##_internal, target) \
.SetRetType<DTYPE>() \
.AddInputType<DTYPE>() \
.End();

EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)

#ifdef CINN_HIP_BF16
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
#endif

#ifdef CINN_HIP_FP16
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
#endif

#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL

#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
cinn_block_reduce_##REDUCE_TYPE##_internal_shm, target) \
.SetRetType<DTYPE>() \
.AddInputType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
#define REGISTER_BLOCK_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE, target) \
.SetRetType<DTYPE>() \
.AddInputType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
.AddInputType<bool>() \
.End();

EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)

#ifdef CINN_HIP_BF16
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
#endif

#ifdef CINN_HIP_FP16
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
#endif

#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL

#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
cinn_partial_block_reduce_##REDUCE_TYPE##_internal_shm, target) \
.SetRetType<DTYPE>() \
.AddInputType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
.End();
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)

#ifdef CINN_HIP_BF16
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
#endif

#ifdef CINN_HIP_FP16
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
#endif

#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
#undef REGISTER_BLOCK_REDUCE_FUNC_IMPL

#define REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
cinn_discrete_reduce_##REDUCE_TYPE##_internal_shm, target) \
.SetRetType<DTYPE>() \
.AddInputType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
#define REGISTER_DISCRETE_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_discrete_reduce_##REDUCE_TYPE, \
target) \
.SetRetType<DTYPE>() \
.AddInputType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
.End();

EXPAND_REDUCE_INT32_REGISTER_MARCO(
REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(
REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)

#ifdef CINN_HIP_BF16
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
#endif

#ifdef CINN_HIP_FP16
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
#endif

#undef REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL
#undef REGISTER_DISCRETE_REDUCE_FUNC_IMPL

REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_grid_reduce_update_semaphore, target)
.SetRetType<bool>()
.AddInputType<int *>()
.End();

#define REGISTER_BLOCK_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE, target) \
.SetRetType<DTYPE>() \
.AddInputType<cinn_buffer_t *>() \
.AddInputType<int>() \
.AddInputType<int>() \
.End();

EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)

#ifdef CINN_HIP_BF16
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
#endif

#ifdef CINN_HIP_FP16
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
#endif

#undef REGISTER_BLOCK_REDUCE_FUNC_IMPL

#define REGISTER_BLOCK_SHUFFLE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
REGISTER_FACKED_EXTERN_FUNC_HELPER(block_shuffle_##REDUCE_TYPE, target) \
.SetRetType<DTYPE>() \
Expand Down
10 changes: 0 additions & 10 deletions paddle/cinn/runtime/hip/hip_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,6 @@ void infer_shape_set_value(int row, int col, int64_t value, int64_t **v) {
v[row][col] = value;
}

int64_t cinn_get_value_in_hip_kernel_args(void *v_args, int idx) {
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
return args[idx].operator int64_t();
}

void *cinn_get_item_in_hip_kernel_args(void *v_args, int idx) {
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
return static_cast<void *>(&args[idx]);
}

} // namespace hip
} // namespace runtime
} // namespace cinn
3 changes: 0 additions & 3 deletions paddle/cinn/runtime/hip/hip_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ void cinn_call_hip_kernel(void *kernel_fn,

void infer_shape_set_value(int row, int col, int64_t value, int64_t **v);

int64_t cinn_get_value_in_hip_kernel_args(void *v_args, int idx);
void *cinn_get_item_in_hip_kernel_args(void *v_args, int idx);

} // namespace hip
} // namespace runtime
} // namespace cinn
6 changes: 0 additions & 6 deletions paddle/cinn/runtime/intrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,9 @@ static const char* call_cuda_memset = "cinn_call_cuda_memset";
static const char* get_value_in_cuda_kernel_args =
"cinn_get_value_in_cuda_kernel_args";

static const char* get_value_in_hip_kernel_args =
"cinn_get_value_in_hip_kernel_args";

static const char* get_item_in_cuda_kernel_args =
"cinn_get_item_in_cuda_kernel_args";

static const char* get_item_in_hip_kernel_args =
"cinn_get_item_in_hip_kernel_args";

static const char* infer_shape_set_value = "infer_shape_set_value";

static const char* pod_values_to_array_repr = "pod_values_to_array";
Expand Down