|
| 1 | +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include <algorithm> |
| 18 | +#include <cmath> |
| 19 | +#include <numeric> |
| 20 | +#include <set> |
| 21 | +#include <vector> |
| 22 | + |
| 23 | +#include <cub/cub.cuh> // NOLINT |
| 24 | +#include "paddle/fluid/framework/tensor.h" |
| 25 | + |
| 26 | +namespace paddle { |
| 27 | +namespace operators { |
| 28 | + |
| 29 | +namespace detail { |
| 30 | +template <typename T, size_t ElementCount> |
| 31 | +struct Array { |
| 32 | + public: |
| 33 | + HOSTDEVICE inline Array() {} |
| 34 | + |
| 35 | + HOSTDEVICE inline T& operator[](size_t index) { return data_[index]; } |
| 36 | + |
| 37 | + HOSTDEVICE inline const T& operator[](size_t index) const { |
| 38 | + return data_[index]; |
| 39 | + } |
| 40 | + |
| 41 | + HOSTDEVICE constexpr inline size_t size() const { return ElementCount; } |
| 42 | + |
| 43 | + template <typename VectorLikeType> |
| 44 | + static inline Array<T, ElementCount> From(const VectorLikeType& vec) { |
| 45 | + PADDLE_ENFORCE_EQ(vec.size(), ElementCount, "size not match"); |
| 46 | + size_t n = static_cast<size_t>(vec.size()); |
| 47 | + Array<T, ElementCount> ret; |
| 48 | + for (size_t i = 0; i < n; ++i) ret[i] = vec[i]; |
| 49 | + return ret; |
| 50 | + } |
| 51 | + |
| 52 | + private: |
| 53 | + T data_[ElementCount]; |
| 54 | +}; |
| 55 | + |
| 56 | +// reduce the last axis of 2d array |
| 57 | +template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp, |
| 58 | + int BlockDim> |
| 59 | +__global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, |
| 60 | + TransformOp transformer, Ty init, |
| 61 | + int reduce_num) { |
| 62 | + __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage; |
| 63 | + int idx_x = blockIdx.x * reduce_num; |
| 64 | + int idx_y = threadIdx.x; |
| 65 | + Ty reduce_var = init; |
| 66 | + for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) |
| 67 | + reduce_var = reducer(reduce_var, transformer(x[idx_x + idx_y])); |
| 68 | + |
| 69 | + reduce_var = |
| 70 | + cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer); |
| 71 | + |
| 72 | + if (threadIdx.x == 0) { |
| 73 | + y[blockIdx.x] = reduce_var; |
| 74 | + } |
| 75 | +} |
| 76 | + |
| 77 | +template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp, |
| 78 | + int BlockDim, int Rank, int ReduceRank> |
| 79 | +__global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, |
| 80 | + TransformOp transformer, Ty init, int reduce_num, |
| 81 | + Array<int, Rank> x_strides, |
| 82 | + Array<int, ReduceRank> reduce_dim, |
| 83 | + Array<int, ReduceRank> reduce_strides, |
| 84 | + Array<int, Rank - ReduceRank> left_dim, |
| 85 | + Array<int, Rank - ReduceRank> left_strides) { |
| 86 | + __shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage; |
| 87 | + Array<int, Rank> sub_index; |
| 88 | + int left_idx = blockIdx.x; |
| 89 | + for (int i = 0; i < Rank - ReduceRank; ++i) { |
| 90 | + sub_index[left_dim[i]] = left_idx / left_strides[i]; |
| 91 | + left_idx %= left_strides[i]; |
| 92 | + } |
| 93 | + |
| 94 | + int reduce_idx = threadIdx.x; |
| 95 | + for (int j = 0; j < ReduceRank; ++j) { |
| 96 | + sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; |
| 97 | + reduce_idx %= reduce_strides[j]; |
| 98 | + } |
| 99 | + |
| 100 | + int idx_x = 0; |
| 101 | + for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); |
| 102 | + Ty reduce_var = static_cast<Ty>(transformer(x[idx_x])); |
| 103 | + |
| 104 | + for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) { |
| 105 | + int reduce_idx = i; |
| 106 | + for (int j = 0; j < ReduceRank; ++j) { |
| 107 | + sub_index[reduce_dim[j]] = reduce_idx / reduce_strides[j]; |
| 108 | + reduce_idx %= reduce_strides[j]; |
| 109 | + } |
| 110 | + |
| 111 | + int idx_x = 0; |
| 112 | + for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); |
| 113 | + reduce_var = static_cast<Ty>(reducer(reduce_var, transformer(x[idx_x]))); |
| 114 | + } |
| 115 | + |
| 116 | + reduce_var = |
| 117 | + cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer); |
| 118 | + |
| 119 | + if (threadIdx.x == 0) { |
| 120 | + y[blockIdx.x] = reduce_var; |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +static inline std::vector<int> GetStrides(const std::vector<int>& dims) { |
| 125 | + int n = static_cast<int>(dims.size()); |
| 126 | + if (n == 0) return std::vector<int>(); |
| 127 | + std::vector<int> strides(n); |
| 128 | + strides.back() = 1; |
| 129 | + for (int i = n - 2; i >= 0; --i) { |
| 130 | + strides[i] = strides[i + 1] * dims[i + 1]; |
| 131 | + } |
| 132 | + return strides; |
| 133 | +} |
| 134 | + |
| 135 | +static inline std::vector<int> GetStrides(const std::vector<int>& dims, |
| 136 | + const std::vector<int>& idx) { |
| 137 | + int n = static_cast<int>(idx.size()); |
| 138 | + if (n == 0) return std::vector<int>(); |
| 139 | + std::vector<int> strides(n); |
| 140 | + strides.back() = 1; |
| 141 | + for (int i = n - 2; i >= 0; --i) { |
| 142 | + strides[i] = strides[i + 1] * dims[idx[i + 1]]; |
| 143 | + } |
| 144 | + return strides; |
| 145 | +} |
| 146 | + |
| 147 | +constexpr int kMaxBlockDim = 512; |
| 148 | + |
| 149 | +static inline int GetDesiredBlockDim(int block_dim) { |
| 150 | + return block_dim >= kMaxBlockDim |
| 151 | + ? kMaxBlockDim |
| 152 | + : (1 << static_cast<int>(std::log2(block_dim))); |
| 153 | +} |
| 154 | + |
| 155 | +template <typename Tx, typename Ty, int BlockDim, typename ReduceOp, |
| 156 | + typename TransformOp> |
| 157 | +static void TensorReduceImpl( |
| 158 | + const Tx* x_data, Ty* y_data, const platform::Place& place, |
| 159 | + const ReduceOp& reducer, const TransformOp& transformer, const Ty& init, |
| 160 | + int left_num, int reduce_num, const std::vector<int>& x_strides, |
| 161 | + const std::vector<int>& reduce_dim, const std::vector<int>& reduce_strides, |
| 162 | + const std::vector<int>& left_dim, const std::vector<int>& left_strides, |
| 163 | + cudaStream_t stream) { |
| 164 | +#define CUB_RANK_CASE(i, ...) \ |
| 165 | + case i: { \ |
| 166 | + constexpr auto kRank = i; \ |
| 167 | + switch (reduce_rank) { __VA_ARGS__; } \ |
| 168 | + } break |
| 169 | + |
| 170 | +#define CUB_REDUCE_RANK_CASE(i, ...) \ |
| 171 | + case i: { \ |
| 172 | + constexpr auto kReduceRank = i; \ |
| 173 | + ReduceKernel<Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, \ |
| 174 | + kReduceRank><<<left_num, BlockDim, 0, stream>>>( \ |
| 175 | + x_data, y_data, reducer, transformer, init, reduce_num, \ |
| 176 | + Array<int, kRank>::From(x_strides), \ |
| 177 | + Array<int, kReduceRank>::From(reduce_dim), \ |
| 178 | + Array<int, kReduceRank>::From(reduce_strides), \ |
| 179 | + Array<int, kRank - kReduceRank>::From(left_dim), \ |
| 180 | + Array<int, kRank - kReduceRank>::From(left_strides)); \ |
| 181 | + } break |
| 182 | + |
| 183 | + int rank = x_strides.size(); |
| 184 | + int reduce_rank = reduce_strides.size(); |
| 185 | + if (rank == reduce_rank) { |
| 186 | + cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x( |
| 187 | + x_data, transformer); |
| 188 | + size_t temp_storage_bytes = 0; |
| 189 | + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, |
| 190 | + reduce_num, reducer, init, stream); |
| 191 | + framework::Tensor tmp; |
| 192 | + auto* temp_storage = tmp.mutable_data<uint8_t>( |
| 193 | + framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), |
| 194 | + place); |
| 195 | + cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, |
| 196 | + reduce_num, reducer, init, stream); |
| 197 | + return; |
| 198 | + } |
| 199 | + if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { |
| 200 | + ReduceKernel2D<Tx, Ty, ReduceOp, TransformOp, |
| 201 | + BlockDim><<<left_num, BlockDim, 0, stream>>>( |
| 202 | + x_data, y_data, reducer, transformer, init, reduce_num); |
| 203 | + return; |
| 204 | + } |
| 205 | + /* |
| 206 | + if (rank == 3 && reduce_rank == 1 && reduce_dim[0] == 1) { |
| 207 | + // TODO(liangdun): we can optimize 3d case which the 2nd axis is reduced. |
| 208 | + // Currently, it is handled by code below, but inefficient |
| 209 | + return; |
| 210 | + } |
| 211 | + */ |
| 212 | + |
| 213 | + switch (rank) { |
| 214 | + CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1);); |
| 215 | + |
| 216 | + CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);); |
| 217 | + |
| 218 | + CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); |
| 219 | + CUB_REDUCE_RANK_CASE(3);); |
| 220 | + |
| 221 | + CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); |
| 222 | + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);); |
| 223 | + |
| 224 | + CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); |
| 225 | + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); |
| 226 | + CUB_REDUCE_RANK_CASE(5);); |
| 227 | + |
| 228 | + CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); |
| 229 | + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); |
| 230 | + CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6);); |
| 231 | + |
| 232 | + CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); |
| 233 | + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); |
| 234 | + CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6);); |
| 235 | + |
| 236 | + CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); |
| 237 | + CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); |
| 238 | + CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6); |
| 239 | + CUB_REDUCE_RANK_CASE(7); CUB_REDUCE_RANK_CASE(8);); |
| 240 | + } |
| 241 | + |
| 242 | +#undef CUB_REDUCE_RANK_CASE |
| 243 | +#undef CUB_RANK_CASE |
| 244 | +} |
| 245 | + |
| 246 | +} // namespace detail |
| 247 | + |
| 248 | +template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp> |
| 249 | +void TensorReduce(const framework::Tensor& x, framework::Tensor* y, |
| 250 | + std::vector<int> origin_reduce_dims, const Ty& init, |
| 251 | + const ReduceOp& reducer, const TransformOp& transformer, |
| 252 | + cudaStream_t stream) { |
| 253 | + auto x_dim = framework::vectorize2int(x.dims()); |
| 254 | + std::vector<int> new_x_dim, new_reduce_dims; |
| 255 | + int is_reduced = 0; |
| 256 | + for (auto e : origin_reduce_dims) { |
| 257 | + auto pos = e >= 0 ? e : e + x_dim.size(); |
| 258 | + is_reduced |= 1 << e; |
| 259 | + } |
| 260 | + for (int i = 0; i < x_dim.size(); i++) { |
| 261 | + if ((i == 0) || (((is_reduced >> i) ^ (is_reduced >> (i - 1))) & 1)) { |
| 262 | + new_x_dim.push_back(x_dim[i]); |
| 263 | + if ((is_reduced >> i) & 1) |
| 264 | + new_reduce_dims.push_back(new_x_dim.size() - 1); |
| 265 | + } else { |
| 266 | + new_x_dim[new_x_dim.size() - 1] *= x_dim[i]; |
| 267 | + } |
| 268 | + } |
| 269 | + x_dim = new_x_dim; |
| 270 | + origin_reduce_dims = new_reduce_dims; |
| 271 | + int x_rank = static_cast<int>(x_dim.size()); |
| 272 | + std::set<int> left_set, reduce_set; |
| 273 | + for (int i = 0; i < x_rank; ++i) left_set.insert(i); |
| 274 | + |
| 275 | + for (auto e : origin_reduce_dims) { |
| 276 | + left_set.erase(e); |
| 277 | + reduce_set.insert(e); |
| 278 | + } |
| 279 | + |
| 280 | + std::vector<int> reduce_dim(reduce_set.begin(), reduce_set.end()); |
| 281 | + std::vector<int> left_dim(left_set.begin(), left_set.end()); |
| 282 | + |
| 283 | + std::vector<int> x_strides = detail::GetStrides(x_dim); |
| 284 | + std::vector<int> reduce_strides = detail::GetStrides(x_dim, reduce_dim); |
| 285 | + std::vector<int> left_strides = detail::GetStrides(x_dim, left_dim); |
| 286 | + int reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]]; |
| 287 | + int left_num = 1; |
| 288 | + if (left_dim.size()) left_num = left_strides[0] * x_dim[left_dim[0]]; |
| 289 | + |
| 290 | + std::vector<int> y_dim(left_dim.size()); |
| 291 | + for (int i = 0; i < left_dim.size(); ++i) { |
| 292 | + y_dim[i] = x_dim[left_dim[i]]; |
| 293 | + } |
| 294 | + auto x_data = x.data<Tx>(); |
| 295 | + auto y_data = y->mutable_data<Ty>(x.place()); |
| 296 | + if (reduce_num == 1) return; |
| 297 | + |
| 298 | +#define CUB_BLOCK_DIM_CASE(block_dim) \ |
| 299 | + case block_dim: { \ |
| 300 | + constexpr auto kBlockDim = block_dim; \ |
| 301 | + detail::TensorReduceImpl<Tx, Ty, block_dim, ReduceOp, TransformOp>( \ |
| 302 | + x_data, y_data, x.place(), reducer, transformer, init, left_num, \ |
| 303 | + reduce_num, x_strides, reduce_dim, reduce_strides, left_dim, \ |
| 304 | + left_strides, stream); \ |
| 305 | + } break |
| 306 | + |
| 307 | + switch (detail::GetDesiredBlockDim(reduce_num)) { |
| 308 | + CUB_BLOCK_DIM_CASE(512); |
| 309 | + CUB_BLOCK_DIM_CASE(256); |
| 310 | + CUB_BLOCK_DIM_CASE(128); |
| 311 | + CUB_BLOCK_DIM_CASE(64); |
| 312 | + CUB_BLOCK_DIM_CASE(32); |
| 313 | + CUB_BLOCK_DIM_CASE(16); |
| 314 | + CUB_BLOCK_DIM_CASE(8); |
| 315 | + CUB_BLOCK_DIM_CASE(4); |
| 316 | + CUB_BLOCK_DIM_CASE(2); |
| 317 | + } |
| 318 | +#undef CUB_BLOCK_DIM_CASE |
| 319 | +} |
| 320 | + |
| 321 | +} // namespace operators |
| 322 | +} // namespace paddle |
0 commit comments