16
16
import unittest
17
17
18
18
import numpy as np
19
- from op_test import OpTest , convert_float_to_uint16 , paddle_static_guard
19
+ from op_test import OpTest , convert_float_to_uint16
20
+ from utils import dygraph_guard , static_guard
20
21
21
22
import paddle
22
23
from paddle .base import core
@@ -260,7 +261,7 @@ def setUp(self):
260
261
def test_api (self ):
261
262
262
263
def test_static ():
263
- with paddle_static_guard ():
264
+ with static_guard ():
264
265
with paddle .static .program_guard (paddle .static .Program ()):
265
266
inputs = [
266
267
paddle .static .data (
@@ -273,8 +274,7 @@ def test_static():
273
274
paddle .broadcast_tensors (inputs )
274
275
275
276
def test_dynamic ():
276
- paddle .disable_static ()
277
- try :
277
+ with dygraph_guard ():
278
278
inputs = [
279
279
paddle .to_tensor (
280
280
np .random .random ([4 , 1 , 4 , 1 ]).astype (self .dtype )
@@ -294,8 +294,6 @@ def test_dynamic():
294
294
),
295
295
]
296
296
paddle .broadcast_tensors (inputs )
297
- finally :
298
- paddle .enable_static ()
299
297
300
298
test_static ()
301
299
test_dynamic ()
@@ -314,7 +312,7 @@ def setUp(self):
314
312
class TestRaiseBroadcastTensorsError (unittest .TestCase ):
315
313
def test_errors (self ):
316
314
def test_type ():
317
- with paddle_static_guard ():
315
+ with static_guard ():
318
316
with paddle .static .program_guard (paddle .static .Program ()):
319
317
inputs = [
320
318
paddle .static .data (
@@ -327,7 +325,7 @@ def test_type():
327
325
paddle .broadcast_tensors (inputs )
328
326
329
327
def test_dtype ():
330
- with paddle_static_guard ():
328
+ with static_guard ():
331
329
with paddle .static .program_guard (paddle .static .Program ()):
332
330
inputs = [
333
331
paddle .static .data (
@@ -340,7 +338,7 @@ def test_dtype():
340
338
paddle .broadcast_tensors (inputs )
341
339
342
340
def test_bcast_semantics ():
343
- with paddle_static_guard ():
341
+ with static_guard ():
344
342
with paddle .static .program_guard (paddle .static .Program ()):
345
343
inputs = [
346
344
paddle .static .data (
@@ -353,7 +351,7 @@ def test_bcast_semantics():
353
351
paddle .broadcast_tensors (inputs )
354
352
355
353
def test_bcast_semantics_complex64 ():
356
- with paddle_static_guard ():
354
+ with static_guard ():
357
355
with paddle .static .program_guard (paddle .static .Program ()):
358
356
inputs = [
359
357
paddle .static .data (
@@ -382,8 +380,7 @@ def test_bcast_semantics_complex64():
382
380
class TestRaiseBroadcastTensorsErrorDyGraph (unittest .TestCase ):
383
381
def test_errors (self ):
384
382
def test_type ():
385
- paddle .disable_static ()
386
- try :
383
+ with dygraph_guard ():
387
384
inputs = [
388
385
paddle .to_tensor (
389
386
np .ones (shape = [1 , 1 , 1 , 1 ], dtype = 'float32' , name = "x4" )
@@ -393,12 +390,9 @@ def test_type():
393
390
),
394
391
]
395
392
paddle .broadcast_tensors (inputs )
396
- finally :
397
- paddle .enable_static ()
398
393
399
394
def test_dtype ():
400
- paddle .disable_static ()
401
- try :
395
+ with dygraph_guard ():
402
396
inputs = [
403
397
paddle .to_tensor (
404
398
np .ones (shape = [1 , 1 , 1 , 1 ], dtype = 'int8' , name = "x6" )
@@ -408,12 +402,9 @@ def test_dtype():
408
402
),
409
403
]
410
404
paddle .broadcast_tensors (inputs )
411
- finally :
412
- paddle .enable_static ()
413
405
414
406
def test_bcast_semantics ():
415
- paddle .disable_static ()
416
- try :
407
+ with dygraph_guard ():
417
408
inputs = [
418
409
paddle .to_tensor (
419
410
np .ones (shape = [1 , 3 , 1 , 1 ], dtype = 'float32' , name = "x9" )
@@ -423,8 +414,6 @@ def test_bcast_semantics():
423
414
),
424
415
]
425
416
paddle .broadcast_tensors (inputs )
426
- finally :
427
- paddle .enable_static ()
428
417
429
418
self .assertRaises (TypeError , test_type )
430
419
self .assertRaises (TypeError , test_dtype )
@@ -440,7 +429,7 @@ def set_dtypes(self):
440
429
pass
441
430
442
431
def test_single_static (self ):
443
- with paddle_static_guard ():
432
+ with static_guard ():
444
433
with paddle .static .program_guard (paddle .static .Program ()):
445
434
inputs = [
446
435
paddle .static .data (
@@ -451,17 +440,14 @@ def test_single_static(self):
451
440
self .assertEqual (len (outputs ), 1 )
452
441
453
442
def test_single_dynamic (self ):
454
- paddle .disable_static ()
455
- try :
443
+ with dygraph_guard ():
456
444
inputs = [
457
445
paddle .to_tensor (
458
446
np .random .random ([1 , 4 , 1 , 4 ]).astype (self .dtype )
459
447
),
460
448
]
461
449
outputs = paddle .broadcast_tensors (inputs )
462
450
self .assertEqual (len (outputs ), 1 )
463
- finally :
464
- paddle .enable_static ()
465
451
466
452
467
453
class TestBroadcastTensorsAPIZeroSize (unittest .TestCase ):
@@ -476,7 +462,7 @@ def set_dtype(self):
476
462
pass
477
463
478
464
def test_zero_size_static (self ):
479
- with paddle_static_guard ():
465
+ with static_guard ():
480
466
with paddle .static .program_guard (paddle .static .Program ()):
481
467
inputs = [
482
468
paddle .static .data (
@@ -491,8 +477,7 @@ def test_zero_size_static(self):
491
477
self .assertEqual (outputs [1 ].shape , self .expected_shape )
492
478
493
479
def test_zero_size_dynamic (self ):
494
- paddle .disable_static ()
495
- try :
480
+ with dygraph_guard ():
496
481
data1 = np .zeros (self .shape1 , dtype = self .dtype )
497
482
data2 = np .zeros (self .shape2 , dtype = self .dtype )
498
483
@@ -503,8 +488,6 @@ def test_zero_size_dynamic(self):
503
488
outputs = paddle .broadcast_tensors (inputs )
504
489
self .assertEqual (outputs [0 ].shape , self .expected_shape )
505
490
self .assertEqual (outputs [1 ].shape , self .expected_shape )
506
- finally :
507
- paddle .enable_static ()
508
491
509
492
510
493
class TestBroadcastTensorsAPIZeroSize_bool (TestBroadcastTensorsAPIZeroSize ):
0 commit comments