Skip to content

Commit a31c5cd

Browse files
committed
refine test
1 parent b44d10b commit a31c5cd

File tree

1 file changed

+15
-32
lines changed

1 file changed

+15
-32
lines changed

test/legacy_test/test_broadcast_tensors_op.py

+15-32
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import unittest
1717

1818
import numpy as np
19-
from op_test import OpTest, convert_float_to_uint16, paddle_static_guard
19+
from op_test import OpTest, convert_float_to_uint16
20+
from utils import dygraph_guard, static_guard
2021

2122
import paddle
2223
from paddle.base import core
@@ -260,7 +261,7 @@ def setUp(self):
260261
def test_api(self):
261262

262263
def test_static():
263-
with paddle_static_guard():
264+
with static_guard():
264265
with paddle.static.program_guard(paddle.static.Program()):
265266
inputs = [
266267
paddle.static.data(
@@ -273,8 +274,7 @@ def test_static():
273274
paddle.broadcast_tensors(inputs)
274275

275276
def test_dynamic():
276-
paddle.disable_static()
277-
try:
277+
with dygraph_guard():
278278
inputs = [
279279
paddle.to_tensor(
280280
np.random.random([4, 1, 4, 1]).astype(self.dtype)
@@ -294,8 +294,6 @@ def test_dynamic():
294294
),
295295
]
296296
paddle.broadcast_tensors(inputs)
297-
finally:
298-
paddle.enable_static()
299297

300298
test_static()
301299
test_dynamic()
@@ -314,7 +312,7 @@ def setUp(self):
314312
class TestRaiseBroadcastTensorsError(unittest.TestCase):
315313
def test_errors(self):
316314
def test_type():
317-
with paddle_static_guard():
315+
with static_guard():
318316
with paddle.static.program_guard(paddle.static.Program()):
319317
inputs = [
320318
paddle.static.data(
@@ -327,7 +325,7 @@ def test_type():
327325
paddle.broadcast_tensors(inputs)
328326

329327
def test_dtype():
330-
with paddle_static_guard():
328+
with static_guard():
331329
with paddle.static.program_guard(paddle.static.Program()):
332330
inputs = [
333331
paddle.static.data(
@@ -340,7 +338,7 @@ def test_dtype():
340338
paddle.broadcast_tensors(inputs)
341339

342340
def test_bcast_semantics():
343-
with paddle_static_guard():
341+
with static_guard():
344342
with paddle.static.program_guard(paddle.static.Program()):
345343
inputs = [
346344
paddle.static.data(
@@ -353,7 +351,7 @@ def test_bcast_semantics():
353351
paddle.broadcast_tensors(inputs)
354352

355353
def test_bcast_semantics_complex64():
356-
with paddle_static_guard():
354+
with static_guard():
357355
with paddle.static.program_guard(paddle.static.Program()):
358356
inputs = [
359357
paddle.static.data(
@@ -382,8 +380,7 @@ def test_bcast_semantics_complex64():
382380
class TestRaiseBroadcastTensorsErrorDyGraph(unittest.TestCase):
383381
def test_errors(self):
384382
def test_type():
385-
paddle.disable_static()
386-
try:
383+
with dygraph_guard():
387384
inputs = [
388385
paddle.to_tensor(
389386
np.ones(shape=[1, 1, 1, 1], dtype='float32', name="x4")
@@ -393,12 +390,9 @@ def test_type():
393390
),
394391
]
395392
paddle.broadcast_tensors(inputs)
396-
finally:
397-
paddle.enable_static()
398393

399394
def test_dtype():
400-
paddle.disable_static()
401-
try:
395+
with dygraph_guard():
402396
inputs = [
403397
paddle.to_tensor(
404398
np.ones(shape=[1, 1, 1, 1], dtype='int8', name="x6")
@@ -408,12 +402,9 @@ def test_dtype():
408402
),
409403
]
410404
paddle.broadcast_tensors(inputs)
411-
finally:
412-
paddle.enable_static()
413405

414406
def test_bcast_semantics():
415-
paddle.disable_static()
416-
try:
407+
with dygraph_guard():
417408
inputs = [
418409
paddle.to_tensor(
419410
np.ones(shape=[1, 3, 1, 1], dtype='float32', name="x9")
@@ -423,8 +414,6 @@ def test_bcast_semantics():
423414
),
424415
]
425416
paddle.broadcast_tensors(inputs)
426-
finally:
427-
paddle.enable_static()
428417

429418
self.assertRaises(TypeError, test_type)
430419
self.assertRaises(TypeError, test_dtype)
@@ -440,7 +429,7 @@ def set_dtypes(self):
440429
pass
441430

442431
def test_single_static(self):
443-
with paddle_static_guard():
432+
with static_guard():
444433
with paddle.static.program_guard(paddle.static.Program()):
445434
inputs = [
446435
paddle.static.data(
@@ -451,17 +440,14 @@ def test_single_static(self):
451440
self.assertEqual(len(outputs), 1)
452441

453442
def test_single_dynamic(self):
454-
paddle.disable_static()
455-
try:
443+
with dygraph_guard():
456444
inputs = [
457445
paddle.to_tensor(
458446
np.random.random([1, 4, 1, 4]).astype(self.dtype)
459447
),
460448
]
461449
outputs = paddle.broadcast_tensors(inputs)
462450
self.assertEqual(len(outputs), 1)
463-
finally:
464-
paddle.enable_static()
465451

466452

467453
class TestBroadcastTensorsAPIZeroSize(unittest.TestCase):
@@ -476,7 +462,7 @@ def set_dtype(self):
476462
pass
477463

478464
def test_zero_size_static(self):
479-
with paddle_static_guard():
465+
with static_guard():
480466
with paddle.static.program_guard(paddle.static.Program()):
481467
inputs = [
482468
paddle.static.data(
@@ -491,8 +477,7 @@ def test_zero_size_static(self):
491477
self.assertEqual(outputs[1].shape, self.expected_shape)
492478

493479
def test_zero_size_dynamic(self):
494-
paddle.disable_static()
495-
try:
480+
with dygraph_guard():
496481
data1 = np.zeros(self.shape1, dtype=self.dtype)
497482
data2 = np.zeros(self.shape2, dtype=self.dtype)
498483

@@ -503,8 +488,6 @@ def test_zero_size_dynamic(self):
503488
outputs = paddle.broadcast_tensors(inputs)
504489
self.assertEqual(outputs[0].shape, self.expected_shape)
505490
self.assertEqual(outputs[1].shape, self.expected_shape)
506-
finally:
507-
paddle.enable_static()
508491

509492

510493
class TestBroadcastTensorsAPIZeroSize_bool(TestBroadcastTensorsAPIZeroSize):

0 commit comments

Comments
 (0)