18
18
import numpy as np
19
19
import paddle
20
20
import paddle .static as static
21
+ from paddle .fluid .framework import _test_eager_guard
21
22
22
23
p_list_n_n = ("fro" , "nuc" , 1 , - 1 , np .inf , - np .inf )
23
24
p_list_m_n = (None , 2 , - 2 )
@@ -89,16 +90,21 @@ def test_out(self):
89
90
90
91
91
92
class API_TestDygraphCond (unittest .TestCase ):
92
- def test_out (self ):
93
+ def func_out (self ):
93
94
paddle .disable_static ()
94
95
# test calling results of 'cond' in dynamic mode
95
96
x_list_n_n , x_list_m_n = gen_input ()
96
97
test_dygraph_assert_true (self , x_list_n_n , p_list_n_n + p_list_m_n )
97
98
test_dygraph_assert_true (self , x_list_m_n , p_list_m_n )
98
99
100
+ def test_out (self ):
101
+ with _test_eager_guard ():
102
+ self .func_out ()
103
+ self .func_out ()
104
+
99
105
100
106
class TestCondAPIError (unittest .TestCase ):
101
- def test_dygraph_api_error (self ):
107
+ def func_dygraph_api_error (self ):
102
108
paddle .disable_static ()
103
109
# test raising errors when 'cond' is called in dygraph mode
104
110
p_list_error = ('fro_' , '_nuc' , - 0.7 , 0 , 1.5 , 3 )
@@ -113,6 +119,11 @@ def test_dygraph_api_error(self):
113
119
x_tensor = paddle .to_tensor (x )
114
120
self .assertRaises (ValueError , paddle .linalg .cond , x_tensor , p )
115
121
122
+ def test_dygraph_api_error (self ):
123
+ with _test_eager_guard ():
124
+ self .func_dygraph_api_error ()
125
+ self .func_dygraph_api_error ()
126
+
116
127
def test_static_api_error (self ):
117
128
paddle .enable_static ()
118
129
# test raising errors when 'cond' is called in static mode
@@ -149,13 +160,18 @@ def test_static_empty_input_error(self):
149
160
150
161
151
162
class TestCondEmptyTensorInput (unittest .TestCase ):
152
- def test_dygraph_empty_tensor_input (self ):
163
+ def func_dygraph_empty_tensor_input (self ):
153
164
paddle .disable_static ()
154
165
# test calling results of 'cond' when input is an empty tensor in dynamic mode
155
166
x_list_n_n , x_list_m_n = gen_empty_input ()
156
167
test_dygraph_assert_true (self , x_list_n_n , p_list_n_n + p_list_m_n )
157
168
test_dygraph_assert_true (self , x_list_m_n , p_list_m_n )
158
169
170
+ def test_dygraph_empty_tensor_input (self ):
171
+ with _test_eager_guard ():
172
+ self .func_dygraph_empty_tensor_input ()
173
+ self .func_dygraph_empty_tensor_input ()
174
+
159
175
160
176
if __name__ == "__main__" :
161
177
paddle .enable_static ()
0 commit comments