Skip to content

Commit df71250

Browse files
committed
BroadcastDataLoader template partial specialization
1 parent b26a777 commit df71250

File tree

2 files changed

+96
-7
lines changed

2 files changed

+96
-7
lines changed

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,102 @@ struct BroadcastDataLoader {
132132
}
133133
};
134134

135+
/* BroadcastDataLoaders Partial specialization */
135136
#ifndef PADDLE_WITH_XPU_KP
136-
// FIXME: add BroadcastDataLoaders Partial specialization here
137+
// Scalar elementwise Loader with consideration of IsBoundary.
138+
template <int Index, int VecSize>
139+
struct BroadcastDataLoader<Index, VecSize, true, kElementwise> {
140+
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
141+
static __device__ __forceinline__ void Apply(const Array1 &ins,
142+
ArgsT *args,
143+
const Array2 &configs,
144+
const Array3 &use_broadcast,
145+
const int block_offset,
146+
const int num,
147+
const uint32_t numel) {
148+
using Type = std::tuple_element_t<Index, ArgsT>;
149+
int thread_offset = threadIdx.x * VecSize + block_offset;
150+
#pragma unroll
151+
for (int idx = 0; idx < VecSize; ++idx) {
152+
std::get<Index>(args[idx]) = static_cast<Type>(1);
153+
int index = thread_offset + idx;
154+
if (index < numel) {
155+
std::get<Index>(args[idx]) =
156+
reinterpret_cast<const _ptr_ Type *>(ins[Index])[index];
157+
}
158+
}
159+
}
160+
};
161+
162+
// Vectorized elementwise Loader without consideration of IsBoundary.
163+
template <int Index, int VecSize>
164+
struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
165+
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
166+
static __device__ __forceinline__ void Apply(const Array1 &ins,
167+
ArgsT *args,
168+
const Array2 &configs,
169+
const Array3 &use_broadcast,
170+
const int block_offset,
171+
const int num,
172+
const uint32_t numel) {
173+
using Type = std::tuple_element_t<Index, ArgsT>;
174+
using VecType = phi::kps::details::VectorType<Type, VecSize>;
175+
VecType vec_temp;
176+
177+
int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
178+
const VecType *__restrict__ vec_input =
179+
reinterpret_cast<const VecType *__restrict__>(ins[Index]);
180+
vec_temp = vec_input[thread_offset];
181+
#pragma unroll
182+
for (int idx = 0; idx < VecSize; ++idx) {
183+
std::get<Index>(args[idx]) = vec_temp.val[idx];
184+
}
185+
}
186+
};
187+
188+
// Common broadcast data loader.
189+
template <int Index, int VecSize, bool IsBoundary>
190+
struct BroadcastDataLoader<Index, VecSize, IsBoundary, kBroadcast> {
191+
template <typename Array1, typename Array2, typename Array3, typename ArgsT>
192+
static __device__ __forceinline__ void Apply(const Array1 &ins,
193+
ArgsT *args,
194+
const Array2 &configs,
195+
const Array3 &use_broadcast,
196+
const int block_offset,
197+
const int num,
198+
const uint32_t numel) {
199+
using Type = std::tuple_element_t<Index, ArgsT>;
200+
uint32_t index_bc[VecSize];
201+
#pragma unroll
202+
for (int k = 0; k < VecSize; ++k) {
203+
index_bc[k] = 0;
204+
std::get<Index>(args[k]) = static_cast<Type>(1);
205+
}
206+
207+
uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
208+
#pragma unroll
209+
for (int k = 0; k < VecSize; ++k) {
210+
uint32_t idx = thread_offset + k;
211+
if (IsBoundary && idx == numel) {
212+
break;
213+
}
214+
#pragma unroll
215+
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
216+
if (i == configs[0].rank) break;
217+
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
218+
idx = fast_divmoder.val[0];
219+
index_bc[k] += fast_divmoder.val[1] * configs[Index].strides[i];
220+
}
221+
}
222+
223+
#pragma unroll
224+
for (int k = 0; k < VecSize; ++k) {
225+
std::get<Index>(args[k]) =
226+
reinterpret_cast<const _ptr_ Type *>(ins[Index])[index_bc[k]];
227+
}
228+
}
229+
};
230+
137231
#endif
138232

139233
// static broadcast unroller
@@ -685,7 +779,6 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
685779
};
686780
#endif
687781

688-
// FIXME: delete ElementwiseType
689782
template <ElementwiseType ET,
690783
typename OutT,
691784
typename Functor,
@@ -825,8 +918,6 @@ void BroadcastKernelForDifferentVecSize(
825918
}
826919
}
827920

828-
// FIXME: delete (ElementwiseType ET)
829-
// default: axis = -1
830921
template <ElementwiseType ET,
831922
typename InT,
832923
typename OutT,
@@ -839,7 +930,6 @@ void BroadcastKernel(const KPDevice &ctx,
839930
Functor func) {
840931
// When there are multiple inputs, the outputs's rank should be equal the
841932
// maximum rank of all inputs.
842-
// FIXME: delete ET ?
843933
using Traits = phi::funcs::FunctionTraits<Functor>;
844934
const int kArity = Traits::arity;
845935
PADDLE_ENFORCE_EQ(
@@ -888,7 +978,7 @@ void ElementwiseCompute(const GPUContext &dev_ctx,
888978
std::vector<const DenseTensor *> ins = {&x, &y};
889979
std::vector<DenseTensor *> outs = {z};
890980
dev_ctx.template Alloc<OutType>(z);
891-
// FIXME: delete ElementwiseType
981+
892982
BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
893983
dev_ctx, ins, &outs, axis, func);
894984
}

paddle/phi/kernels/funcs/elementwise_base.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ namespace kps = phi::kps;
3535

3636
namespace phi {
3737

38-
// FIXME: delete this enum
3938
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3 };
4039
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
4140
for supporting multiple-output feature in elementwise system.*/

0 commit comments

Comments
 (0)