|
20 | 20 | import paddle.nn.functional as F
|
21 | 21 | from paddle.nn.functional.flash_attention import (
|
22 | 22 | scaled_dot_product_attention,
|
| 23 | + sdp_kernel, |
23 | 24 | )
|
24 | 25 |
|
25 | 26 |
|
@@ -76,7 +77,7 @@ class TestAttentionWithBoolMask(unittest.TestCase):
|
76 | 77 | def setUp(self):
|
77 | 78 | self.place = paddle.CUDAPlace(0)
|
78 | 79 | self.shape = (1, 8, 8, 16)
|
79 |
| - self.dtype = 'float16' |
| 80 | + self.dtype = 'float32' |
80 | 81 | self.dropout = 0.0
|
81 | 82 | self.causal = False
|
82 | 83 |
|
@@ -115,9 +116,12 @@ def test_dot_scale_product_bool_mask(self):
|
115 | 116 | bool_mask, place=self.place, dtype=paddle.bool, stop_gradient=False
|
116 | 117 | )
|
117 | 118 |
|
118 |
| - out = scaled_dot_product_attention( |
119 |
| - q, k, v, m, self.dropout, self.causal |
120 |
| - ) |
| 119 | + with sdp_kernel( |
| 120 | + enable_math=True, enable_flash=False, enable_mem_efficient=False |
| 121 | + ): |
| 122 | + out = scaled_dot_product_attention( |
| 123 | + q, k, v, m, self.dropout, self.causal |
| 124 | + ) |
121 | 125 |
|
122 | 126 | out_ = attention_naive_with_bool_mask(q_, k_, v_, m)
|
123 | 127 |
|
@@ -160,9 +164,12 @@ def test_dot_scale_product_float_mask(self):
|
160 | 164 | mask, place=self.place, dtype=self.dtype, stop_gradient=False
|
161 | 165 | )
|
162 | 166 |
|
163 |
| - out = scaled_dot_product_attention( |
164 |
| - q, k, v, m, self.dropout, self.causal |
165 |
| - ) |
| 167 | + with sdp_kernel( |
| 168 | + enable_math=True, enable_flash=False, enable_mem_efficient=False |
| 169 | + ): |
| 170 | + out = scaled_dot_product_attention( |
| 171 | + q, k, v, m, self.dropout, self.causal |
| 172 | + ) |
166 | 173 |
|
167 | 174 | out_ = attention_naive_with_mask(q_, k_, v_, m)
|
168 | 175 | out.backward()
|
|
0 commit comments