@@ -132,8 +132,102 @@ struct BroadcastDataLoader {
132
132
}
133
133
};
134
134
135
+ /* BroadcastDataLoaders Partial specialization */
135
136
#ifndef PADDLE_WITH_XPU_KP
136
- // FIXME: add BroadcastDataLoaders Partial specialization here
137
+ // Scalar elementwise Loader with consideration of IsBoundary.
138
+ template <int Index, int VecSize>
139
+ struct BroadcastDataLoader <Index, VecSize, true , kElementwise > {
140
+ template <typename Array1, typename Array2, typename Array3, typename ArgsT>
141
+ static __device__ __forceinline__ void Apply (const Array1 &ins,
142
+ ArgsT *args,
143
+ const Array2 &configs,
144
+ const Array3 &use_broadcast,
145
+ const int block_offset,
146
+ const int num,
147
+ const uint32_t numel) {
148
+ using Type = std::tuple_element_t <Index, ArgsT>;
149
+ int thread_offset = threadIdx.x * VecSize + block_offset;
150
+ #pragma unroll
151
+ for (int idx = 0 ; idx < VecSize; ++idx) {
152
+ std::get<Index>(args[idx]) = static_cast <Type>(1 );
153
+ int index = thread_offset + idx;
154
+ if (index < numel) {
155
+ std::get<Index>(args[idx]) =
156
+ reinterpret_cast <const _ptr_ Type *>(ins[Index])[index ];
157
+ }
158
+ }
159
+ }
160
+ };
161
+
162
+ // Vectorized elementwise Loader without consideration of IsBoundary.
163
+ template <int Index, int VecSize>
164
+ struct BroadcastDataLoader <Index, VecSize, false , kElementwise > {
165
+ template <typename Array1, typename Array2, typename Array3, typename ArgsT>
166
+ static __device__ __forceinline__ void Apply (const Array1 &ins,
167
+ ArgsT *args,
168
+ const Array2 &configs,
169
+ const Array3 &use_broadcast,
170
+ const int block_offset,
171
+ const int num,
172
+ const uint32_t numel) {
173
+ using Type = std::tuple_element_t <Index, ArgsT>;
174
+ using VecType = phi::kps::details::VectorType<Type, VecSize>;
175
+ VecType vec_temp;
176
+
177
+ int thread_offset = threadIdx.x + blockIdx.x * blockDim.x ;
178
+ const VecType *__restrict__ vec_input =
179
+ reinterpret_cast <const VecType *__restrict__>(ins[Index]);
180
+ vec_temp = vec_input[thread_offset];
181
+ #pragma unroll
182
+ for (int idx = 0 ; idx < VecSize; ++idx) {
183
+ std::get<Index>(args[idx]) = vec_temp.val [idx];
184
+ }
185
+ }
186
+ };
187
+
188
+ // Common broadcast data loader.
189
+ template <int Index, int VecSize, bool IsBoundary>
190
+ struct BroadcastDataLoader <Index, VecSize, IsBoundary, kBroadcast > {
191
+ template <typename Array1, typename Array2, typename Array3, typename ArgsT>
192
+ static __device__ __forceinline__ void Apply (const Array1 &ins,
193
+ ArgsT *args,
194
+ const Array2 &configs,
195
+ const Array3 &use_broadcast,
196
+ const int block_offset,
197
+ const int num,
198
+ const uint32_t numel) {
199
+ using Type = std::tuple_element_t <Index, ArgsT>;
200
+ uint32_t index_bc[VecSize];
201
+ #pragma unroll
202
+ for (int k = 0 ; k < VecSize; ++k) {
203
+ index_bc[k] = 0 ;
204
+ std::get<Index>(args[k]) = static_cast <Type>(1 );
205
+ }
206
+
207
+ uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
208
+ #pragma unroll
209
+ for (int k = 0 ; k < VecSize; ++k) {
210
+ uint32_t idx = thread_offset + k;
211
+ if (IsBoundary && idx == numel) {
212
+ break ;
213
+ }
214
+ #pragma unroll
215
+ for (int i = 0 ; i < phi::DDim::kMaxRank ; ++i) {
216
+ if (i == configs[0 ].rank ) break ;
217
+ auto fast_divmoder = configs[0 ].divmoders [i].Divmod (idx);
218
+ idx = fast_divmoder.val [0 ];
219
+ index_bc[k] += fast_divmoder.val [1 ] * configs[Index].strides [i];
220
+ }
221
+ }
222
+
223
+ #pragma unroll
224
+ for (int k = 0 ; k < VecSize; ++k) {
225
+ std::get<Index>(args[k]) =
226
+ reinterpret_cast <const _ptr_ Type *>(ins[Index])[index_bc[k]];
227
+ }
228
+ }
229
+ };
230
+
137
231
#endif
138
232
139
233
// static broadcast unroller
@@ -685,7 +779,6 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
685
779
};
686
780
#endif
687
781
688
- // FIXME: delete ElementwiseType
689
782
template <ElementwiseType ET,
690
783
typename OutT,
691
784
typename Functor,
@@ -825,8 +918,6 @@ void BroadcastKernelForDifferentVecSize(
825
918
}
826
919
}
827
920
828
- // FIXME: delete (ElementwiseType ET)
829
- // default: axis = -1
830
921
template <ElementwiseType ET,
831
922
typename InT,
832
923
typename OutT,
@@ -839,7 +930,6 @@ void BroadcastKernel(const KPDevice &ctx,
839
930
Functor func) {
840
931
// When there are multiple inputs, the outputs's rank should be equal the
841
932
// maximum rank of all inputs.
842
- // FIXME: delete ET ?
843
933
using Traits = phi::funcs::FunctionTraits<Functor>;
844
934
const int kArity = Traits::arity;
845
935
PADDLE_ENFORCE_EQ (
@@ -888,7 +978,7 @@ void ElementwiseCompute(const GPUContext &dev_ctx,
888
978
std::vector<const DenseTensor *> ins = {&x, &y};
889
979
std::vector<DenseTensor *> outs = {z};
890
980
dev_ctx.template Alloc <OutType>(z);
891
- // FIXME: delete ElementwiseType
981
+
892
982
BroadcastKernel<ElementwiseType::kBinary , T, OutType, Functor, 1 >(
893
983
dev_ctx, ins, &outs, axis, func);
894
984
}
0 commit comments