Skip to content

Commit 7ad9a5f

Browse files
authored
[PIR]Migrate einsum_v2 into pir (PaddlePaddle#58501)
1 parent 8b56e79 commit 7ad9a5f

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

python/paddle/tensor/einsum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from paddle import _C_ops
2424

2525
from ..base.data_feeder import check_type, check_variable_and_dtype
26-
from ..base.framework import in_dygraph_mode
26+
from ..base.framework import in_dynamic_or_pir_mode
2727
from ..base.layer_helper import LayerHelper
2828
from .linalg import matmul, transpose
2929
from .manipulation import reshape, squeeze, unsqueeze
@@ -832,7 +832,7 @@ def gen_einsum_op(equation, *operands):
832832
EinsumOp Python Interface:
833833
"""
834834

835-
if in_dygraph_mode():
835+
if in_dynamic_or_pir_mode():
836836
return _C_ops.einsum(operands, equation)[0]
837837
else:
838838
assert len(operands) <= 2, "Only support two operands in EinsumOp."

test/legacy_test/test_einsum_v2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import paddle
2121
from paddle.base import core
22+
from paddle.pir_utils import test_with_pir_api
2223

2324
os.environ['FLAGS_new_einsum'] = "1"
2425

@@ -382,7 +383,7 @@ def check_output_equal(self, actual, expect, rtol=1.0e-5, atol=1.0e-8):
382383
rtol=rtol,
383384
atol=atol,
384385
err_msg=error_msg.format(
385-
paddle.get_device(), expect, actual, self.__class__.__name__
386+
self._get_place(False), expect, actual, self.__class__.__name__
386387
),
387388
)
388389

@@ -465,6 +466,7 @@ def test_sums(self):
465466
self.check_output("i,ij->", y, x)
466467
self.check_output("ij,i->", x, y)
467468

469+
@test_with_pir_api
468470
def test_static_graph(self):
469471
paddle.enable_static()
470472
base = paddle.base
@@ -523,11 +525,12 @@ def setUp(self):
523525
def tearDown(self):
524526
paddle.disable_static()
525527

528+
@test_with_pir_api
526529
def test_shape(self):
527530
A = paddle.static.data(name='x', shape=[-1])
528531
B = paddle.static.data(name='y', shape=[384])
529532
C = paddle.einsum('i,d->id', A, B)
530-
self.assertEqual(C.shape, (-1, 384))
533+
self.assertEqual(tuple(C.shape), (-1, 384))
531534

532535

533536
@unittest.skipIf(

0 commit comments

Comments
 (0)