|
| 1 | +# paddle.sparse.nn.Softmax 设计文档 |
| 2 | + |
| 3 | +| API名称 | paddle.sparse.nn.Softmax | |
| 4 | +| ------------------------------------------------------------ | ------------------------------------------------- | |
| 5 | +| 提交作者<input type="checkbox" class="rowselector hidden"> | thunder95 | |
| 6 | +| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-04-19 | |
| 7 | +| 版本号 | V1.0 | |
| 8 | +| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | Develop | |
| 9 | +| 文件名 | 20230419_api_design_for_sparse_coo_nn_softmax.md<br> | |
| 10 | + |
| 11 | + |
| 12 | +# 一、概述 |
| 13 | + |
| 14 | +## 1、相关背景 |
| 15 | + |
| 16 | +稀疏 Tensor 是元素大部分为零的矩阵,在实际求解任务时经常出现大规模的稀疏 Tensor。由于其自身的稀疏性,为了节省存储空间,通常会修改稀疏 Tensor 的存储结构。目前比较普遍的存储结构为 COO 和 CSR。 |
| 17 | + |
| 18 | +Paddle 目前已经实现了 COO 和 CSR 格式的稀疏 Tensor 的构建以及一些算子操作,Softmax目前仅仅支持了 CSR 格式的稀疏 Tensor, 还需要对COO格式的支持。 |
| 19 | + |
| 20 | +## 2、功能目标 |
| 21 | + |
| 22 | +在飞桨中增加 paddle.sparse.nn.Softmax 对COO稀疏格式的支持。 |
| 23 | + |
| 24 | +## 3、意义 |
| 25 | + |
| 26 | +飞桨将支持 paddle.sparse.nn.Softmax 在coo 稀疏格式下的计算逻辑。 |
| 27 | + |
| 28 | +# 二、飞桨现状 |
| 29 | + |
| 30 | +目前飞桨的paddle.sparse.nn.Softmax API 仅支持CSR 格式, 还不支持COO稀疏格式。 |
| 31 | + |
| 32 | + |
| 33 | +# 三、业内方案调研 |
| 34 | + |
| 35 | +## TensorFlow |
| 36 | + |
| 37 | +Tensorflow中提供了softmax稀疏算子支持, 详情可参考官方文档([tf.sparse.softmax](https://tensorflow.google.cn/api_docs/python/tf/sparse/softmax)) 。 |
| 38 | + |
| 39 | +``` python |
| 40 | +tf.sparse.softmax( |
| 41 | + sp_input, name=None |
| 42 | +) |
| 43 | +``` |
| 44 | +具体核心实现代码如下所示(截取自 [tensorflow/core/kernels/sparse_softmax_op.cc](https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/core/kernels/sparse_softmax_op.cc) |
| 45 | +```cpp |
| 46 | +template <typename Device, typename T> |
| 47 | +class SparseSoftmaxOp : public OpKernel { |
| 48 | + public: |
| 49 | + explicit SparseSoftmaxOp(OpKernelConstruction *context) : OpKernel(context) {} |
| 50 | + |
| 51 | + void Compute(OpKernelContext *context) override { |
| 52 | + const Tensor *indices_t, *values_t, *shape_t; |
| 53 | + OP_REQUIRES_OK(context, context->input("sp_indices", &indices_t)); |
| 54 | + OP_REQUIRES_OK(context, context->input("sp_values", &values_t)); |
| 55 | + OP_REQUIRES_OK(context, context->input("sp_shape", &shape_t)); |
| 56 | + |
| 57 | + // Validations. |
| 58 | + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices_t->shape()), |
| 59 | + errors::InvalidArgument( |
| 60 | + "Input sp_indices should be a matrix but received shape: ", |
| 61 | + indices_t->shape().DebugString())); |
| 62 | + OP_REQUIRES(context, |
| 63 | + TensorShapeUtils::IsVector(values_t->shape()) && |
| 64 | + TensorShapeUtils::IsVector(shape_t->shape()), |
| 65 | + errors::InvalidArgument( |
| 66 | + "Inputs sp_values and sp_shape should be vectors " |
| 67 | + "but received shapes: ", |
| 68 | + values_t->shape().DebugString(), " and ", |
| 69 | + shape_t->shape().DebugString())); |
| 70 | + OP_REQUIRES(context, shape_t->NumElements() >= 2, |
| 71 | + errors::InvalidArgument( |
| 72 | + "Input should have rank >= 2, but received shape: ", |
| 73 | + shape_t->SummarizeValue(3))); |
| 74 | + TensorShape shape; |
| 75 | + OP_REQUIRES_OK(context, TensorShape::BuildTensorShape( |
| 76 | + shape_t->flat<int64_t>(), &shape)); |
| 77 | + |
| 78 | + const int64_t nnz = indices_t->dim_size(0); |
| 79 | + const int rank = static_cast<int>(indices_t->dim_size(1)); |
| 80 | + SparseTensor st; |
| 81 | + OP_REQUIRES_OK( |
| 82 | + context, SparseTensor::Create(tensor::DeepCopy(*indices_t), |
| 83 | + tensor::DeepCopy(*values_t), shape, &st)); |
| 84 | + |
| 85 | + Tensor *output_values = nullptr; |
| 86 | + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({nnz}), |
| 87 | + &output_values)); |
| 88 | + typename TTypes<T>::Flat output_flat = output_values->flat<T>(); |
| 89 | + |
| 90 | + Tensor tmp_t; |
| 91 | + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value, |
| 92 | + TensorShape({}), &tmp_t)); |
| 93 | + typename TTypes<T>::Scalar tmp_scalar = tmp_t.scalar<T>(); |
| 94 | + |
| 95 | + gtl::InlinedVector<int64_t, 4> dims(rank); |
| 96 | + std::iota(dims.begin(), dims.end(), 0); |
| 97 | + // { 0, ..., rank-1 }. |
| 98 | + const ArraySlice<int64_t> kReorderDims(dims); |
| 99 | + // All but the last dim -- the class dimension to be max-reduced along. |
| 100 | + const ArraySlice<int64_t> kGroupByDims = kReorderDims.subspan(0, rank - 1); |
| 101 | + st.Reorder<T>(kReorderDims); |
| 102 | + int count = 0; |
| 103 | + |
| 104 | + // The SparseTensor has logical shape [..., b, c], where the |
| 105 | + // innermost size-"c" dimension is the class dimension to be max-reduced. |
| 106 | + // Therefore we group by the first (rank - 1) dimensions. |
| 107 | + const Device &device = context->eigen_device<Device>(); |
| 108 | + for (const auto &g : st.group(kGroupByDims)) { |
| 109 | + const auto group_vals = g.values<T>(); |
| 110 | + const int group_size = group_vals.size(); |
| 111 | + |
| 112 | + // Shifts by max, exponentiates, then renormalizes. |
| 113 | + tmp_scalar.device(context->eigen_device<Device>()) = group_vals.maximum(); |
| 114 | + const T group_max = tmp_scalar(); |
| 115 | + |
| 116 | + Eigen::Tensor<T, 1, Eigen::RowMajor> tmp(group_size); |
| 117 | + tmp.device(device) = (group_vals - tmp.constant(group_max)).exp(); |
| 118 | + |
| 119 | + tmp_scalar.device(device) = tmp.sum().inverse(); |
| 120 | + tmp.device(device) = tmp * tmp.constant(tmp_scalar()); |
| 121 | + |
| 122 | + // Assigns back to output[count, count + group_size). |
| 123 | + Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>> output_part( |
| 124 | + output_flat.data() + count, group_size); |
| 125 | + output_part.device(device) = tmp; |
| 126 | + |
| 127 | + count += group_size; |
| 128 | + } |
| 129 | + } |
| 130 | +}; |
| 131 | +``` |
| 132 | + |
| 133 | +## SciPy |
| 134 | + |
| 135 | +SciPy 不支持对稀疏 Tensor 的 softmax 操作。 |
| 136 | + |
| 137 | +## Pytorch |
| 138 | + |
| 139 | +Pytorch中支持了softmax API的COO格式稀疏算子, 详情可参考官方文档([torch.sparse.softmax](https://pytorch.org/docs/stable/generated/torch.sparse.softmax.html) 。 |
| 140 | +具体核心实现代码如下所示(截取自 [pytorch/src/ATen/native/sparse/SoftMax.cpp](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/sparse/SoftMax.cpp) |
| 141 | + |
| 142 | +``` cpp |
| 143 | +template <typename scalar_t, bool LogSoftMax> |
| 144 | +void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t dim) { |
| 145 | + auto sparse_dim = input.sparse_dim(); |
| 146 | + auto indices = input._indices().contiguous(); |
| 147 | + auto values = input._values().contiguous(); |
| 148 | + auto out_values = output._values(); |
| 149 | + auto out_indices = output._indices(); |
| 150 | + out_values.resize_as_(values); |
| 151 | + out_indices.resize_as_(indices); |
| 152 | + out_indices.copy_(indices); |
| 153 | + |
| 154 | + if (dim >= sparse_dim) { |
| 155 | + if (LogSoftMax) { |
| 156 | + auto new_values = |
| 157 | + at::cpu::_log_softmax(values, dim - sparse_dim + 1, false); |
| 158 | + out_values.set_(new_values); |
| 159 | + } else { |
| 160 | + auto new_values = at::cpu::_softmax(values, dim - sparse_dim + 1, false); |
| 161 | + out_values.set_(new_values); |
| 162 | + } |
| 163 | + return; |
| 164 | + } |
| 165 | + |
| 166 | + auto nnz = values.size(0); |
| 167 | + auto sizes = input.sizes(); |
| 168 | + auto nvalues = get_nvalues(sizes, sparse_dim); |
| 169 | + |
| 170 | + /* Prepare accessors */ |
| 171 | + auto values_2 = values.view({nnz, nvalues}); |
| 172 | + auto values_accessor = values_2.accessor<scalar_t, 2>(); |
| 173 | + |
| 174 | + auto out_values_2 = out_values.view({nnz, nvalues}); |
| 175 | + auto out_values_accessor = out_values_2.accessor<scalar_t, 2>(); |
| 176 | + |
| 177 | + /* Compute independent pools of indices */ |
| 178 | + auto pools = get_pools(indices, sizes, dim); |
| 179 | + |
| 180 | + int64_t grain_size = 1; |
| 181 | + parallel_for(0, pools.size(), grain_size, [&](int64_t begin, int64_t end) { |
| 182 | + for (const auto p : c10::irange(begin, end)) { |
| 183 | + auto pool_indices = pools[p]; |
| 184 | + |
| 185 | + // Skip empty pools |
| 186 | + if (pool_indices.empty()) |
| 187 | + continue; |
| 188 | + |
| 189 | + /* Prepare scratch space */ |
| 190 | + std::vector<scalar_t> mx_row(nvalues, -std::numeric_limits<scalar_t>::infinity()); |
| 191 | + std::vector<scalar_t> exp_sums_row(nvalues, 0); |
| 192 | + |
| 193 | + /* Compute mx */ |
| 194 | + for (int64_t i : pool_indices) { |
| 195 | + auto values_row = values_accessor[i]; |
| 196 | + for (const auto j : c10::irange(nvalues)) { |
| 197 | + mx_row[j] = std::max(mx_row[j], values_row[j]); |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + /* Apply exp to (v - mx) and sum the results */ |
| 202 | + for (int64_t i : pool_indices) { |
| 203 | + auto values_row = values_accessor[i]; |
| 204 | + auto out_values_row = out_values_accessor[i]; |
| 205 | + for (const auto j : c10::irange(nvalues)) { |
| 206 | + auto v = std::exp(values_row[j] - mx_row[j]); |
| 207 | + if (!LogSoftMax) { |
| 208 | + out_values_row[j] = v; |
| 209 | + } |
| 210 | + exp_sums_row[j] += v; |
| 211 | + } |
| 212 | + } |
| 213 | + |
| 214 | + for (const auto j : c10::irange(nvalues)) { |
| 215 | + if (LogSoftMax) { |
| 216 | + mx_row[j] += std::log(exp_sums_row[j]); |
| 217 | + } else { |
| 218 | + exp_sums_row[j] = 1.0 / exp_sums_row[j]; |
| 219 | + } |
| 220 | + } |
| 221 | + |
| 222 | + /* Normalize with the sum of exponents */ |
| 223 | + for (int64_t i : pool_indices) { |
| 224 | + auto values_row = values_accessor[i]; |
| 225 | + auto out_values_row = out_values_accessor[i]; |
| 226 | + for (const auto j : c10::irange(nvalues)) { |
| 227 | + if (LogSoftMax) { |
| 228 | + out_values_row[j] = values_row[j] - mx_row[j]; |
| 229 | + } else { |
| 230 | + out_values_row[j] *= exp_sums_row[j]; |
| 231 | + } |
| 232 | + } |
| 233 | + } |
| 234 | + } |
| 235 | + }); |
| 236 | +} |
| 237 | +``` |
| 238 | +
|
| 239 | +# 四、对比分析 |
| 240 | +
|
| 241 | +Tensorflow基于Eigen计算,支持COO稀疏格式,不支持axis传入。 |
| 242 | +Scipy没有直接支持softmax的稀疏算子计算。 |
| 243 | +Pytorch中能支持axis传入,且支持COO格式的稀疏算子。 |
| 244 | +
|
| 245 | +
|
| 246 | +# 五、设计思路与实现方案 |
| 247 | +
|
| 248 | +## 命名与参数设计 |
| 249 | +
|
| 250 | +sparse softmax 已经支持 CSR 格式,这个稀疏张量上的方法的命名和参数不需要额外设计,只需要添加相应的COO格式支持。 |
| 251 | +
|
| 252 | +在 paddle/phi/api/yaml 下新增注册该算子COO格式的前向以及反向。 |
| 253 | +
|
| 254 | +``` yaml |
| 255 | +- op : softmax |
| 256 | + args : (Tensor x, int axis=-1) |
| 257 | + output : Tensor(out) |
| 258 | + infer_meta : |
| 259 | + func : UnchangedInferMeta |
| 260 | + param : [x] |
| 261 | + kernel : |
| 262 | + func : softmax_coo{sparse_coo -> sparse_coo}, |
| 263 | + softmax_csr{sparse_csr -> sparse_csr} |
| 264 | + layout : x |
| 265 | + backward : softmax_grad |
| 266 | +``` |
| 267 | + |
| 268 | + |
| 269 | +``` yaml |
| 270 | +- backward_op : softmax_grad |
| 271 | + forward : softmax(Tensor x, int axis=-1) -> Tensor(out) |
| 272 | + args : (Tensor out, Tensor out_grad, int axis) |
| 273 | + output : Tensor(x_grad) |
| 274 | + infer_meta : |
| 275 | + func : UnchangedInferMeta |
| 276 | + param : [out] |
| 277 | + kernel : |
| 278 | + func : softmax_coo_grad{sparse_coo, sparse_coo -> sparse_coo}, |
| 279 | + softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr} |
| 280 | +``` |
| 281 | +
|
| 282 | +## 底层OP设计 |
| 283 | +
|
| 284 | +新增一个COO格式的前向以及反向Kernel: |
| 285 | +
|
| 286 | +``` cpp |
| 287 | +template <typename T, typename Context> |
| 288 | +void SoftmaxCooKernel(const Context& dev_ctx, |
| 289 | + const SparseCooTensor& x, |
| 290 | + int axis, |
| 291 | + SparseCooTensor* out); |
| 292 | +``` |
| 293 | + |
| 294 | +``` cpp |
| 295 | +template <typename T, typename Context> |
| 296 | +void SoftmaxCooGradKernel(const Context& dev_ctx, |
| 297 | + const SparseCooTensor& out, |
| 298 | + const SparseCooTensor& dout, |
| 299 | + int axis, |
| 300 | + SparseCooTensor* dx); |
| 301 | +``` |
| 302 | +
|
| 303 | +## API实现方案 |
| 304 | +
|
| 305 | +在python/paddle/sparse/nn/functional/activation.py 文件和 python/paddle/sparse/nn/layer/activation.py 文件中的原API上没有改动。 |
| 306 | +
|
| 307 | +参考pytorch的计算方式,先计算索引的pool映射,在对应维度上一次计算max以及求和,最终对指数的求和进行normalize计算。 |
| 308 | +
|
| 309 | +在cuda的kernel中, 若指定的axis大于等于稀疏维度,将使用稠密张量的softmax算子,若小于则分两步; |
| 310 | +- 先计算pool和max, 基于Thrust库设计函数ComputePoolMax, 计算出指定维度上索引的pools以及每个pool对应的最大值, |
| 311 | +- 基于pool的数量,设计对应的block和grid,调用SparseCooSoftmaxKernel, 计算pool内的softmax值 |
| 312 | + |
| 313 | +在反向梯度SparseCooSoftmaxGradKernel计算中,需先设计函数GetOffsets, 基于稀疏张量的索引计算对应稠密张量的偏移量,进而通过反向求导的公式计算梯度。 |
| 314 | +
|
| 315 | +
|
| 316 | +# 六、测试和验收的考量 |
| 317 | +
|
| 318 | +完善单测代码,python/paddle/fluid/tests/unittests/test_sparse_softmax_op.py 文件中新增测试COO稀疏格式的case如下: |
| 319 | +
|
| 320 | +- 数值正确性 |
| 321 | +- COO数据格式 |
| 322 | +- 不同输入tensor的数据类型下检查输出结果 |
| 323 | +- 计算结果与dense tensor进行比较 |
| 324 | +
|
| 325 | +# 七、可行性分析和排期规划 |
| 326 | +
|
| 327 | +前两周实现代码、文档和测试。 |
| 328 | +
|
| 329 | +第三周进行 Code Review 和继续迭代。 |
| 330 | +
|
| 331 | +# 八、影响面 |
| 332 | +
|
| 333 | +对其它模块没有影响。 |
| 334 | +
|
| 335 | +# 名词解释 |
| 336 | +
|
| 337 | +# 附件及参考资料 |
0 commit comments