21
21
import numpy as np
22
22
23
23
import paddle
24
+ from paddle .framework import in_pir_mode
25
+ from paddle .pir_utils import test_with_pir_api
24
26
25
27
26
28
class TestNanInfBase (unittest .TestCase ):
@@ -299,6 +301,7 @@ def test_eager(self):
299
301
debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
300
302
)
301
303
304
+ @test_with_pir_api
302
305
def test_static (self ):
303
306
paddle .enable_static ()
304
307
shape = [8 , 8 ]
@@ -310,16 +313,22 @@ def test_static(self):
310
313
x = paddle .static .data (name = 'x' , shape = [8 , 8 ], dtype = "float32" )
311
314
y = paddle .static .data (name = 'y' , shape = [8 , 8 ], dtype = "float32" )
312
315
out = paddle .add (x , y )
313
- paddle .amp .debugging .check_numerics (
314
- tensor = out ,
315
- op_type = "elementwise_add" ,
316
- var_name = out .name ,
317
- debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
318
- )
316
+ if in_pir_mode ():
317
+ paddle .amp .debugging .check_numerics (
318
+ tensor = out ,
319
+ op_type = "elementwise_add" ,
320
+ var_name = out .id ,
321
+ debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
322
+ )
323
+ else :
324
+ paddle .amp .debugging .check_numerics (
325
+ tensor = out ,
326
+ op_type = "elementwise_add" ,
327
+ var_name = out .name ,
328
+ debug_mode = paddle .amp .debugging .DebugMode .CHECK_ALL ,
329
+ )
319
330
exe = paddle .static .Executor (paddle .CPUPlace ())
320
- exe .run (
321
- main_program , feed = {"x" : x_np , "y" : y_np }, fetch_list = [out .name ]
322
- )
331
+ exe .run (main_program , feed = {"x" : x_np , "y" : y_np }, fetch_list = [out ])
323
332
paddle .disable_static ()
324
333
325
334
0 commit comments