Skip to content

Commit 6f538d1

Browse files
authored
【Hackathon + No.29】为 Paddle 新增 paddle.sparse.slice 稀疏 API (#382)
* api design for sparse slice * update doc for kernel and test case
1 parent f282c1e commit 6f538d1

File tree

1 file changed

+334
-0
lines changed

1 file changed

+334
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
# paddle.sparse.slice 设计文档
2+
3+
4+
| API名称 | paddle.sparse.slice |
5+
|-------------|-----------------------------------------|
6+
| 提交作者 | ScottWong98 |
7+
| 提交时间 | 2023-02-25 |
8+
| 版本号 | V1.0.0 |
9+
| 依赖飞桨版本 | develop |
10+
| 文件名 | 20230225_api_design_for_sparse_slice.md |
11+
12+
# 一、概述
13+
## 1、相关背景
14+
15+
简单来说,稀疏 Tensor 是元素大部分为零的矩阵,在实际求解任务时经常出现大规模的稀疏 Tensor。由于其自身的稀疏性,为了节省存储空间,通常会修改稀疏 Tensor 的存储结构。目前比较普遍的存储结构为 COO 和 CSR。
16+
17+
Paddle 目前已经实现了 COO 和 CSR 格式的稀疏 Tensor 的构建以及一些算子操作,然而目前还没有支持对其的 slice 操作,而 slice 操作在实际中是有应用价值的,因此在 Paddle 中集成该功能是有必要的。
18+
19+
## 2、功能目标
20+
21+
为 Paddle 新增 paddle.sparse.slice 稀疏 API。针对 Paddle 的两种稀疏 Tensor 格式 COO 和 CSR,都需新增 slice 的计算逻辑。一共需要新增 2 个 kernel 的前向与反向。动静态图都需要支持。
22+
23+
其中 COO 的 kernel 需要支持任意维度的稀疏 Tensor,CSR 的 kernel 需要支持 2D/3D 的稀疏 Tensor。
24+
25+
## 3、意义
26+
27+
支持稀疏 Tensor 的 slice 操作,丰富基础功能,提升稀疏 Tensor 的 API 完整度。
28+
29+
# 二、飞桨现状
30+
31+
目前paddle缺少相关功能实现。
32+
33+
# 三、业内方案调研
34+
35+
针对 PyTorch,TensorFlow 和 SciPy 三种框架对该功能进行了调研,具体结果如下。
36+
37+
## PyTorch
38+
PyTorch 目前还不支持对稀疏 Tensor 的 slice 功能,参考 [PyTorch 论坛上的回答](https://discuss.pytorch.org/t/column-row-slicing-a-torch-sparse-tensor/19130/2)
39+
40+
## TensorFlow
41+
42+
TensorFlow 只支持 COO 格式的 slice 功能。详情可参考官方文档([tf.sparse.slice](https://www.tensorflow.org/api_docs/python/tf/sparse/slice))。
43+
44+
具体核心实现代码如下所示(截取自 [tensorflow/core/util/sparse/sparse_tensor.h](https://github.com/tensorflow/tensorflow/blob/v2.11.0/tensorflow/core/util/sparse/sparse_tensor.h#L580) 文件):
45+
46+
```cpp
47+
template <typename T>
48+
inline StatusOr<SparseTensor> SparseTensor::Slice(
49+
const SparseTensor& input_tensor, const gtl::ArraySlice<int64_t> start,
50+
const gtl::ArraySlice<int64_t> size) {
51+
TensorShape output_shape(input_tensor.shape());
52+
53+
const int dims = input_tensor.dims();
54+
for (int dim = 0; dim < dims; dim++) {
55+
// Determine the size of the result; if the selected slice goes beyond the
56+
// input boundary, the result will correspond to the size of the overlap
57+
// between the input and the selected slice.
58+
const int64_t input_size = output_shape.dim_size(dim);
59+
const int64_t start_index = start[dim];
60+
const int64_t slice_size = size[dim];
61+
62+
if (start_index < input_size - slice_size) {
63+
// The entire selection is within input boundaries.
64+
TF_RETURN_IF_ERROR(output_shape.SetDimWithStatus(dim, slice_size));
65+
} else if (start_index < input_size) {
66+
// The selection starts within input boundaries, but goes beyond them.
67+
TF_RETURN_IF_ERROR(
68+
output_shape.SetDimWithStatus(dim, input_size - start_index));
69+
} else {
70+
// The selection is entirely out of input boundaries.
71+
TF_RETURN_IF_ERROR(output_shape.SetDimWithStatus(dim, 0));
72+
}
73+
}
74+
75+
auto input_indices_t = input_tensor.indices().matrix<int64_t>();
76+
auto input_values_t = input_tensor.values().vec<T>();
77+
78+
// Find the number of indices that fall inside start and size.
79+
int count = 0;
80+
for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
81+
// The following will check to see if an input is within the
82+
// range specified by start and size.
83+
// The for loop below iterates through all dimensions. In case
84+
// the index falls outside of the start and size at any dimension,
85+
// it will be considered as a "no hit" (hit = false). In this
86+
// case, it will not be counted as the index that fall inside
87+
// the range specified by start and size.
88+
bool hit = true;
89+
for (int dim = 0; dim < dims; dim++) {
90+
if (!(start[dim] <= input_indices_t(i, dim) &&
91+
input_indices_t(i, dim) < start[dim] + size[dim])) {
92+
hit = false;
93+
break;
94+
}
95+
}
96+
if (!hit) {
97+
continue;
98+
}
99+
count++;
100+
}
101+
102+
Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
103+
Tensor output_indices(DT_INT64, TensorShape({count, dims}));
104+
105+
auto output_values_t = output_values.vec<T>();
106+
auto output_indices_t = output_indices.matrix<int64_t>();
107+
108+
// Obtain the output indices that fall inside start and size.
109+
int index = 0;
110+
for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
111+
i++) {
112+
// The logic here is similar as the above except that the above
113+
// only count the number of indices while here we actually generate
114+
// the output.
115+
bool hit = true;
116+
for (int dim = 0; dim < dims; dim++) {
117+
if (!(start[dim] <= input_indices_t(i, dim) &&
118+
input_indices_t(i, dim) < start[dim] + size[dim])) {
119+
hit = false;
120+
break;
121+
}
122+
}
123+
if (!hit) {
124+
continue;
125+
}
126+
output_values_t(index) = input_values_t(i);
127+
for (int dim = 0; dim < dims; dim++) {
128+
output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
129+
}
130+
index++;
131+
}
132+
133+
return SparseTensor(output_indices, output_values, output_shape);
134+
}
135+
```
136+
137+
## SciPy
138+
139+
SciPy 只支持对 CSR 格式的 slice 操作。SciPy 并没有提供对 slice 操作的文档说明,但经过实践,发现与 Numpy 中的 slice 操作形式一样。
140+
141+
SciPy 中对 slice 操作的具体核心实现代码如下所示 (截取自 [scipy/sparse/sparsetools/csr.h](https://github.com/scipy/scipy/blob/v1.10.1/scipy/sparse/sparsetools/csr.h#L1181) 文件):
142+
```c++
143+
template<class I, class T>
144+
void get_csr_submatrix(const I n_row,
145+
const I n_col,
146+
const I Ap[],
147+
const I Aj[],
148+
const T Ax[],
149+
const I ir0,
150+
const I ir1,
151+
const I ic0,
152+
const I ic1,
153+
std::vector<I>* Bp,
154+
std::vector<I>* Bj,
155+
std::vector<T>* Bx)
156+
{
157+
I new_n_row = ir1 - ir0;
158+
//I new_n_col = ic1 - ic0; //currently unused
159+
I new_nnz = 0;
160+
I kk = 0;
161+
162+
// Count nonzeros total/per row.
163+
for(I i = 0; i < new_n_row; i++){
164+
I row_start = Ap[ir0+i];
165+
I row_end = Ap[ir0+i+1];
166+
167+
for(I jj = row_start; jj < row_end; jj++){
168+
if ((Aj[jj] >= ic0) && (Aj[jj] < ic1)) {
169+
new_nnz++;
170+
}
171+
}
172+
}
173+
174+
// Allocate.
175+
Bp->resize(new_n_row+1);
176+
Bj->resize(new_nnz);
177+
Bx->resize(new_nnz);
178+
179+
// Assign.
180+
(*Bp)[0] = 0;
181+
for(I i = 0; i < new_n_row; i++){
182+
I row_start = Ap[ir0+i];
183+
I row_end = Ap[ir0+i+1];
184+
185+
for(I jj = row_start; jj < row_end; jj++){
186+
if ((Aj[jj] >= ic0) && (Aj[jj] < ic1)) {
187+
(*Bj)[kk] = Aj[jj] - ic0;
188+
(*Bx)[kk] = Ax[jj];
189+
kk++;
190+
}
191+
}
192+
(*Bp)[i+1] = kk;
193+
}
194+
}
195+
```
196+
197+
# 四、对比分析
198+
199+
由于 PyTorch 并没有支持稀疏 Tensor 的 slice 操作,故我们只对 TensorFlow 和 SciPy 进行分析。
200+
201+
TensorFlow
202+
- 优点:实现了 COO 格式下对任意维度 slice 的操作
203+
- 缺点:仅支持 COO 格式
204+
205+
SciPy
206+
- 优点:实现了 CSR 格式下 slice 的操作
207+
- 缺点:
208+
- 仅提供 CSR 格式的 API,对于 COO 格式的 slice 操作,只能转换到 CSR 格式进行实现。
209+
- 只支持 2D 稀疏 Tensor 的 slice 操作
210+
211+
因此,我们可以在 TensorFlow 和 SciPy 的实现逻辑之上进行相应的改动,来实现我们所设置的功能目标。
212+
# 五、设计思路与实现方案
213+
214+
## 命名与参数设计
215+
216+
仿照 `DenseTensor` 中 slice kernel 的设计,在 `paddle/phi/kernels/sparse/cpu/slice_kernel.cc``paddle/phi/kernels/sparse/gpu/slice_kernel.cu` 中,前向 kernel 的设计为:
217+
```c++
218+
template <typename T, typename Context>
219+
void SliceCooKernel(const Context& dev_ctx,
220+
const SparseCooTensor& x,
221+
const std::vector<int64_t>& axes,
222+
const phi::IntArray& starts,
223+
const phi::IntArray& ends,
224+
SparseCooTensor* out);
225+
```
226+
```c++
227+
template <typename T, typename Context>
228+
void SliceCsrKernel(const Context& dev_ctx,
229+
const SparseCsrTensor& x,
230+
const std::vector<int64_t>& axes,
231+
const phi::IntArray& starts,
232+
const phi::IntArray& ends,
233+
SparseCsrTensor* out);
234+
```
235+
236+
`paddle/phi/kernels/sparse/cpu/slice_grad_kernel.cc``paddle/phi/kernels/sparse/gpu/slice_grad_kernel.cu` 中,反向 kernel 的设计为:
237+
```c++
238+
template <typename T, typename Context>
239+
void SliceCooGradKernel(const Context& dev_ctx,
240+
const SparseCooTensor& x,
241+
const SparseCooTensor& out_grad,
242+
const std::vector<int64_t>& axes,
243+
const phi::IntArray& starts,
244+
const phi::IntArray& ends,
245+
SparseCooTensor* x_grad);
246+
```
247+
```c++
248+
template <typename T, typename Context>
249+
void SliceCsrGradKernel(const Context& dev_ctx,
250+
const SparseCsrTensor& x,
251+
const SparseCsrTensor& out_grad,
252+
const std::vector<int64_t>& axes,
253+
const phi::IntArray& starts,
254+
const phi::IntArray& ends,
255+
SparseCsrTensor* x_grad);
256+
```
257+
258+
`paddle/phi/api/yaml/sparse_ops.yaml` 中新增对应 API:
259+
```yaml
260+
- op : slice
261+
args : (Tensor x, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
262+
output : Tensor(out)
263+
infer_meta :
264+
func : UnchangedInferMeta
265+
param: [x]
266+
kernel :
267+
func : slice_coo{sparse_coo -> sparse_coo},
268+
slice_csr{sparse_csr -> sparse_csr}
269+
layout: x
270+
backward : slice_grad
271+
```
272+
273+
`paddle/phi/api/yaml/sparse_backward.yaml` 中新增对应 API:
274+
```yaml
275+
- backward_op : slice_grad
276+
forward : slice (Tensor x, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(out)
277+
args : (Tensor x, Tensor out_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
278+
output : Tensor(x_grad)
279+
infer_meta :
280+
func : UnchangedInferMeta
281+
param : [x]
282+
kernel :
283+
func : slice_coo_grad{sparse_coo, sparse_coo -> sparse_coo},
284+
slice_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
285+
```
286+
## 底层OP设计
287+
288+
对于 COO 格式的 slice 操作,可以参考 TensorFlow 的方法,遍历每个非零元素,判断其位置在各维度上是否在 slice 的范围内。
289+
290+
对于 CSR 格式的 slice 操作,可以在 SciPy 的基础上添加对 3D 稀疏 Tensor 的 slice 操作。
291+
- 对于 2D 稀疏 Tensor,处理逻辑与 SciPy 相似
292+
- 对于 3D 稀疏 Tensor,可以先对第 0 维进行 slice,第 1 维和第 2 维的处理与 2D 稀疏 Tensor 的处理逻辑类似。
293+
294+
## API实现方案
295+
296+
预期 Paddle 调用 slice API 的形式为:
297+
```python
298+
paddle.sparse.slice(x, axes, starts, ends)
299+
```
300+
- **x** (Tensor) - 输入的稀疏 Tensor,支持 COO 和 CSR 格式
301+
- **axes** (list|tuple) - 需要进行 slice 操作的维度,如果是 CSR 格式的稀疏 Tensor,确保长度为 2 或 3
302+
- **starts** (list|tuple|Tensor) - 各维度上 slice 的起始位置,如果是 CSR 格式的稀疏 Tensor,确保长度为 2 或 3
303+
- **ends** (list|tule|Tensor) - 各维度上 slice 的结束位置,如果是 CSR 格式的稀疏 Tensor,确保长度为 2 或 3
304+
305+
我们会首先检查 **axes**, **starts** 与 **ends** 的合法性,再进行对应的 slice 操作。
306+
307+
# 六、测试和验收的考量
308+
309+
测试考虑的 case 以及验收标准如下:
310+
311+
| case | 验收标准|
312+
|------|-------|
313+
|axes, starts 和 ends 长度对比 | 对长度不相等的情况能进行报错,相等的情况能返回正确结果|
314+
|axes, starts 和 ends 对边界的处理 | 对超出边界的情况能进行报错,未超出边界的情况能返回正确结果|
315+
|axes, starts 和 ends 对负数的处理 | 能返回正确结果|
316+
|不同 shape, axes, starts 和 ends 下结果的正确性 | 能返回正确结果|
317+
318+
# 七、可行性分析和排期规划
319+
320+
方案主要自行实现核心算法,可行。具体规划为:
321+
322+
- 阶段一:实现 cpu 上的 API 功能开发,并通过测试
323+
- 阶段二:实现 gpu 上的 API 功能开发,并通过测试
324+
- 阶段三:书写该 API 的中英文档
325+
326+
# 八、影响面
327+
为独立新增op,对其他模块没有影响
328+
329+
# 名词解释
330+
331+
332+
# 附件及参考资料
333+
334+

0 commit comments

Comments
 (0)