Skip to content

Commit 35d23af

Browse files
authored
[BIT] Fix fused_linear, fused_multi_head_attention doc (#7336)
* add lost fused_linear argument: trans_x description * align pseudocode with fused_multi_head_attention implementation in fused_transformer.py
1 parent 1ab8ef4 commit 35d23af

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

docs/api/paddle/incubate/nn/functional/fused_linear_cn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ fused_linear
1212
- **x** (Tensor) – 需要进行乘法运算的输入 Tensor。
1313
- **weight** (Tensor) – 需要进行乘法运算的权重 Tensor,它的阶数必须为 2。
1414
- **bias** (Tensor, 可选) – 输入的偏置 Tensor。如果为 None ,则不执行偏置加法。否则,将偏置加到矩阵乘法的结果上。默认值为 None。
15+
- **trans_x** (bool, 可选) - 是否在乘法之前转置输入 Tensor。默认值:False。
1516
- **transpose_weight** (bool, 可选) - 是否在乘法之前转置权重。默认值:False。
1617
- **name** (str, 可选) - 如需详细信息,请参阅 :ref:`api_guide_Name` 。一般无需设置,默认值为 None。
1718

docs/api/paddle/incubate/nn/functional/fused_multi_head_attention_cn.rst

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,30 @@ fused_multi_head_attention 算子目前只支持在 GPU 下运行,其包含的
1616
.. code-block:: ipython
1717
1818
# pseudocode
19+
residual = x
1920
if pre_layer_norm:
20-
out = layer_norm(x)
21-
out = linear(out) + qkv) + bias
21+
out = layer_norm(x)
2222
else:
23-
out = linear(x) + bias
23+
out = x
24+
out = matmul(out, qkv_weight) + qkv_bias
2425
out = transpose(out, perm=[2, 0, 3, 1, 4])
2526
# extract q, k and v from out.
26-
q = out[0:1,::]
27+
q = out[0:1,::] * (head_dim ** -0.5)
2728
k = out[1:2,::]
2829
v = out[2:3,::]
29-
out = q * k^t
30+
out = matmul(q, k, transpose_y=True)
3031
out = attn_mask + out
3132
out = softmax(out)
3233
out = dropout(out)
33-
out = out * v
34+
out = matmul(out, v)
3435
out = transpose(out, perm=[0, 2, 1, 3])
35-
out = out_linear(out)
36-
if pre_layer_norm:
37-
out = x + dropout(linear_bias + out)
36+
out = linear(out, bias=None)
37+
if add_residual:
38+
out = residual + dropout(out + linear_bias)
3839
else:
39-
out = layer_norm(x + dropout(linear_bias + out))
40+
out = dropout(out + linear_bias)
41+
if not pre_layer_norm:
42+
out = layer_norm(out)
4043
4144
4245
值得注意的是,该 API 中,q, k, v 的 weight 被统一存储在一个权重 Tensor 中,形状为 `[3, num_heads, head_dim, embed_dim]` ,

0 commit comments

Comments
 (0)