Skip to content

Commit 6d3f14b

Browse files
authored
[CINN] [New Hardware Update]Clean obsolete HIP Reduce templates (#72142)
* remove hip call_get_value_in_kernel_args * fix hip reduce * remove cinn_get_value_in_hip_kernel fix bugs
1 parent e86e638 commit 6d3f14b

8 files changed

+84
-434
lines changed

paddle/cinn/backends/codegen_device_util.cc

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
4848
const auto &symbolic_arg_define = [&]() -> std::vector<ir::Expr> {
4949
std::vector<ir::Expr> arg_defs;
5050
for (const auto &item : symbolic_shape_var_index) {
51-
#ifdef CINN_WITH_CUDA
5251
ir::Expr call_get_value_in_kernel_args =
5352
ir::Call::Make(Int(64),
5453
runtime::intrinsic::get_value_in_cuda_kernel_args,
@@ -57,18 +56,6 @@ ir::Module CreateSwitchWithBroadcastConditionModule(
5756
ir::CallType::Extern,
5857
ir::FunctionRef(),
5958
0);
60-
#elif defined(CINN_WITH_HIP)
61-
ir::Expr call_get_value_in_kernel_args =
62-
ir::Call::Make(Int(64),
63-
runtime::intrinsic::get_value_in_hip_kernel_args,
64-
{kernel_args, ir::Expr(item.first)},
65-
{},
66-
ir::CallType::Extern,
67-
ir::FunctionRef(),
68-
0);
69-
#else
70-
CINN_NOT_IMPLEMENTED
71-
#endif
7259
ir::Expr let_symbol = ir::Expr(item.second);
7360
let_symbol->set_type(type_of<int64_t>());
7461
ir::Expr stmt = ir::Let::Make(let_symbol, call_get_value_in_kernel_args);
@@ -373,7 +360,6 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
373360
const std::vector<ir::Argument> &args = func->args;
374361
for (int i = 0; i < args.size(); ++i) {
375362
if (args[i].is_var()) {
376-
#ifdef CINN_WITH_CUDA
377363
ir::Expr call_get_value_in_kernel_args =
378364
ir::Call::Make(Int(64),
379365
runtime::intrinsic::get_value_in_cuda_kernel_args,
@@ -382,18 +368,6 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(
382368
ir::CallType::Extern,
383369
ir::FunctionRef(),
384370
0);
385-
#elif defined(CINN_WITH_HIP)
386-
ir::Expr call_get_value_in_kernel_args =
387-
ir::Call::Make(Int(64),
388-
runtime::intrinsic::get_value_in_hip_kernel_args,
389-
{kernel_args_, ir::Expr(i)},
390-
{},
391-
ir::CallType::Extern,
392-
ir::FunctionRef(),
393-
0);
394-
#else
395-
CINN_NOT_IMPLEMENTED
396-
#endif
397371
ir::Expr let_symbol = ir::ir_utils::IRCopy(args[i].var_arg());
398372
let_symbol->set_type(type_of<int64_t>());
399373
ir::stmt::StmtRef stmt =

paddle/cinn/backends/codegen_invoke_module.h

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -68,27 +68,11 @@ class CodeGenSwitchHost : public CodeGenInvokeModule {
6868
: CodeGenInvokeModule(m, b, vars) {}
6969
// only support call of args get function and inner case host function call
7070
llvm::Value *Visit(const ir::Call *op) override {
71-
return common::DefaultDeviceTarget().arch.Match(
72-
[&](common::NVGPUArch) -> llvm::Value * {
73-
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
74-
return CodeGenLLVM::Visit(op);
75-
} else {
76-
return LowerInnerCaseCall(op);
77-
}
78-
},
79-
[&](common::HygonDCUArchHIP) -> llvm::Value * {
80-
if (op->name == runtime::intrinsic::get_value_in_hip_kernel_args) {
81-
return CodeGenLLVM::Visit(op);
82-
} else {
83-
return LowerInnerCaseCall(op);
84-
}
85-
},
86-
[&](std::variant<common::UnknownArch,
87-
common::X86Arch,
88-
common::ARMArch,
89-
common::HygonDCUArchSYCL>) -> llvm::Value * {
90-
CINN_NOT_IMPLEMENTED;
91-
});
71+
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
72+
return CodeGenLLVM::Visit(op);
73+
} else {
74+
return LowerInnerCaseCall(op);
75+
}
9276
}
9377

9478
private:

paddle/cinn/runtime/hip/cinn_hip_runtime_source.h

Lines changed: 51 additions & 221 deletions
Large diffs are not rendered by default.

paddle/cinn/runtime/hip/hip_intrinsics.cc

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,6 @@ CINN_REGISTER_HELPER(cinn_hip_host_api) {
2323
GlobalSymbolRegistry::Global().RegisterFn(
2424
"backend_api.hip", reinterpret_cast<void *>(HIPBackendAPI::Global()));
2525

26-
using cinn::runtime::hip::cinn_get_value_in_hip_kernel_args;
27-
REGISTER_EXTERN_FUNC_HELPER(cinn_get_value_in_hip_kernel_args,
28-
cinn::common::DefaultHostTarget())
29-
.SetRetType<int64_t>()
30-
.AddInputType<void *>() // args
31-
.AddInputType<int>() // index
32-
.End();
33-
34-
using cinn::runtime::hip::cinn_get_item_in_hip_kernel_args;
35-
REGISTER_EXTERN_FUNC_HELPER(cinn_get_item_in_hip_kernel_args,
36-
cinn::common::DefaultHostTarget())
37-
.SetRetType<void *>()
38-
.AddInputType<void *>() // args
39-
.AddInputType<int>() // index
40-
.End();
41-
4226
REGISTER_EXTERN_FUNC_HELPER(cinn_call_hip_kernel,
4327
cinn::common::DefaultHostTarget())
4428
.SetRetType<void>()

paddle/cinn/runtime/hip/hip_intrinsics_reduce.cc

Lines changed: 28 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -63,162 +63,59 @@ CINN_REGISTER_HELPER(hip_intrinsics_reduce) {
6363
MACRO(min_fp16, float16, ##__VA_ARGS__)
6464
#endif
6565

66-
#define REGISTER_WARP_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
67-
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_warp_reduce_##REDUCE_TYPE, target) \
68-
.SetRetType<DTYPE>() \
69-
.AddInputType<cinn_buffer_t *>() \
70-
.AddInputType<int>() \
71-
.AddInputType<int>() \
72-
.End();
73-
74-
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_WARP_REDUCE_FUNC_IMPL)
75-
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_WARP_REDUCE_FUNC_IMPL)
76-
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
77-
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
78-
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
79-
80-
#ifdef CINN_HIP_BF16
81-
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
82-
#endif
83-
84-
#ifdef CINN_HIP_FP16
85-
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_WARP_REDUCE_FUNC_IMPL)
86-
#endif
87-
88-
#undef REGISTER_WARP_REDUCE_FUNC_IMPL
89-
90-
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_warp_reduce_avg_fp32, target)
91-
.SetRetType<float>()
92-
.AddInputType<cinn_buffer_t *>()
93-
.AddInputType<int>()
94-
.AddInputType<int>()
95-
.End();
96-
97-
#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
98-
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
99-
cinn_block_reduce_##REDUCE_TYPE##_internal, target) \
100-
.SetRetType<DTYPE>() \
101-
.AddInputType<DTYPE>() \
102-
.End();
103-
104-
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
105-
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
106-
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
107-
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
108-
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
109-
110-
#ifdef CINN_HIP_BF16
111-
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
112-
#endif
113-
114-
#ifdef CINN_HIP_FP16
115-
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
116-
#endif
117-
118-
#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
119-
120-
#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
121-
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
122-
cinn_block_reduce_##REDUCE_TYPE##_internal_shm, target) \
123-
.SetRetType<DTYPE>() \
124-
.AddInputType<DTYPE>() \
125-
.AddInputType<cinn_buffer_t *>() \
66+
#define REGISTER_BLOCK_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
67+
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE, target) \
68+
.SetRetType<DTYPE>() \
69+
.AddInputType<DTYPE>() \
70+
.AddInputType<cinn_buffer_t *>() \
71+
.AddInputType<bool>() \
12672
.End();
12773

128-
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
129-
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
130-
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
131-
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
132-
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
133-
134-
#ifdef CINN_HIP_BF16
135-
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
136-
#endif
137-
138-
#ifdef CINN_HIP_FP16
139-
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
140-
#endif
141-
142-
#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
143-
144-
#define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
145-
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
146-
cinn_partial_block_reduce_##REDUCE_TYPE##_internal_shm, target) \
147-
.SetRetType<DTYPE>() \
148-
.AddInputType<DTYPE>() \
149-
.AddInputType<cinn_buffer_t *>() \
150-
.End();
151-
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
152-
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
153-
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
154-
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
155-
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
74+
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
75+
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
76+
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
77+
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
78+
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
15679

15780
#ifdef CINN_HIP_BF16
158-
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
81+
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
15982
#endif
16083

16184
#ifdef CINN_HIP_FP16
162-
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
85+
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
16386
#endif
16487

165-
#undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
88+
#undef REGISTER_BLOCK_REDUCE_FUNC_IMPL
16689

167-
#define REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
168-
REGISTER_FACKED_EXTERN_FUNC_HELPER( \
169-
cinn_discrete_reduce_##REDUCE_TYPE##_internal_shm, target) \
170-
.SetRetType<DTYPE>() \
171-
.AddInputType<DTYPE>() \
172-
.AddInputType<cinn_buffer_t *>() \
90+
#define REGISTER_DISCRETE_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
91+
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_discrete_reduce_##REDUCE_TYPE, \
92+
target) \
93+
.SetRetType<DTYPE>() \
94+
.AddInputType<DTYPE>() \
95+
.AddInputType<cinn_buffer_t *>() \
17396
.End();
17497

175-
EXPAND_REDUCE_INT32_REGISTER_MARCO(
176-
REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
177-
EXPAND_REDUCE_INT64_REGISTER_MARCO(
178-
REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
179-
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
180-
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
181-
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
98+
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
99+
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
100+
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
101+
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
102+
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
182103

183104
#ifdef CINN_HIP_BF16
184-
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
105+
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
185106
#endif
186107

187108
#ifdef CINN_HIP_FP16
188-
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
109+
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
189110
#endif
190111

191-
#undef REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL
112+
#undef REGISTER_DISCRETE_REDUCE_FUNC_IMPL
192113

193114
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_grid_reduce_update_semaphore, target)
194115
.SetRetType<bool>()
195116
.AddInputType<int *>()
196117
.End();
197118

198-
#define REGISTER_BLOCK_REDUCE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
199-
REGISTER_FACKED_EXTERN_FUNC_HELPER(cinn_block_reduce_##REDUCE_TYPE, target) \
200-
.SetRetType<DTYPE>() \
201-
.AddInputType<cinn_buffer_t *>() \
202-
.AddInputType<int>() \
203-
.AddInputType<int>() \
204-
.End();
205-
206-
EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
207-
EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
208-
EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
209-
EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
210-
EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
211-
212-
#ifdef CINN_HIP_BF16
213-
EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
214-
#endif
215-
216-
#ifdef CINN_HIP_FP16
217-
EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_REDUCE_FUNC_IMPL)
218-
#endif
219-
220-
#undef REGISTER_BLOCK_REDUCE_FUNC_IMPL
221-
222119
#define REGISTER_BLOCK_SHUFFLE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \
223120
REGISTER_FACKED_EXTERN_FUNC_HELPER(block_shuffle_##REDUCE_TYPE, target) \
224121
.SetRetType<DTYPE>() \

paddle/cinn/runtime/hip/hip_util.cc

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,6 @@ void infer_shape_set_value(int row, int col, int64_t value, int64_t **v) {
7777
v[row][col] = value;
7878
}
7979

80-
int64_t cinn_get_value_in_hip_kernel_args(void *v_args, int idx) {
81-
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
82-
return args[idx].operator int64_t();
83-
}
84-
85-
void *cinn_get_item_in_hip_kernel_args(void *v_args, int idx) {
86-
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
87-
return static_cast<void *>(&args[idx]);
88-
}
89-
9080
} // namespace hip
9181
} // namespace runtime
9282
} // namespace cinn

paddle/cinn/runtime/hip/hip_util.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ void cinn_call_hip_kernel(void *kernel_fn,
6969

7070
void infer_shape_set_value(int row, int col, int64_t value, int64_t **v);
7171

72-
int64_t cinn_get_value_in_hip_kernel_args(void *v_args, int idx);
73-
void *cinn_get_item_in_hip_kernel_args(void *v_args, int idx);
74-
7572
} // namespace hip
7673
} // namespace runtime
7774
} // namespace cinn

paddle/cinn/runtime/intrinsic.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,9 @@ static const char* call_cuda_memset = "cinn_call_cuda_memset";
115115
static const char* get_value_in_cuda_kernel_args =
116116
"cinn_get_value_in_cuda_kernel_args";
117117

118-
static const char* get_value_in_hip_kernel_args =
119-
"cinn_get_value_in_hip_kernel_args";
120-
121118
static const char* get_item_in_cuda_kernel_args =
122119
"cinn_get_item_in_cuda_kernel_args";
123120

124-
static const char* get_item_in_hip_kernel_args =
125-
"cinn_get_item_in_hip_kernel_args";
126-
127121
static const char* infer_shape_set_value = "infer_shape_set_value";
128122

129123
static const char* pod_values_to_array_repr = "pod_values_to_array";

0 commit comments

Comments
 (0)