Skip to content

Commit 8d40c67

Browse files
Fancy-hjypWanglongzhi2001
authored and
Wanglongzhi2001
committed
【PIR API adaptor No.143、144】 Migrate margin_cross_entropy、masked_multihead_attention (PaddlePaddle#58762)
1 parent ddc1437 commit 8d40c67

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

python/paddle/incubate/nn/functional/masked_multihead_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from paddle import _C_ops
16-
from paddle.framework import LayerHelper, in_dynamic_mode
16+
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode
1717

1818

1919
def masked_multihead_attention(
@@ -90,7 +90,7 @@ def masked_multihead_attention(
9090
9191
"""
9292

93-
if in_dynamic_mode():
93+
if in_dynamic_or_pir_mode():
9494
return _C_ops.masked_multihead_attention_(
9595
x,
9696
cache_kv,

python/paddle/nn/functional/loss.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2314,7 +2314,7 @@ def margin_cross_entropy(
23142314
if input_dims - 1 == label_dims:
23152315
label = paddle.unsqueeze(label, axis=-1)
23162316

2317-
if in_dynamic_mode():
2317+
if in_dynamic_or_pir_mode():
23182318
softmax, loss = _C_ops.margin_cross_entropy(
23192319
logits,
23202320
label,

test/legacy_test/test_margin_cross_entropy_op.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from op_test import OpTest, convert_float_to_uint16, paddle_static_guard
1919

2020
import paddle
21-
from paddle.base import Program, core, program_guard
21+
from paddle.base import core
22+
from paddle.pir_utils import test_with_pir_api
23+
from paddle.static import Program, program_guard
2224

2325

2426
def stable_softmax_comm(x):
@@ -148,10 +150,14 @@ def setUp(self):
148150
}
149151

150152
def test_check_output(self):
151-
self.check_output_with_place(core.CUDAPlace(0), atol=1e-5)
153+
self.check_output_with_place(
154+
core.CUDAPlace(0), atol=1e-5, check_pir=True
155+
)
152156

153157
def test_check_grad(self):
154-
self.check_grad_with_place(core.CUDAPlace(0), ["Logits"], "Loss")
158+
self.check_grad_with_place(
159+
core.CUDAPlace(0), ["Logits"], "Loss", check_pir=True
160+
)
155161

156162

157163
@unittest.skipIf(
@@ -168,6 +174,7 @@ def test_check_grad(self):
168174
"Loss",
169175
numeric_grad_delta=5e-2,
170176
max_relative_error=5e-2,
177+
check_pir=True,
171178
)
172179

173180

@@ -179,7 +186,9 @@ def init_dtype(self):
179186
self.dtype = np.float16
180187

181188
def test_check_output(self):
182-
self.check_output_with_place(core.CUDAPlace(0), atol=5e-2)
189+
self.check_output_with_place(
190+
core.CUDAPlace(0), atol=5e-2, check_pir=True
191+
)
183192

184193
def test_check_grad(self):
185194
self.check_grad_with_place(
@@ -188,6 +197,7 @@ def test_check_grad(self):
188197
"Loss",
189198
numeric_grad_delta=6e-1,
190199
max_relative_error=6e-1,
200+
check_pir=True,
191201
)
192202

193203

@@ -264,7 +274,9 @@ def setUp(self):
264274
}
265275

266276
def test_check_output(self):
267-
self.check_output_with_place(core.CUDAPlace(0), atol=5e-2)
277+
self.check_output_with_place(
278+
core.CUDAPlace(0), atol=5e-2, check_pir=True
279+
)
268280

269281
def test_check_grad(self):
270282
self.check_grad_with_place(
@@ -273,6 +285,7 @@ def test_check_grad(self):
273285
"Loss",
274286
numeric_grad_delta=6e-1,
275287
max_relative_error=6e-1,
288+
check_pir=True,
276289
)
277290

278291

@@ -301,13 +314,17 @@ def init_loss_params(self):
301314
class TestMarginCrossEntropyOpCPU(TestMarginCrossEntropyOp):
302315
def test_check_output(self):
303316
try:
304-
self.check_output_with_place(core.CPUPlace(), atol=1e-5)
317+
self.check_output_with_place(
318+
core.CPUPlace(), atol=1e-5, check_pir=True
319+
)
305320
except RuntimeError:
306321
pass
307322

308323
def test_check_grad(self):
309324
try:
310-
self.check_grad_with_place(core.CPUPlace(), ["Logits"], "Loss")
325+
self.check_grad_with_place(
326+
core.CPUPlace(), ["Logits"], "Loss", check_pir=True
327+
)
311328
except RuntimeError:
312329
pass
313330

@@ -347,6 +364,7 @@ def init_dtype(self):
347364
def init_reduction(self):
348365
self.reduction = None
349366

367+
@test_with_pir_api
350368
def test_static(self):
351369
for place in self.places:
352370
self.check_static_result(place=place)
@@ -404,7 +422,7 @@ def check_static_result(self, place):
404422

405423
exe = paddle.base.Executor(place)
406424
[loss_res, softmax_res] = exe.run(
407-
paddle.base.default_main_program(),
425+
paddle.static.default_main_program(),
408426
feed={'logits': logits_np, 'label': labels_np},
409427
fetch_list=[loss, softmax],
410428
)

test/legacy_test/test_masked_multihead_attention_op.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import paddle
2020
from paddle.framework import core
2121
from paddle.incubate.nn.functional import masked_multihead_attention
22+
from paddle.pir_utils import test_with_pir_api
2223

2324

2425
@unittest.skipIf(
@@ -213,6 +214,7 @@ def check_main(
213214
paddle.enable_static()
214215
return paddle_naive_mmha_out, paddle_mmha_out
215216

217+
@test_with_pir_api
216218
def test_mmha_fp16(self):
217219
if not paddle.is_compiled_with_cuda():
218220
return
@@ -234,6 +236,7 @@ def test_mmha_fp16(self):
234236
atol=1e-3,
235237
)
236238

239+
@test_with_pir_api
237240
def test_mmha_qkv_out_scale(self):
238241
if not paddle.is_compiled_with_cuda():
239242
return
@@ -255,6 +258,7 @@ def test_mmha_qkv_out_scale(self):
255258
atol=1e-3,
256259
)
257260

261+
@test_with_pir_api
258262
def test_mmha_outlinear_in_scale(self):
259263
if not paddle.is_compiled_with_cuda():
260264
return
@@ -463,11 +467,12 @@ def check_main(
463467
"bias_static": bias.astype(dtype),
464468
"src_mask_static": src_mask.astype(dtype),
465469
},
466-
fetch_list=[outs],
470+
fetch_list=[outs[0], outs[1]],
467471
)
468472

469473
return paddle_naive_mmha_out, out_s
470474

475+
@test_with_pir_api
471476
def test_mmha_fp16(self):
472477
if not paddle.is_compiled_with_cuda():
473478
return

0 commit comments

Comments
 (0)