Skip to content

Commit caa5621

Browse files
authored
在scaled_dot_product_attention函数中加入bool mask (#72927)
* added bool mask in attention modified: python/paddle/nn/functional/flash_attention.py modified: test/legacy_test/test_flash_attention.py * updated tests modified: test/legacy_test/test_flash_attention.py * added a new test modified: test/legacy_test/test_flash_attention.py new file: test/legacy_test/test_scaled_dot_product_attention.py * update for test modified: test/legacy_test/test_scaled_dot_product_attention.py * updated test para again modified: test/legacy_test/test_scaled_dot_product_attention.py * updated for tests modified: python/paddle/nn/functional/flash_attention.py modified: test/legacy_test/test_scaled_dot_product_attention.py * updated for DCU test modified: test/legacy_test/test_scaled_dot_product_attention.py * updated func modified: python/paddle/nn/functional/flash_attention.py * updated func modified: python/paddle/nn/functional/flash_attention.py * updated test modified: test/legacy_test/test_scaled_dot_product_attention.py * updated test modified: test/legacy_test/test_scaled_dot_product_attention.py
1 parent 37d2e1f commit caa5621

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed

python/paddle/nn/functional/flash_attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,12 @@ def scaled_dot_product_attention(
14231423
sdp_func_name = _select_sdp_for_sdpa(
14241424
query, key, attn_mask, dropout_p, is_causal
14251425
)
1426+
if attn_mask.dtype == paddle.bool:
1427+
attn_mask = paddle.where(
1428+
attn_mask,
1429+
paddle.to_tensor(0.0, dtype=query.dtype),
1430+
paddle.to_tensor(-float('inf'), dtype=query.dtype),
1431+
)
14261432
if sdp_func_name == "flash_attn":
14271433
if in_dynamic_or_pir_mode():
14281434
fixed_seed_offset = None
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
import paddle.nn.functional as F
21+
from paddle.nn.functional.flash_attention import (
22+
scaled_dot_product_attention,
23+
sdp_kernel,
24+
)
25+
26+
27+
def attention_naive(q, k, v, causal=False):
28+
qt = paddle.transpose(q, [0, 2, 1, 3])
29+
kt = paddle.transpose(k, [0, 2, 1, 3])
30+
vt = paddle.transpose(v, [0, 2, 1, 3])
31+
scale = 1.0 / np.sqrt(q.shape[-1])
32+
s = paddle.matmul(qt * scale, paddle.transpose(kt, [0, 1, 3, 2]))
33+
p = (
34+
paddle.incubate.softmax_mask_fuse_upper_triangle(s)
35+
if causal
36+
else F.softmax(s)
37+
)
38+
o = paddle.matmul(p, vt)
39+
return paddle.transpose(o, [0, 2, 1, 3])
40+
41+
42+
def attention_naive_with_mask(q, k, v, attn_bias):
43+
qt = paddle.transpose(q, [0, 2, 1, 3])
44+
kt = paddle.transpose(k, [0, 2, 1, 3])
45+
vt = paddle.transpose(v, [0, 2, 1, 3])
46+
scale = 1.0 / np.sqrt(q.shape[-1])
47+
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
48+
s = paddle.scale(s, scale)
49+
p = F.softmax(s + attn_bias)
50+
o = paddle.matmul(p, vt)
51+
return paddle.transpose(o, [0, 2, 1, 3])
52+
53+
54+
def attention_naive_with_bool_mask(q, k, v, bool_mask):
55+
qt = paddle.transpose(q, [0, 2, 1, 3])
56+
kt = paddle.transpose(k, [0, 2, 1, 3])
57+
vt = paddle.transpose(v, [0, 2, 1, 3])
58+
59+
scale = 1.0 / np.sqrt(q.shape[-1])
60+
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
61+
s = paddle.scale(s, scale)
62+
63+
float_mask = paddle.where(
64+
bool_mask,
65+
paddle.to_tensor(0.0, dtype=q.dtype),
66+
paddle.to_tensor(-float('inf'), dtype=q.dtype),
67+
)
68+
69+
s = s + float_mask
70+
p = F.softmax(s)
71+
72+
o = paddle.matmul(p, vt)
73+
return paddle.transpose(o, [0, 2, 1, 3])
74+
75+
76+
@unittest.skipIf(
77+
not paddle.is_compiled_with_cuda(),
78+
"CUDA is not available, this test requires GPU support.",
79+
)
80+
class TestAttentionWithBoolMask(unittest.TestCase):
81+
def setUp(self):
82+
self.place = paddle.CUDAPlace(0)
83+
self.shape = (1, 1, 8, 8)
84+
self.dtype = 'float32'
85+
self.dropout = 0.0
86+
self.causal = False
87+
88+
def test_dot_scale_product_bool_mask(self):
89+
# test dynamic
90+
paddle.disable_static()
91+
92+
query = np.random.random(self.shape)
93+
key = np.random.random(self.shape)
94+
value = np.random.random(self.shape)
95+
96+
q = paddle.to_tensor(
97+
query, place=self.place, dtype=self.dtype, stop_gradient=False
98+
)
99+
k = paddle.to_tensor(
100+
key, place=self.place, dtype=self.dtype, stop_gradient=False
101+
)
102+
v = paddle.to_tensor(
103+
value, place=self.place, dtype=self.dtype, stop_gradient=False
104+
)
105+
106+
q_ = paddle.to_tensor(
107+
query, place=self.place, dtype=self.dtype, stop_gradient=False
108+
)
109+
k_ = paddle.to_tensor(
110+
key, place=self.place, dtype=self.dtype, stop_gradient=False
111+
)
112+
v_ = paddle.to_tensor(
113+
value, place=self.place, dtype=self.dtype, stop_gradient=False
114+
)
115+
116+
mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1])
117+
bool_mask = np.random.choice([True, False], size=mask_shape)
118+
119+
m = paddle.to_tensor(
120+
bool_mask, place=self.place, dtype=paddle.bool, stop_gradient=False
121+
)
122+
123+
with sdp_kernel(
124+
enable_math=True, enable_flash=False, enable_mem_efficient=False
125+
):
126+
out = scaled_dot_product_attention(
127+
q, k, v, m, self.dropout, self.causal
128+
)
129+
130+
out_ = attention_naive_with_bool_mask(q_, k_, v_, m)
131+
132+
out.backward()
133+
out_.backward()
134+
135+
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
136+
137+
def test_dot_scale_product_float_mask(self):
138+
# test with mask=float
139+
paddle.disable_static()
140+
141+
query = np.random.random(self.shape)
142+
key = np.random.random(self.shape)
143+
value = np.random.random(self.shape)
144+
145+
q = paddle.to_tensor(
146+
query, place=self.place, dtype=self.dtype, stop_gradient=False
147+
)
148+
k = paddle.to_tensor(
149+
key, place=self.place, dtype=self.dtype, stop_gradient=False
150+
)
151+
v = paddle.to_tensor(
152+
value, place=self.place, dtype=self.dtype, stop_gradient=False
153+
)
154+
155+
q_ = paddle.to_tensor(
156+
query, place=self.place, dtype=self.dtype, stop_gradient=False
157+
)
158+
k_ = paddle.to_tensor(
159+
key, place=self.place, dtype=self.dtype, stop_gradient=False
160+
)
161+
v_ = paddle.to_tensor(
162+
value, place=self.place, dtype=self.dtype, stop_gradient=False
163+
)
164+
165+
mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1])
166+
mask = np.random.random(mask_shape)
167+
m = paddle.to_tensor(
168+
mask, place=self.place, dtype=self.dtype, stop_gradient=False
169+
)
170+
171+
with sdp_kernel(
172+
enable_math=True, enable_flash=False, enable_mem_efficient=False
173+
):
174+
out = scaled_dot_product_attention(
175+
q, k, v, m, self.dropout, self.causal
176+
)
177+
178+
out_ = attention_naive_with_mask(q_, k_, v_, m)
179+
out.backward()
180+
out_.backward()
181+
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
182+
183+
184+
if __name__ == '__main__':
185+
unittest.main()

0 commit comments

Comments
 (0)