diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index f96a1764c24a5..b4125ab555079 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -31,20 +31,49 @@ namespace funcs { enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 }; -template +template +struct UseBroadcast { + template + static HOSTDEVICE void Apply( + const std::vector &ins_tensor, + const ArgsT &args, + int64_t numel, + Array1 *ins_data, + Array2 *use_broadcast, + int *broadcast_num, + bool *all_elementwise) { + (*ins_data)[Index] = (const _ptr_ char *)(ins_tensor[Index]->data()); + bool is_same_dim = ins_tensor[Index]->numel() == numel; + if (is_same_dim) { + (*use_broadcast)[Index] = false; + } else { + (*use_broadcast)[Index] = true; + (*broadcast_num)++; + } + *all_elementwise &= is_same_dim; + } +}; + +template struct LoaderTypeClassifier { public: int64_t numel{0}; - int vec_size{1}; + int vec_size{4}; int broadcast_num{0}; bool all_elementwise{true}; - phi::Array use_broadcast; - phi::Array ins_data; + phi::Array use_broadcast; + phi::Array ins_data; LoaderTypeClassifier() {} LoaderTypeClassifier(const std::vector &ins, std::vector *outs) { + using Traits = phi::funcs::FunctionTraits; + using ArgsT = typename Traits::ArgsTuple; + ArgsT arg; uint64_t out_addr = reinterpret_cast((*outs)[0]->data()); + + UnrollerWithoutVecSize::step(ins, arg, &vec_size); + for (auto i = 1; i < outs->size(); ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), @@ -56,165 +85,185 @@ struct LoaderTypeClassifier { out_addr = (out_addr | reinterpret_cast((*outs)[i]->data())); } - int out_vec_size = - phi::GetVectorizedSize(reinterpret_cast(out_addr)); - uint64_t in_addr = static_cast(0); + vec_size = std::min( + vec_size, + phi::GetVectorizedSize(reinterpret_cast(out_addr))); numel = (*outs)[0]->numel(); - for (int i = 0; i < Arity; ++i) { - auto in_data = ins[i]->data(); - ins_data[i] = (const _ptr_ InT *)(in_data); - - bool is_same_dim = ins[i]->numel() == numel; - if (is_same_dim) { - use_broadcast[i] = false; - in_addr = (in_addr | reinterpret_cast(in_data)); - } else { - use_broadcast[i] = true; - broadcast_num++; - } - all_elementwise &= is_same_dim; - } - int in_vec_size = std::min( - 4, phi::GetVectorizedSize(reinterpret_cast(in_addr))); - vec_size = std::min(out_vec_size, in_vec_size); + UnrollerWithoutVecSize::step(ins, + arg, + numel, + &ins_data, + &use_broadcast, + &broadcast_num, + &all_elementwise); } }; -#ifndef PADDLE_WITH_XPU_KP // Common broadcast/elementwise Loader. -template +template struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { -#pragma unroll - for (int i = 0; i < Arity; ++i) { - kps::Init(args[i], static_cast(1.0f)); - if (use_broadcast[i]) { - kps::ReadDataBc( - args[i], ins[i], block_offset, configs[i], numel, VecSize); - } else { - kps::ReadData( - args[i], ins[i] + block_offset, num, VecSize); - } + template + static __device__ __forceinline__ void Apply(const Array1 &ins, + ArgsT *args, + const Array2 &configs, + const Array3 &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { + using Type = std::tuple_element_t; + kps::Init(args, static_cast(1.0f)); + + if (use_broadcast[Index]) { + kps::ReadDataBc( + args, + reinterpret_cast(ins[Index]), + block_offset, + configs[Index], + numel, + VecSize); + } + // NOTE: If use if...else... with condition `use_broadcast[Index]` here, + // there will be some errs with clang12 while compiling in ROCm. + // When the compiler is upgraded, if...else... may be used. + if (!use_broadcast[Index]) { + kps::ReadData( + args, + reinterpret_cast(ins[Index]) + block_offset, + num, + VecSize); } } }; +/* BroadcastDataLoaders Partial specialization */ +#ifndef PADDLE_WITH_XPU_KP // Scalar elementwise Loader with consideration of IsBoundary. -template -struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { +template +struct BroadcastDataLoader { + template + static __device__ __forceinline__ void Apply(const Array1 &ins, + ArgsT *args, + const Array2 &configs, + const Array3 &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { + using Type = std::tuple_element_t; int thread_offset = threadIdx.x * VecSize + block_offset; #pragma unroll - for (int i = 0; i < Arity; ++i) { -#pragma unroll - for (int idx = 0; idx < VecSize; ++idx) { - args[i][idx] = static_cast(1); - int index = thread_offset + idx; - if (index < numel) { - args[i][idx] = ins[i][index]; - } + for (int idx = 0; idx < VecSize; ++idx) { + std::get(args[idx]) = static_cast(1); + int index = thread_offset + idx; + if (index < numel) { + std::get(args[idx]) = + reinterpret_cast(ins[Index])[index]; } } } }; // Vectorized elementwise Loader without consideration of IsBoundary. -template -struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { - using VecType = phi::kps::details::VectorType; - VecType vec_temp[Arity]; +template +struct BroadcastDataLoader { + template + static __device__ __forceinline__ void Apply(const Array1 &ins, + ArgsT *args, + const Array2 &configs, + const Array3 &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { + using Type = std::tuple_element_t; + using VecType = phi::kps::details::VectorType; + VecType vec_temp; int thread_offset = threadIdx.x + blockIdx.x * blockDim.x; + const VecType *__restrict__ vec_input = + reinterpret_cast(ins[Index]); + vec_temp = vec_input[thread_offset]; #pragma unroll - for (int i = 0; i < Arity; ++i) { - const VecType *__restrict__ vec_input = - reinterpret_cast(ins[i]); - vec_temp[i] = vec_input[thread_offset]; -#pragma unroll - for (int idx = 0; idx < VecSize; ++idx) { - args[i][idx] = vec_temp[i].val[idx]; - } + for (int idx = 0; idx < VecSize; ++idx) { + std::get(args[idx]) = vec_temp.val[idx]; } } }; // Common broadcast data loader. -template -struct BroadcastDataLoader { - __device__ __forceinline__ void operator()( - T args[Arity][VecSize], - const phi::Array &ins, - const phi::Array &configs, - const phi::Array &use_broadcast, - const int block_offset, - const int num, - const uint32_t numel) { - uint32_t index_bc[Arity][VecSize]; -#pragma unroll - for (int j = 0; j < Arity; ++j) { +template +struct BroadcastDataLoader { + template + static __device__ __forceinline__ void Apply(const Array1 &ins, + ArgsT *args, + const Array2 &configs, + const Array3 &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { + using Type = std::tuple_element_t; + uint32_t index_bc[VecSize]; #pragma unroll - for (int k = 0; k < VecSize; ++k) { - index_bc[j][k] = 0; - args[j][k] = static_cast(1); - } + for (int k = 0; k < VecSize; ++k) { + index_bc[k] = 0; + std::get(args[k]) = static_cast(1); } uint32_t thread_offset = block_offset + threadIdx.x * VecSize; #pragma unroll for (int k = 0; k < VecSize; ++k) { uint32_t idx = thread_offset + k; - if (IsBoundary) { - if (idx == numel) break; + if (IsBoundary && idx == numel) { + break; } - #pragma unroll for (int i = 0; i < phi::DDim::kMaxRank; ++i) { if (i == configs[0].rank) break; auto fast_divmoder = configs[0].divmoders[i].Divmod(idx); idx = fast_divmoder.val[0]; -#pragma unroll - for (int j = 0; j < Arity; ++j) { - index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i]; - } + index_bc[k] += fast_divmoder.val[1] * configs[Index].strides[i]; } } #pragma unroll - for (int j = 0; j < Arity; ++j) { -#pragma unroll - for (int k = 0; k < VecSize; ++k) { - args[j][k] = ins[j][index_bc[j][k]]; - } + for (int k = 0; k < VecSize; ++k) { + std::get(args[k]) = + reinterpret_cast(ins[Index])[index_bc[k]]; } } }; + #endif -template + typename Func, + bool IsBoundary, + int LoadType, + int VecSize, + int End, + int Begin = 0> +struct BcUnroller { + template + static HOSTDEVICE inline void step(Args &&...args) { + Func::Apply( + std::forward(args)...); + BcUnroller::step( + args...); + } +}; + +template