Skip to content

Commit 3a6092c

Browse files
authored
1 parent f137fe4 commit 3a6092c

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4461,7 +4461,7 @@ def as_complex(x, name=None):
44614461
[[1j , (2+3j) , (4+5j) ],
44624462
[(6+7j) , (8+9j) , (10+11j)]])
44634463
"""
4464-
if in_dynamic_mode():
4464+
if in_dynamic_or_pir_mode():
44654465
return _C_ops.as_complex(x)
44664466
else:
44674467
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'as_complex')
@@ -4512,7 +4512,7 @@ def as_real(x, name=None):
45124512
[8. , 9. ],
45134513
[10., 11.]]])
45144514
"""
4515-
if in_dynamic_mode():
4515+
if in_dynamic_or_pir_mode():
45164516
return _C_ops.as_real(x)
45174517
else:
45184518
check_variable_and_dtype(x, 'x', ['complex64', 'complex128'], 'as_real')

test/legacy_test/test_complex_view_op.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import paddle
2121
from paddle import static
2222
from paddle.base import dygraph
23+
from paddle.pir_utils import test_with_pir_api
2324

2425
paddle.enable_static()
2526

@@ -43,7 +44,7 @@ def setUp(self):
4344
self.outputs = {'Out': out_ref}
4445

4546
def test_check_output(self):
46-
self.check_output()
47+
self.check_output(check_pir=True)
4748

4849
def test_check_grad(self):
4950
self.check_grad(
@@ -64,7 +65,7 @@ def setUp(self):
6465
self.python_api = paddle.as_real
6566

6667
def test_check_output(self):
67-
self.check_output()
68+
self.check_output(check_pir=True)
6869

6970
def test_check_grad(self):
7071
self.check_grad(
@@ -84,6 +85,7 @@ def test_dygraph(self):
8485
out_np = paddle.as_complex(x).numpy()
8586
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
8687

88+
@test_with_pir_api
8789
def test_static(self):
8890
mp, sp = static.Program(), static.Program()
8991
with static.program_guard(mp, sp):
@@ -107,6 +109,7 @@ def test_dygraph(self):
107109
out_np = paddle.as_real(x).numpy()
108110
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
109111

112+
@test_with_pir_api
110113
def test_static(self):
111114
mp, sp = static.Program(), static.Program()
112115
with static.program_guard(mp, sp):

0 commit comments

Comments
 (0)