Skip to content

Commit 332fa7e

Browse files
authored
support getitem for OpResult in newIR (#57964)
1 parent bbdf952 commit 332fa7e

File tree

4 files changed

+41
-10
lines changed

4 files changed

+41
-10
lines changed

python/paddle/base/variable_index.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,9 @@ def get_tensor_with_basic_indexing(
10321032
)
10331033
attrs['infer_flags'] = infer_flags
10341034

1035-
if paddle.in_dynamic_mode():
1035+
from . import in_dynamic_or_pir_mode
1036+
1037+
if in_dynamic_or_pir_mode():
10361038
if "StartsTensorList" in inputs.keys():
10371039
st = inputs['StartsTensorList']
10381040
else:

python/paddle/pir/math_op_patch.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,25 @@
1414

1515
from . import OpResult
1616

17+
_already_patch_opresult = False
18+
1719

1820
def monkey_patch_opresult():
19-
# Handling Tensor Methods
20-
import paddle.tensor
21-
22-
for method_name in paddle.tensor.tensor_method_func:
23-
if hasattr(OpResult, method_name):
24-
continue
25-
method_impl = getattr(paddle.tensor, method_name, None)
26-
if method_impl:
27-
setattr(OpResult, method_name, method_impl)
21+
global _already_patch_opresult
22+
if not _already_patch_opresult:
23+
# Handling Tensor Methods
24+
import paddle.tensor
25+
26+
for method_name in paddle.tensor.tensor_method_func:
27+
if hasattr(OpResult, method_name):
28+
continue
29+
method_impl = getattr(paddle.tensor, method_name, None)
30+
if method_impl:
31+
setattr(OpResult, method_name, method_impl)
32+
33+
# Handling __getitem__
34+
from ..base.variable_index import _getitem_static
35+
36+
OpResult.__getitem__ = _getitem_static
37+
38+
_already_patch_opresult = True

python/setup.py.in

+1
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ packages=['paddle',
429429
'paddle.framework',
430430
'paddle.jit',
431431
'paddle.jit.dy2static',
432+
'paddle.jit.newir_dy2static',
432433
'paddle.inference',
433434
'paddle.inference.contrib',
434435
'paddle.inference.contrib.utils',

test/ir/new_ir/test_new_ir_to_static.py

+17
Original file line numberDiff line numberDiff line change
@@ -195,5 +195,22 @@ def train_step(to_static=True):
195195
)
196196

197197

198+
class TestDy2staticNewIR6(unittest.TestCase):
199+
# test basic-indexing __getitem__ for OpResult
200+
def test_basic_network(self):
201+
def func(x):
202+
shape = paddle.shape(x)
203+
out = shape[1:]
204+
return out
205+
206+
static_func = paddle.jit.to_static(func)
207+
x = paddle.randn((2, 3, 4))
208+
x.stop_gradient = False
209+
ans = func(x)
210+
out = static_func(x)
211+
212+
np.testing.assert_allclose(out.numpy(), ans.numpy())
213+
214+
198215
if __name__ == "__main__":
199216
unittest.main()

0 commit comments

Comments
 (0)