Skip to content

Commit 8290a9e

Browse files
authored
【PIR API adaptor No.36】check_numerics (#58879)
1 parent 91fa5ff commit 8290a9e

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

python/paddle/amp/debugging.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from paddle.base import core
2424
from paddle.base.framework import dygraph_only
2525

26-
from ..framework import LayerHelper, in_dynamic_mode
26+
from ..framework import LayerHelper, in_dynamic_or_pir_mode
2727

2828
__all__ = [
2929
"DebugMode",
@@ -372,7 +372,7 @@ def check_numerics(
372372
stack_height_limit = -1
373373
output_dir = ""
374374

375-
if in_dynamic_mode():
375+
if in_dynamic_or_pir_mode():
376376
return _C_ops.check_numerics(
377377
tensor,
378378
op_type,

test/legacy_test/test_nan_inf.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import numpy as np
2222

2323
import paddle
24+
from paddle.framework import in_pir_mode
25+
from paddle.pir_utils import test_with_pir_api
2426

2527

2628
class TestNanInfBase(unittest.TestCase):
@@ -299,6 +301,7 @@ def test_eager(self):
299301
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
300302
)
301303

304+
@test_with_pir_api
302305
def test_static(self):
303306
paddle.enable_static()
304307
shape = [8, 8]
@@ -310,16 +313,22 @@ def test_static(self):
310313
x = paddle.static.data(name='x', shape=[8, 8], dtype="float32")
311314
y = paddle.static.data(name='y', shape=[8, 8], dtype="float32")
312315
out = paddle.add(x, y)
313-
paddle.amp.debugging.check_numerics(
314-
tensor=out,
315-
op_type="elementwise_add",
316-
var_name=out.name,
317-
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
318-
)
316+
if in_pir_mode():
317+
paddle.amp.debugging.check_numerics(
318+
tensor=out,
319+
op_type="elementwise_add",
320+
var_name=out.id,
321+
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
322+
)
323+
else:
324+
paddle.amp.debugging.check_numerics(
325+
tensor=out,
326+
op_type="elementwise_add",
327+
var_name=out.name,
328+
debug_mode=paddle.amp.debugging.DebugMode.CHECK_ALL,
329+
)
319330
exe = paddle.static.Executor(paddle.CPUPlace())
320-
exe.run(
321-
main_program, feed={"x": x_np, "y": y_np}, fetch_list=[out.name]
322-
)
331+
exe.run(main_program, feed={"x": x_np, "y": y_np}, fetch_list=[out])
323332
paddle.disable_static()
324333

325334

0 commit comments

Comments
 (0)