Skip to content

Commit 6dafdc9

Browse files
authored
[AutoParallel] add locallayer doc (#7033)
* add locallayer doc * add locallayer doc * update example codes * update * update * update doc
1 parent 6704fff commit 6dafdc9

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
.. _cn_api_paddle_distributed_LocalLayer:
2+
3+
LocalLayer
4+
-------------------------------
5+
6+
.. py:class:: paddle.distributed.LocalLayer(out_dist_attrs)
7+
8+
LocalLayer 用于在分布式训练中实现局部计算操作。在自动并行训练中,某些操作(如带 mask 的 loss 计算、MoE 相关计算等)需要在每张卡上独立进行局部计算,而不是直接在全局分布式张量上计算。LocalLayer 通过自动处理张量转换,使得用户可以像编写单卡代码一样实现这些局部操作。
9+
10+
参数
11+
:::::::::
12+
13+
- **out_dist_attrs** (list[tuple[ProcessMesh, list[Placement]]]) - 指定输出张量的分布策略。每个元素是一个元组,包含:
14+
15+
- ProcessMesh: 计算设备网格,定义计算资源的拓扑结构
16+
- list[Placement]: 张量分布方式的列表,描述如何将局部计算结果转换回分布式张量
17+
18+
**代码示例**
19+
20+
.. code-block:: python
21+
22+
import paddle
23+
import paddle.distributed as dist
24+
from paddle.distributed import Placement, ProcessMesh, LocalLayer
25+
26+
class CustomLayer(dist.LocalLayer):
27+
def __init__(self, out_dist_attrs):
28+
super().__init__(out_dist_attrs)
29+
self.local_result = paddle.to_tensor(0.0)
30+
def forward(self, x):
31+
mask = paddle.zeros_like(x)
32+
if dist.get_rank() == 0:
33+
mask[1:3] = 1
34+
else:
35+
mask[4:7] = 1
36+
x = x * mask
37+
mask_sum = paddle.sum(x)
38+
mask_sum = mask_sum / mask.sum()
39+
self.local_result = mask_sum
40+
return mask_sum
41+
42+
dist.init_parallel_env()
43+
mesh = ProcessMesh([0, 1], dim_names=["x"])
44+
out_dist_attrs = [
45+
(mesh, [dist.Partial(dist.ReduceType.kRedSum)]),
46+
]
47+
48+
local_input = paddle.arange(0, 10, dtype='float32')
49+
local_input = local_input + dist.get_rank()
50+
input_dist = dist.auto_parallel.api.dtensor_from_local(
51+
local_input,
52+
mesh,
53+
[dist.Shard(0)]
54+
)
55+
56+
custom_layer = CustomLayer(out_dist_attrs)
57+
output_dist = custom_layer(input_dist)
58+
local_value = custom_layer.local_result
59+
60+
gathered_values = []
61+
dist.all_gather(gathered_values, local_value)
62+
print(f"[Rank 0] local_loss={gathered_values[0]}")
63+
# [Rank 0] local_loss=1.5
64+
print(f"[Rank 1] local_loss={gathered_values[1]}")
65+
# [Rank 1] local_loss=6.0
66+
print(f"global_loss (distributed)={output_dist}")
67+
# global_loss (distributed)=7.5
68+
69+
70+
方法
71+
:::::::::
72+
73+
__call__()
74+
'''''''''
75+
76+
执行局部计算的核心方法。该方法会:
77+
78+
1. 将输入的分布式张量转换为本地张量
79+
2. 在本地执行前向计算
80+
3. 将计算结果按照指定的分布策略转换回分布式张量
81+
82+
**参数**
83+
84+
- **inputs** (Any) - 输入张量,通常是分布式张量
85+
- **kwargs** (Any) - 额外的关键字参数
86+
87+
**返回**
88+
89+
按照 out_dist_attrs 指定的分布策略转换后的分布式张量
90+
91+
**使用场景**
92+
93+
LocalLayer 可以用于但不限于以下场景:
94+
95+
1. 带 mask 的 loss 计算:需要在每张卡上独立计算 masked token 的 loss
96+
2. MoE (混合专家模型)相关计算:
97+
- aux_loss 计算:基于每张卡上专家分配到的局部 token 数进行计算
98+
- z_loss 计算:对每张卡上的 logits 独立计算 z_loss
99+
- 张量 reshape 操作:在局部维度上进行 shape 变换
100+
3. 其他需要保持局部计算语义的场景
101+
102+
**注意事项**
103+
104+
1. LocalLayer 的输出必须指定正确的分布策略,以确保结果的正确性
105+
2. 在 forward 方法中编写计算逻辑时,可以像单卡编程一样使用常规的 tensor 操作
106+
3. 局部计算结果会自动根据分布策略进行聚合,无需手动添加通信操作

0 commit comments

Comments
 (0)