Skip to content

在scaled_dot_product_attention函数中加入bool mask #72927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 6, 2025
6 changes: 6 additions & 0 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,12 @@ def scaled_dot_product_attention(
sdp_func_name = _select_sdp_for_sdpa(
query, key, attn_mask, dropout_p, is_causal
)
if attn_mask.dtype == paddle.bool:
attn_mask = paddle.where(
attn_mask,
paddle.to_tensor(0.0, dtype=query.dtype),
paddle.to_tensor(-float('inf'), dtype=query.dtype),
)
if sdp_func_name == "flash_attn":
if in_dynamic_or_pir_mode():
fixed_seed_offset = None
Expand Down
185 changes: 185 additions & 0 deletions test/legacy_test/test_scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
import paddle.nn.functional as F
from paddle.nn.functional.flash_attention import (
scaled_dot_product_attention,
sdp_kernel,
)


def attention_naive(q, k, v, causal=False):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt * scale, paddle.transpose(kt, [0, 1, 3, 2]))
p = (
paddle.incubate.softmax_mask_fuse_upper_triangle(s)
if causal
else F.softmax(s)
)
o = paddle.matmul(p, vt)
return paddle.transpose(o, [0, 2, 1, 3])


def attention_naive_with_mask(q, k, v, attn_bias):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
p = F.softmax(s + attn_bias)
o = paddle.matmul(p, vt)
return paddle.transpose(o, [0, 2, 1, 3])


def attention_naive_with_bool_mask(q, k, v, bool_mask):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])

scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)

float_mask = paddle.where(
bool_mask,
paddle.to_tensor(0.0, dtype=q.dtype),
paddle.to_tensor(-float('inf'), dtype=q.dtype),
)

s = s + float_mask
p = F.softmax(s)

o = paddle.matmul(p, vt)
return paddle.transpose(o, [0, 2, 1, 3])


@unittest.skipIf(
not paddle.is_compiled_with_cuda(),
"CUDA is not available, this test requires GPU support.",
)
class TestAttentionWithBoolMask(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (1, 1, 8, 8)
self.dtype = 'float32'
self.dropout = 0.0
self.causal = False

def test_dot_scale_product_bool_mask(self):
# test dynamic
paddle.disable_static()

query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)

q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1])
bool_mask = np.random.choice([True, False], size=mask_shape)

m = paddle.to_tensor(
bool_mask, place=self.place, dtype=paddle.bool, stop_gradient=False
)

with sdp_kernel(
enable_math=True, enable_flash=False, enable_mem_efficient=False
):
out = scaled_dot_product_attention(
q, k, v, m, self.dropout, self.causal
)

out_ = attention_naive_with_bool_mask(q_, k_, v_, m)

out.backward()
out_.backward()

np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)

def test_dot_scale_product_float_mask(self):
# test with mask=float
paddle.disable_static()

query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)

q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1])
mask = np.random.random(mask_shape)
m = paddle.to_tensor(
mask, place=self.place, dtype=self.dtype, stop_gradient=False
)

with sdp_kernel(
enable_math=True, enable_flash=False, enable_mem_efficient=False
):
out = scaled_dot_product_attention(
q, k, v, m, self.dropout, self.causal
)

out_ = attention_naive_with_mask(q_, k_, v_, m)
out.backward()
out_.backward()
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你本地在PaConvert里的sdpa的单测里加一下attn_mask为bool的测试例子,测试一下计算结果是否和pytorch 一致。附一下paconvert测试结果。

然后映射文档也记得修改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,已修改。之前用的是python3.8虚拟环境,升级python3.9重装环境花费了一些时间。
PaddlePaddle/PaConvert#586



if __name__ == '__main__':
unittest.main()
Loading