|
| 1 | +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import unittest |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +import paddle |
| 20 | +import paddle.nn.functional as F |
| 21 | +from paddle.nn.functional.flash_attention import ( |
| 22 | + scaled_dot_product_attention, |
| 23 | + sdp_kernel, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +def attention_naive(q, k, v, causal=False): |
| 28 | + qt = paddle.transpose(q, [0, 2, 1, 3]) |
| 29 | + kt = paddle.transpose(k, [0, 2, 1, 3]) |
| 30 | + vt = paddle.transpose(v, [0, 2, 1, 3]) |
| 31 | + scale = 1.0 / np.sqrt(q.shape[-1]) |
| 32 | + s = paddle.matmul(qt * scale, paddle.transpose(kt, [0, 1, 3, 2])) |
| 33 | + p = ( |
| 34 | + paddle.incubate.softmax_mask_fuse_upper_triangle(s) |
| 35 | + if causal |
| 36 | + else F.softmax(s) |
| 37 | + ) |
| 38 | + o = paddle.matmul(p, vt) |
| 39 | + return paddle.transpose(o, [0, 2, 1, 3]) |
| 40 | + |
| 41 | + |
| 42 | +def attention_naive_with_mask(q, k, v, attn_bias): |
| 43 | + qt = paddle.transpose(q, [0, 2, 1, 3]) |
| 44 | + kt = paddle.transpose(k, [0, 2, 1, 3]) |
| 45 | + vt = paddle.transpose(v, [0, 2, 1, 3]) |
| 46 | + scale = 1.0 / np.sqrt(q.shape[-1]) |
| 47 | + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) |
| 48 | + s = paddle.scale(s, scale) |
| 49 | + p = F.softmax(s + attn_bias) |
| 50 | + o = paddle.matmul(p, vt) |
| 51 | + return paddle.transpose(o, [0, 2, 1, 3]) |
| 52 | + |
| 53 | + |
| 54 | +def attention_naive_with_bool_mask(q, k, v, bool_mask): |
| 55 | + qt = paddle.transpose(q, [0, 2, 1, 3]) |
| 56 | + kt = paddle.transpose(k, [0, 2, 1, 3]) |
| 57 | + vt = paddle.transpose(v, [0, 2, 1, 3]) |
| 58 | + |
| 59 | + scale = 1.0 / np.sqrt(q.shape[-1]) |
| 60 | + s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2])) |
| 61 | + s = paddle.scale(s, scale) |
| 62 | + |
| 63 | + float_mask = paddle.where( |
| 64 | + bool_mask, |
| 65 | + paddle.to_tensor(0.0, dtype=q.dtype), |
| 66 | + paddle.to_tensor(-float('inf'), dtype=q.dtype), |
| 67 | + ) |
| 68 | + |
| 69 | + s = s + float_mask |
| 70 | + p = F.softmax(s) |
| 71 | + |
| 72 | + o = paddle.matmul(p, vt) |
| 73 | + return paddle.transpose(o, [0, 2, 1, 3]) |
| 74 | + |
| 75 | + |
| 76 | +@unittest.skipIf( |
| 77 | + not paddle.is_compiled_with_cuda(), |
| 78 | + "CUDA is not available, this test requires GPU support.", |
| 79 | +) |
| 80 | +class TestAttentionWithBoolMask(unittest.TestCase): |
| 81 | + def setUp(self): |
| 82 | + self.place = paddle.CUDAPlace(0) |
| 83 | + self.shape = (1, 1, 8, 8) |
| 84 | + self.dtype = 'float32' |
| 85 | + self.dropout = 0.0 |
| 86 | + self.causal = False |
| 87 | + |
| 88 | + def test_dot_scale_product_bool_mask(self): |
| 89 | + # test dynamic |
| 90 | + paddle.disable_static() |
| 91 | + |
| 92 | + query = np.random.random(self.shape) |
| 93 | + key = np.random.random(self.shape) |
| 94 | + value = np.random.random(self.shape) |
| 95 | + |
| 96 | + q = paddle.to_tensor( |
| 97 | + query, place=self.place, dtype=self.dtype, stop_gradient=False |
| 98 | + ) |
| 99 | + k = paddle.to_tensor( |
| 100 | + key, place=self.place, dtype=self.dtype, stop_gradient=False |
| 101 | + ) |
| 102 | + v = paddle.to_tensor( |
| 103 | + value, place=self.place, dtype=self.dtype, stop_gradient=False |
| 104 | + ) |
| 105 | + |
| 106 | + q_ = paddle.to_tensor( |
| 107 | + query, place=self.place, dtype=self.dtype, stop_gradient=False |
| 108 | + ) |
| 109 | + k_ = paddle.to_tensor( |
| 110 | + key, place=self.place, dtype=self.dtype, stop_gradient=False |
| 111 | + ) |
| 112 | + v_ = paddle.to_tensor( |
| 113 | + value, place=self.place, dtype=self.dtype, stop_gradient=False |
| 114 | + ) |
| 115 | + |
| 116 | + mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1]) |
| 117 | + bool_mask = np.random.choice([True, False], size=mask_shape) |
| 118 | + |
| 119 | + m = paddle.to_tensor( |
| 120 | + bool_mask, place=self.place, dtype=paddle.bool, stop_gradient=False |
| 121 | + ) |
| 122 | + |
| 123 | + with sdp_kernel( |
| 124 | + enable_math=True, enable_flash=False, enable_mem_efficient=False |
| 125 | + ): |
| 126 | + out = scaled_dot_product_attention( |
| 127 | + q, k, v, m, self.dropout, self.causal |
| 128 | + ) |
| 129 | + |
| 130 | + out_ = attention_naive_with_bool_mask(q_, k_, v_, m) |
| 131 | + |
| 132 | + out.backward() |
| 133 | + out_.backward() |
| 134 | + |
| 135 | + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) |
| 136 | + |
| 137 | + def test_dot_scale_product_float_mask(self): |
| 138 | + # test with mask=float |
| 139 | + paddle.disable_static() |
| 140 | + |
| 141 | + query = np.random.random(self.shape) |
| 142 | + key = np.random.random(self.shape) |
| 143 | + value = np.random.random(self.shape) |
| 144 | + |
| 145 | + q = paddle.to_tensor( |
| 146 | + query, place=self.place, dtype=self.dtype, stop_gradient=False |
| 147 | + ) |
| 148 | + k = paddle.to_tensor( |
| 149 | + key, place=self.place, dtype=self.dtype, stop_gradient=False |
| 150 | + ) |
| 151 | + v = paddle.to_tensor( |
| 152 | + value, place=self.place, dtype=self.dtype, stop_gradient=False |
| 153 | + ) |
| 154 | + |
| 155 | + q_ = paddle.to_tensor( |
| 156 | + query, place=self.place, dtype=self.dtype, stop_gradient=False |
| 157 | + ) |
| 158 | + k_ = paddle.to_tensor( |
| 159 | + key, place=self.place, dtype=self.dtype, stop_gradient=False |
| 160 | + ) |
| 161 | + v_ = paddle.to_tensor( |
| 162 | + value, place=self.place, dtype=self.dtype, stop_gradient=False |
| 163 | + ) |
| 164 | + |
| 165 | + mask_shape = (self.shape[0], 1, self.shape[1], self.shape[1]) |
| 166 | + mask = np.random.random(mask_shape) |
| 167 | + m = paddle.to_tensor( |
| 168 | + mask, place=self.place, dtype=self.dtype, stop_gradient=False |
| 169 | + ) |
| 170 | + |
| 171 | + with sdp_kernel( |
| 172 | + enable_math=True, enable_flash=False, enable_mem_efficient=False |
| 173 | + ): |
| 174 | + out = scaled_dot_product_attention( |
| 175 | + q, k, v, m, self.dropout, self.causal |
| 176 | + ) |
| 177 | + |
| 178 | + out_ = attention_naive_with_mask(q_, k_, v_, m) |
| 179 | + out.backward() |
| 180 | + out_.backward() |
| 181 | + np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03) |
| 182 | + |
| 183 | + |
| 184 | +if __name__ == '__main__': |
| 185 | + unittest.main() |
0 commit comments