Skip to content

Commit ea9ec5f

Browse files
committed
【Hackathon 4 No.83】为神经网络编译器 CINN 增加 resize 算子
1 parent f34a1ad commit ea9ec5f

File tree

2 files changed

+109
-2
lines changed

2 files changed

+109
-2
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# CINN resize 设计文档
2+
|API名称 | resize |
3+
|---|---|
4+
| 提交作者<input type="checkbox" class="rowselector hidden"> | 无名 |
5+
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2023-02-26 |
6+
| 版本号 | V0.0 |
7+
| 依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | paddlepaddle-gpu==0.0 |
8+
| 文件名 | 20230226_cinn_api_design_resize_qlog.md<br> |
9+
10+
11+
12+
# 一、概述
13+
14+
## 1、相关背景
15+
CINN是一种在不改变模型代码的条件下加速飞桨模型运行速度的深度学习编译器。在对接上层框架时,编译器会将上层的框架算子进一步拆分为若干基础算子,这样做的目的一方面是为了减少算子开发的工作量,仅实现有限的基础算子便可以组合出大量的上层框架算子;另一方面便于算子融合技术在编译器中可以实现跨算子自动融合,减少最终执行时的kernel数目和访存开销,达到更好的性能。
16+
17+
为了丰富 CINN 的基础算子,本次任务计算增加 `resize` 算子。
18+
19+
## 2、名词解释
20+
NCHW :一种图的数据格式。N 指 Batch,C 指 Channel,H 指 Height,W 指 width。
21+
22+
## 3、功能目标
23+
实现 `resize` 算子,将输入图片通过指定插值方法调整为指定大小,输入图片应该是 4-D 张量,且形状为`[N, C, H, W]`,注意调整仅适用于H、W对应维度。
24+
25+
26+
## 4、意义
27+
实现 `resize` 算子,将能进一步完善CINN的基础算子库。
28+
29+
# 二、CINN现状
30+
CINN框架暂不支持 `resize` 算子,需要实现。
31+
32+
# 三、业内方案调研
33+
**TVM 的 `resize` 算子**
34+
35+
在 TVM 中,与本次任务将要实现的算子对应的是 `resize2d` 算子,核心代码如下:
36+
```c++
37+
Expr MakeResize2D(Expr data, Expr size, Expr roi, String layout, String method,
38+
String coordinate_transformation_mode, String rounding_method, double cubic_alpha,
39+
double cubic_exclude, double extrapolation_value, DataType out_dtype) {
40+
auto attrs = make_object<Resize2DAttrs>();
41+
attrs->layout = std::move(layout);
42+
attrs->method = std::move(method);
43+
attrs->coordinate_transformation_mode = coordinate_transformation_mode;
44+
attrs->rounding_method = rounding_method;
45+
attrs->cubic_alpha = cubic_alpha;
46+
attrs->cubic_exclude = cubic_exclude;
47+
attrs->extrapolation_value = extrapolation_value;
48+
attrs->out_dtype = out_dtype;
49+
static const Op& op = Op::Get("dyn.image.resize2d");
50+
return Call(op, {data, size, roi}, Attrs(attrs), {});
51+
}
52+
```
53+
代码在 C++ 侧通过 Call 调用 python 侧的 `resize2d` 实现,python 侧已经以 te 形式实现了算子。
54+
55+
[resize2d compute的核心代码](https://github.com/apache/tvm/blob/5e652c1a7aa173cec6f9e68207b410ad06b2fcec/python/tvm/topi/image/resize.py#L531)
56+
57+
58+
# 四、对比分析
59+
TVM 的 `resize2d` 算子实现详细,可作为参考。本次任务计划使用 custom call 实现 `resize` 算子,参考 [cholesky 算子的实现](https://github.com/PaddlePaddle/CINN/pull/1133)。
60+
61+
# 五、设计思路与实现方案
62+
63+
## 命名与参数设计
64+
**算子参数:**
65+
66+
| 类别 | 类型 | 名称 | Shape | 描述 |
67+
| :-------: | :---------: | :-------: | :------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------: |
68+
| Input | Tensor\<T\> | x | [N, C, in_H, in_W] | 输入张量 |
69+
| Attribute | vector<int> | out_shape | [out_H, out_W] | 调整后的张量大小,只需指定H、W两个维度上的值 |
70+
| Attribute | string | mode | - | 指定插值方法,可选项包括:<br>**NEAREST**(最近邻插值,选取H和W维上最近的值);<br>**BILINEAR**(双线性插值,选取H和W维上相邻四个点做线性插值).<br>默认值BILINEAR |
71+
| Output | Tensor\<T\> | out | [N, C, out_H, out_W] | 输出张量,数据类型与输入张量相同同 |
72+
73+
**支持的数据类型:**
74+
75+
`int32`、`float32`
76+
77+
## 底层OP设计
78+
在 `cinn/hlir/op/contrib` 中新增 `resize` 算子。
79+
```c++
80+
ir::Tensor Resize(const ir::Tensor &x,
81+
vector out_shape,
82+
std::string mode,
83+
std::string &output_name)
84+
```
85+
实现 `resize` 的 strategy:`StrategyForResize``InferDtypeForResize``InferShapeForResize`,并注册算子。
86+
## API实现方案
87+
- c++ 接口
88+
89+
`cinn/frontend` 中的 `NetBuild` 类中增加 `Resize` 函数。
90+
- python 接口
91+
92+
`cinn/pybind/frontend.cc` 中增加 `resize` 算子的接口。
93+
# 六、测试和验收的考量。
94+
`python/tests/ops/test_example_op.py` 中添加 `resize` 算子的测试。测试内容覆盖所有 resize 模式,数据类型。
95+
# 七、可行性分析和排期规划
96+
- 可行性分析
97+
CINN中已经实现了大量的基础算子,在现有的框架基础上能够很好地增加算子功能。
98+
- 排期规划
99+
2月27日 ~ 3月11日完成 API 的开发与调试。
100+
3月12日 ~ 3月19日完成测试代码的开发。
101+
# 八、影响面
102+
本次任务影响模块如下,
103+
`cinn\backends``cinn\frontend``cinn\hlir``cinn\pybind``cinn\runtime`
104+
均是在原模块内增加代码,不影响原模块的已有功能。
105+
# 附件及参考资料
106+
1. [CINN项目贡献指南](https://github.com/PaddlePaddle/CINN/pull/810)
107+
2. [CINN IR抽象语法树](https://github.com/PaddlePaddle/CINN/pull/775)
108+
3. [CINN IR DSL在C++的matmul写法例子](https://github.com/PaddlePaddle/CINN/blob/develop/tutorials/matmul.cc)
109+
4. [CINN算子开发示例:pool2d_grad算子](https://github.com/PaddlePaddle/CINN/pull/858)

rfcs/FastDeploy/20230226_tvm_for_FastDeploy.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,3 @@
4444

4545

4646
# 五、影响面
47-
48-
暂无

0 commit comments

Comments
 (0)