@@ -63,162 +63,59 @@ CINN_REGISTER_HELPER(hip_intrinsics_reduce) {
63
63
MACRO (min_fp16, float16, ##__VA_ARGS__)
64
64
#endif
65
65
66
- #define REGISTER_WARP_REDUCE_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
67
- REGISTER_FACKED_EXTERN_FUNC_HELPER (cinn_warp_reduce_##REDUCE_TYPE, target) \
68
- .SetRetType <DTYPE>() \
69
- .AddInputType <cinn_buffer_t *>() \
70
- .AddInputType <int >() \
71
- .AddInputType <int >() \
72
- .End ();
73
-
74
- EXPAND_REDUCE_INT32_REGISTER_MARCO (REGISTER_WARP_REDUCE_FUNC_IMPL)
75
- EXPAND_REDUCE_INT64_REGISTER_MARCO (REGISTER_WARP_REDUCE_FUNC_IMPL)
76
- EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_WARP_REDUCE_FUNC_IMPL)
77
- EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_WARP_REDUCE_FUNC_IMPL)
78
- EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_WARP_REDUCE_FUNC_IMPL)
79
-
80
- #ifdef CINN_HIP_BF16
81
- EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_WARP_REDUCE_FUNC_IMPL)
82
- #endif
83
-
84
- #ifdef CINN_HIP_FP16
85
- EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_WARP_REDUCE_FUNC_IMPL)
86
- #endif
87
-
88
- #undef REGISTER_WARP_REDUCE_FUNC_IMPL
89
-
90
- REGISTER_FACKED_EXTERN_FUNC_HELPER (cinn_warp_reduce_avg_fp32, target)
91
- .SetRetType <float >()
92
- .AddInputType <cinn_buffer_t *>()
93
- .AddInputType <int >()
94
- .AddInputType <int >()
95
- .End ();
96
-
97
- #define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
98
- REGISTER_FACKED_EXTERN_FUNC_HELPER ( \
99
- cinn_block_reduce_##REDUCE_TYPE##_internal, target) \
100
- .SetRetType <DTYPE>() \
101
- .AddInputType <DTYPE>() \
102
- .End ();
103
-
104
- EXPAND_REDUCE_INT32_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
105
- EXPAND_REDUCE_INT64_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
106
- EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
107
- EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
108
- EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
109
-
110
- #ifdef CINN_HIP_BF16
111
- EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
112
- #endif
113
-
114
- #ifdef CINN_HIP_FP16
115
- EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
116
- #endif
117
-
118
- #undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
119
-
120
- #define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
121
- REGISTER_FACKED_EXTERN_FUNC_HELPER ( \
122
- cinn_block_reduce_##REDUCE_TYPE##_internal_shm, target) \
123
- .SetRetType <DTYPE>() \
124
- .AddInputType <DTYPE>() \
125
- .AddInputType <cinn_buffer_t *>() \
66
+ #define REGISTER_BLOCK_REDUCE_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
67
+ REGISTER_FACKED_EXTERN_FUNC_HELPER (cinn_block_reduce_##REDUCE_TYPE, target) \
68
+ .SetRetType <DTYPE>() \
69
+ .AddInputType <DTYPE>() \
70
+ .AddInputType <cinn_buffer_t *>() \
71
+ .AddInputType <bool >() \
126
72
.End ();
127
73
128
- EXPAND_REDUCE_INT32_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
129
- EXPAND_REDUCE_INT64_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
130
- EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
131
- EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
132
- EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
133
-
134
- #ifdef CINN_HIP_BF16
135
- EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
136
- #endif
137
-
138
- #ifdef CINN_HIP_FP16
139
- EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
140
- #endif
141
-
142
- #undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
143
-
144
- #define REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
145
- REGISTER_FACKED_EXTERN_FUNC_HELPER ( \
146
- cinn_partial_block_reduce_##REDUCE_TYPE##_internal_shm, target) \
147
- .SetRetType <DTYPE>() \
148
- .AddInputType <DTYPE>() \
149
- .AddInputType <cinn_buffer_t *>() \
150
- .End ();
151
- EXPAND_REDUCE_INT32_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
152
- EXPAND_REDUCE_INT64_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
153
- EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
154
- EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
155
- EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL)
74
+ EXPAND_REDUCE_INT32_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
75
+ EXPAND_REDUCE_INT64_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
76
+ EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
77
+ EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
78
+ EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
156
79
157
80
#ifdef CINN_HIP_BF16
158
- EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL )
81
+ EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL )
159
82
#endif
160
83
161
84
#ifdef CINN_HIP_FP16
162
- EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL )
85
+ EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL )
163
86
#endif
164
87
165
- #undef REGISTER_BLOCK_REDUCE_INTERNAL_FUNC_IMPL
88
+ #undef REGISTER_BLOCK_REDUCE_FUNC_IMPL
166
89
167
- #define REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
168
- REGISTER_FACKED_EXTERN_FUNC_HELPER ( \
169
- cinn_discrete_reduce_##REDUCE_TYPE##_internal_shm, target) \
170
- .SetRetType <DTYPE>() \
171
- .AddInputType <DTYPE>() \
172
- .AddInputType <cinn_buffer_t *>() \
90
+ #define REGISTER_DISCRETE_REDUCE_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
91
+ REGISTER_FACKED_EXTERN_FUNC_HELPER (cinn_discrete_reduce_##REDUCE_TYPE, \
92
+ target) \
93
+ .SetRetType <DTYPE>() \
94
+ .AddInputType <DTYPE>() \
95
+ .AddInputType <cinn_buffer_t *>() \
173
96
.End ();
174
97
175
- EXPAND_REDUCE_INT32_REGISTER_MARCO (
176
- REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
177
- EXPAND_REDUCE_INT64_REGISTER_MARCO (
178
- REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
179
- EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
180
- EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
181
- EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL)
98
+ EXPAND_REDUCE_INT32_REGISTER_MARCO (REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
99
+ EXPAND_REDUCE_INT64_REGISTER_MARCO (REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
100
+ EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
101
+ EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
102
+ EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_FUNC_IMPL)
182
103
183
104
#ifdef CINN_HIP_BF16
184
- EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL )
105
+ EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_FUNC_IMPL )
185
106
#endif
186
107
187
108
#ifdef CINN_HIP_FP16
188
- EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL )
109
+ EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_DISCRETE_REDUCE_FUNC_IMPL )
189
110
#endif
190
111
191
- #undef REGISTER_DISCRETE_REDUCE_INTERNAL_FUNC_IMPL
112
+ #undef REGISTER_DISCRETE_REDUCE_FUNC_IMPL
192
113
193
114
REGISTER_FACKED_EXTERN_FUNC_HELPER (cinn_grid_reduce_update_semaphore, target)
194
115
.SetRetType <bool >()
195
116
.AddInputType <int *>()
196
117
.End ();
197
118
198
- #define REGISTER_BLOCK_REDUCE_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
199
- REGISTER_FACKED_EXTERN_FUNC_HELPER (cinn_block_reduce_##REDUCE_TYPE, target) \
200
- .SetRetType <DTYPE>() \
201
- .AddInputType <cinn_buffer_t *>() \
202
- .AddInputType <int >() \
203
- .AddInputType <int >() \
204
- .End ();
205
-
206
- EXPAND_REDUCE_INT32_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
207
- EXPAND_REDUCE_INT64_REGISTER_MARCO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
208
- EXPAND_REDUCE_FP32_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
209
- EXPAND_REDUCE_FP64_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
210
- EXPAND_REDUCE_BOOL_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
211
-
212
- #ifdef CINN_HIP_BF16
213
- EXPAND_REDUCE_BF16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
214
- #endif
215
-
216
- #ifdef CINN_HIP_FP16
217
- EXPAND_REDUCE_FP16_REGISTER_MACRO (REGISTER_BLOCK_REDUCE_FUNC_IMPL)
218
- #endif
219
-
220
- #undef REGISTER_BLOCK_REDUCE_FUNC_IMPL
221
-
222
119
#define REGISTER_BLOCK_SHUFFLE_FUNC_IMPL (REDUCE_TYPE, DTYPE ) \
223
120
REGISTER_FACKED_EXTERN_FUNC_HELPER (block_shuffle_##REDUCE_TYPE, target) \
224
121
.SetRetType <DTYPE>() \
0 commit comments