diff --git a/docs/api/paddle/incubate/nn/functional/fused_linear_cn.rst b/docs/api/paddle/incubate/nn/functional/fused_linear_cn.rst index b3843626de4..d5322691d2d 100644 --- a/docs/api/paddle/incubate/nn/functional/fused_linear_cn.rst +++ b/docs/api/paddle/incubate/nn/functional/fused_linear_cn.rst @@ -12,6 +12,7 @@ fused_linear - **x** (Tensor) – 需要进行乘法运算的输入 Tensor。 - **weight** (Tensor) – 需要进行乘法运算的权重 Tensor,它的阶数必须为 2。 - **bias** (Tensor, 可选) – 输入的偏置 Tensor。如果为 None ,则不执行偏置加法。否则,将偏置加到矩阵乘法的结果上。默认值为 None。 + - **trans_x** (bool, 可选) - 是否在乘法之前转置输入 Tensor。默认值:False。 - **transpose_weight** (bool, 可选) - 是否在乘法之前转置权重。默认值:False。 - **name** (str, 可选) - 如需详细信息,请参阅 :ref:`api_guide_Name` 。一般无需设置,默认值为 None。 diff --git a/docs/api/paddle/incubate/nn/functional/fused_multi_head_attention_cn.rst b/docs/api/paddle/incubate/nn/functional/fused_multi_head_attention_cn.rst index 741e1421a97..aefe992da37 100644 --- a/docs/api/paddle/incubate/nn/functional/fused_multi_head_attention_cn.rst +++ b/docs/api/paddle/incubate/nn/functional/fused_multi_head_attention_cn.rst @@ -16,27 +16,30 @@ fused_multi_head_attention 算子目前只支持在 GPU 下运行,其包含的 .. code-block:: ipython # pseudocode + residual = x if pre_layer_norm: - out = layer_norm(x) - out = linear(out) + qkv) + bias + out = layer_norm(x) else: - out = linear(x) + bias + out = x + out = matmul(out, qkv_weight) + qkv_bias out = transpose(out, perm=[2, 0, 3, 1, 4]) # extract q, k and v from out. - q = out[0:1,::] + q = out[0:1,::] * (head_dim ** -0.5) k = out[1:2,::] v = out[2:3,::] - out = q * k^t + out = matmul(q, k, transpose_y=True) out = attn_mask + out out = softmax(out) out = dropout(out) - out = out * v + out = matmul(out, v) out = transpose(out, perm=[0, 2, 1, 3]) - out = out_linear(out) - if pre_layer_norm: - out = x + dropout(linear_bias + out) + out = linear(out, bias=None) + if add_residual: + out = residual + dropout(out + linear_bias) else: - out = layer_norm(x + dropout(linear_bias + out)) + out = dropout(out + linear_bias) + if not pre_layer_norm: + out = layer_norm(out) 值得注意的是,该 API 中,q, k, v 的 weight 被统一存储在一个权重 Tensor 中,形状为 `[3, num_heads, head_dim, embed_dim]` ,