Skip to content

Commit c30b97c

Browse files
committed
updated for DCU test
modified: test/legacy_test/test_scaled_dot_product_attention.py
1 parent 703bc60 commit c30b97c

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

test/legacy_test/test_scaled_dot_product_attention.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle.nn.functional as F
2121
from paddle.nn.functional.flash_attention import (
2222
scaled_dot_product_attention,
23+
sdp_kernel,
2324
)
2425

2526

@@ -76,7 +77,7 @@ class TestAttentionWithBoolMask(unittest.TestCase):
7677
def setUp(self):
7778
self.place = paddle.CUDAPlace(0)
7879
self.shape = (1, 8, 8, 16)
79-
self.dtype = 'float16'
80+
self.dtype = 'float32'
8081
self.dropout = 0.0
8182
self.causal = False
8283

@@ -115,9 +116,12 @@ def test_dot_scale_product_bool_mask(self):
115116
bool_mask, place=self.place, dtype=paddle.bool, stop_gradient=False
116117
)
117118

118-
out = scaled_dot_product_attention(
119-
q, k, v, m, self.dropout, self.causal
120-
)
119+
with sdp_kernel(
120+
enable_math=True, enable_flash=False, enable_mem_efficient=False
121+
):
122+
out = scaled_dot_product_attention(
123+
q, k, v, m, self.dropout, self.causal
124+
)
121125

122126
out_ = attention_naive_with_bool_mask(q_, k_, v_, m)
123127

@@ -160,9 +164,12 @@ def test_dot_scale_product_float_mask(self):
160164
mask, place=self.place, dtype=self.dtype, stop_gradient=False
161165
)
162166

163-
out = scaled_dot_product_attention(
164-
q, k, v, m, self.dropout, self.causal
165-
)
167+
with sdp_kernel(
168+
enable_math=True, enable_flash=False, enable_mem_efficient=False
169+
):
170+
out = scaled_dot_product_attention(
171+
q, k, v, m, self.dropout, self.causal
172+
)
166173

167174
out_ = attention_naive_with_mask(q_, k_, v_, m)
168175
out.backward()

0 commit comments

Comments
 (0)