Skip to content

Commit 360ec11

Browse files
committed
[Eager] Support allclose and linalg_cond to eager mode (#41545)
1 parent a0b0a32 commit 360ec11

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

python/paddle/fluid/tests/unittests/test_allclose_layer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import paddle.fluid as fluid
1717
import unittest
1818
import numpy as np
19+
from paddle.fluid.framework import _test_eager_guard
1920

2021

2122
class TestAllcloseLayer(unittest.TestCase):
@@ -95,7 +96,7 @@ def test_allclose_gpu_fp64(self):
9596
with fluid.program_guard(main, startup):
9697
self.allclose_check(use_cuda=True, dtype='float64')
9798

98-
def test_dygraph_mode(self):
99+
def func_dygraph_mode(self):
99100
x_1 = np.array([10000., 1e-07]).astype("float32")
100101
y_1 = np.array([10000.1, 1e-08]).astype("float32")
101102
x_2 = np.array([10000., 1e-08]).astype("float32")
@@ -171,9 +172,14 @@ def test_dygraph_mode(self):
171172
x_v_5 = paddle.to_tensor(x_5)
172173
y_v_5 = paddle.to_tensor(y_5)
173174
ret_5 = paddle.allclose(
174-
x_v_5, y_v_5, rtol=0.01, atol=0.0, name='test_8')
175+
x_v_5, y_v_5, rtol=0.015, atol=0.0, name='test_8')
175176
self.assertEqual(ret_5.numpy()[0], True)
176177

178+
def test_dygraph_mode(self):
179+
with _test_eager_guard():
180+
self.func_dygraph_mode()
181+
self.func_dygraph_mode()
182+
177183

178184
if __name__ == "__main__":
179185
unittest.main()

python/paddle/fluid/tests/unittests/test_linalg_cond.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import paddle
2020
import paddle.static as static
21+
from paddle.fluid.framework import _test_eager_guard
2122

2223
p_list_n_n = ("fro", "nuc", 1, -1, np.inf, -np.inf)
2324
p_list_m_n = (None, 2, -2)
@@ -89,16 +90,21 @@ def test_out(self):
8990

9091

9192
class API_TestDygraphCond(unittest.TestCase):
92-
def test_out(self):
93+
def func_out(self):
9394
paddle.disable_static()
9495
# test calling results of 'cond' in dynamic mode
9596
x_list_n_n, x_list_m_n = gen_input()
9697
test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
9798
test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)
9899

100+
def test_out(self):
101+
with _test_eager_guard():
102+
self.func_out()
103+
self.func_out()
104+
99105

100106
class TestCondAPIError(unittest.TestCase):
101-
def test_dygraph_api_error(self):
107+
def func_dygraph_api_error(self):
102108
paddle.disable_static()
103109
# test raising errors when 'cond' is called in dygraph mode
104110
p_list_error = ('fro_', '_nuc', -0.7, 0, 1.5, 3)
@@ -113,6 +119,11 @@ def test_dygraph_api_error(self):
113119
x_tensor = paddle.to_tensor(x)
114120
self.assertRaises(ValueError, paddle.linalg.cond, x_tensor, p)
115121

122+
def test_dygraph_api_error(self):
123+
with _test_eager_guard():
124+
self.func_dygraph_api_error()
125+
self.func_dygraph_api_error()
126+
116127
def test_static_api_error(self):
117128
paddle.enable_static()
118129
# test raising errors when 'cond' is called in static mode
@@ -149,13 +160,18 @@ def test_static_empty_input_error(self):
149160

150161

151162
class TestCondEmptyTensorInput(unittest.TestCase):
152-
def test_dygraph_empty_tensor_input(self):
163+
def func_dygraph_empty_tensor_input(self):
153164
paddle.disable_static()
154165
# test calling results of 'cond' when input is an empty tensor in dynamic mode
155166
x_list_n_n, x_list_m_n = gen_empty_input()
156167
test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
157168
test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)
158169

170+
def test_dygraph_empty_tensor_input(self):
171+
with _test_eager_guard():
172+
self.func_dygraph_empty_tensor_input()
173+
self.func_dygraph_empty_tensor_input()
174+
159175

160176
if __name__ == "__main__":
161177
paddle.enable_static()

0 commit comments

Comments
 (0)