Skip to content

Commit 9106243

Browse files
authored
add method introduction for op-perf (#4937)
* add method introduction for op-perf * add some changes * add some changes * change according to the advices
1 parent 32c967b commit 9106243

6 files changed

+119
-4
lines changed
161 KB
Loading
237 KB
Loading
456 KB
Loading

docs/dev_guides/op_optimization/op_optimization_accpetance_criteria_cn.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,34 @@
88

99
## 性能测试
1010

11-
性能测试建议采用OP Benchmark测试算子性能。经过性能优化,[OP Benchmark](https://github.com/PaddlePaddle/benchmark/tree/master/api)中全部case不能出现性能下降,需要通过列表,对比性能优化前后的OP性能情况。
11+
[OP Benchmark](https://github.com/PaddlePaddle/benchmark/tree/master/api)作为一套测试飞桨内算子性能的专业工具, 如下图所示能够输出各类case下的OP性能真实状态, 建议用其进行算子性能测试。经过性能优化,OP Benchmark中全部case不能出现性能下降,需要通过列表,对比性能优化前后的OP性能情况。
12+
13+
```
14+
===========================================================================
15+
-- paddle version : 0.0.0
16+
-- paddle commit : 9b7126d05987b725ad3fb31f31298218c860b2f5
17+
-- benchmark commit : a6ba32197d7b3adb1dcc95b803f8f0d7fa18322c
18+
-- benchmark last update time : Wed Jun 15 02:45:49 2022 +0000
19+
===========================================================================
20+
run command: nvprof --profile-from-start off /work/.virtualenvs_cuda10.2/paddle_py38/bin/python /work/benchmark/api/dynamic_tests_v2/adaptive_avg_pool2d.py --api_name adaptive_avg_pool2d --task speed --framework paddle --testing_mode dynamic --json_file /work/benchmark/api/tests_v2/configs/adaptive_avg_pool2d.json --config_id 0 --backward False --use_gpu True --repeat 1000 --allow_adaptive_repeat True --profiler nvprof
21+
Type Time(%) Time Calls Avg Min Max Name
22+
GPU activities: 100.00% 437.44ms 1000 437.44us 422.98us 472.29us void phi::funcs::KernelPool2D<phi::funcs::AvgPool<float>, float>(int, phi::funcs::AvgPool<float> const *, int, int, int, int, int, int, int, int, int, int, int, phi::funcs::FastDivModForPooling, float, bool, bool, phi::funcs::KernelPool2D<phi::funcs::AvgPool<float>, float>*, bool)
23+
24+
total gpu_time: 437.4400 ms
25+
26+
W0615 14:55:43.819144 28877 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.4, Runtime API Version: 10.2 , cuDNN Version: 7.6.
27+
28+
[paddle][adaptive_avg_pool2d] adaptive_avg_pool2d {
29+
run_tf: True
30+
run_torch: True
31+
data_format: NCHW
32+
output_size: [32, 32]
33+
x_shape: [4, 2048, 64, 128]
34+
x_dtype: float32
35+
atol: 1e-06
36+
}
37+
{"framework": "paddle", "version": "0.0.0", "name": "adaptive_avg_pool2d", "device": "GPU", "backward": false, "speed": {"repeat": 1000, "begin": 10, "end": 990, "total": 0.6467142883612185, "wall_time": 0, "total_include_wall_time": 0.6467142883612185, "gpu_time": 0.43744}, "parameters": "x (Variable) - dtype: float32, shape: [4, 2048, 64, 128]\ndata_format (string): NCHW\noutput_size (list): [32, 32]\n"}
38+
```
1239

1340
## PR内容描述要求
1441

@@ -20,7 +47,7 @@
2047

2148
3. PR性能优化方案概述
2249

23-
4. 性能优化对比表格
50+
4. 优化前后算子性能对比表格
2451

2552
## OP测试内容及单元测试要求
2653

docs/dev_guides/op_optimization/op_optimization_contributing_guides_cn.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
::::::::::::::::::::::
1010

1111
.. image:: ../images/op_optimization_contributing_guides.png
12-
:width: 1000
12+
:width: 900
1313
:alt: op_optimization_contributing_guides
1414
:align: center
1515

@@ -65,7 +65,9 @@
6565
"算子性能优化实现代码", "- `Paddle代码规范 <https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/style_guides_cn.html>`_
6666
- `C++ OP开发指南 <../api_contributing_guides/new_cpp_op_cn.html>`_
6767
- `OP Benchmark使用指南 <https://github.com/PaddlePaddle/benchmark/blob/master/api>`_
68-
- `算子性能优化 验收规范 <./op_optimization_accpetance_criteria_cn.html>`_ ", "`Github飞桨训练框架仓库 <https://github.com/PaddlePaddle/Paddle>`_"
68+
- `算子性能优化 优化方法 <./op_optimization_method_introduction_cn.html>`_
69+
- `算子性能优化 验收规范 <./op_optimization_accpetance_criteria_cn.html>`_
70+
", "`Github飞桨训练框架仓库 <https://github.com/PaddlePaddle/Paddle>`_"
6971

7072

7173
当你完成以上代码设计后,需要将代码提交至 `Github飞桨训练框架仓库 <https://github.com/PaddlePaddle/Paddle>`_ ,并根据 `本地开发指南 <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/git_guides/local_dev_guide_cn.html>`_ 提交PR、准备接受社区的评审。
@@ -95,4 +97,5 @@
9597
.. toctree::
9698
:hidden:
9799

100+
op_optimization_method_introduction_cn.md
98101
op_optimization_accpetance_criteria_cn.md
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# 算子性能优化 方法介绍
2+
3+
提供高性能的计算服务是飞桨的特色之一, 欢迎开发者为飞桨贡献高性能算子, 本文旨在向开发者提供一些快速实现高性能算子的方法。
4+
5+
# 基本介绍
6+
7+
- 算子性能优化工作的业务范围涵盖前向算子、反向算子、优化器等.
8+
9+
- 算子性能优化工作的基本目标是获得明显的算子性能提升, 力争达到业界一流的性能水平, 同时保证精度不会下降.
10+
11+
- 飞桨内算子性能优化主要围绕GPU计算开展, 因此需要用户掌握基本的[GPU编程模型](https://developer.nvidia.com/zh-cn/blog/cuda-model-intro-cn/).
12+
13+
14+
# 优化技巧
15+
16+
## 1.通用优化技巧
17+
18+
GPU Kernel直接影响了算子性能, 我们推荐采用以下等通用优化策略提升GPU Kernel的性能, 从而削减算子的计算开销.
19+
20+
| 通用技巧 |
21+
| -- |
22+
| [向量化读写](https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access>)|
23+
| [协线程操作](https://developer.nvidia.com/blog/cooperative-groups/>) |
24+
| [Warp级操作](https://developer.nvidia.com/blog/using-cuda-warp-level-primitives>) |
25+
| [共享内存操作](<https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/>) ([注意Bank Conflicts](https://developer.nvidia.com/blog/using-shared-memory-cuda-cc/)) |
26+
27+
28+
## 2. 飞桨内置优化技巧
29+
30+
我们在飞桨内开发并封装了一些优化技巧, 具体如下表所示, 欢迎使用, 也欢迎在使用过程中提出修改建议.
31+
32+
### 2.1 [线程配置优化](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/backends/gpu/gpu_launch_config.h)
33+
34+
我们推荐结合OP的使用场景设计对于的线程配置策略,如下图所示[IndexSample OP](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/index_sample_cn.html#index-sample)常用于处理2维数据, 因此使用[2维的线程配置策略](https://github.com/PaddlePaddle/Paddle/blob/30838aa698d6f3f3b0860b052f6a50ef53ac6784/paddle/phi/kernels/gpu/index_sample_kernel.cu#L82-L91)相对比1维配置策略,性能可提升20%左右。
35+
36+
<figure align="center">
37+
<img src="../images/index_sample.png" width=80% height=80%/>
38+
<figcaption><center>图1. IndexSample OP 线程配置策略</center></figcaption>
39+
</figure>
40+
41+
优化GPU Kernel中的线程配置策略, 涵盖一维、二维、三维线程配置策略, 目前已经在`Elementwise`, `Stack`, `IndexSample`等OP中使用.
42+
43+
### 2.2 [Warp计算优化](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/funcs/math_cuda_utils.h)
44+
45+
飞桨内对上文中提到的**Warp级操作**进行了封装, 提供了简易的调用接口, 开发者可调用接口快速获得Warp内或者Block内的全部数据的求和、最大值、最小值.
46+
47+
<figure align="center">
48+
<img src="../images/cuda_math_utils.png" width=80% height=80%/>
49+
<figcaption><center>图2. Warp级操作封装</center></figcaption>
50+
</figure>
51+
52+
### 2.3 [索引计算优化](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/platform/fast_divmod.h):
53+
54+
当GPU Kernel的索引计算中存在除法或取模操作, 将在导致汇编层面计算开销变大, 我们建议采用快速除法优化这部分的计算开销。飞桨内[Pooling OP](https://github.com/PaddlePaddle/Paddle/blob/890c73158f663b327be7664ed6c4d08fb2c236a9/paddle/phi/kernels/funcs/pooling.cu#L41-L101) 采用索引优化计算后, 性能提升1倍.
55+
56+
<figure align="center">
57+
<img src="../images/fast_divmod.png" width=50% height=50%/>
58+
<figcaption><center>图3. 快速整型除法操作</center></figcaption>
59+
</figure>
60+
61+
### 2.4 [Kps优化工具库](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/kernel_primitive_api/index_cn.html)
62+
63+
飞桨综合了一系列GPU Kernel通用性能优化技巧推出了Kernel Primitive API,提供高性能的 Block 级 IO 运算和 Compute 运算。使用 Kernel Primitive API 进行 Kernel 开发可以更加专注计算逻辑的实现,在保证性能的同时大幅减少代码量,同时实现了算子计算与硬件解耦,详情见官网[Kernel Primitive API](https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/kernel_primitive_api/index_cn.html), 建议参考案例[ElementwiseAdd](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/kernel_primitive_api/add_example_cn.html)[Reduce](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/kernel_primitive_api/reduce_example_cn.html) 使用。
64+
65+
66+
### 3. C++模板特性
67+
68+
我们也鼓励充分挖掘C++侧的可用优化点, 如使用`#pragma unroll`编译阶段加速指令,编译期自动展循环, 加速运行时循环的执行效率.
69+
70+
- 案例: [Elementwise_add OP](https://github.com/PaddlePaddle/Paddle/blob/30838aa698d6f3f3b0860b052f6a50ef53ac6784/paddle/phi/kernels/funcs/elementwise_base.h#L658-L661) 采用模板参数加速循环展开, 性能提升约5%
71+
72+
```
73+
struct SameDimsElementwisePrimitiveCaller {
74+
__device__ inline void operator()(Functor func, ArgsT *args, OutT *result) {
75+
#pragma unroll
76+
for (int idx = 0; idx < VecSize; ++idx) {
77+
result[idx] = static_cast<OutT>(Apply(func, args[idx]));
78+
}
79+
}
80+
};
81+
```
82+
83+
### 4. 内置第三方库
84+
85+
飞桨内置了cuBLAS, cuDNN, cuSOLVER, Thrust等一系列第三方库, 若采用这些第三方等高性能计算库能获得显著的性能收益,也欢迎使用。cuBLAS使用示例见[matmul_kernel_impl.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/impl/matmul_kernel_impl.h), cuDNN的使用示例见[conv_kernel.cu](https://github.com/PaddlePaddle/Paddle/blob/30838aa698d6f3f3b0860b052f6a50ef53ac6784/paddle/phi/kernels/gpudnn/conv_kernel.cu#L366-L379), cuSOLVER使用示例见[values_vectors_functor.h](https://github.com/PaddlePaddle/Paddle/blob/30838aa698d6f3f3b0860b052f6a50ef53ac6784/paddle/phi/kernels/funcs/values_vectors_functor.h#L219-L260), Thrust使用示例见[coalesced_kernel.cu](https://github.com/PaddlePaddle/Paddle/blob/30838aa698d6f3f3b0860b052f6a50ef53ac6784/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu#L93-L106).

0 commit comments

Comments
 (0)