File tree 4 files changed +41
-10
lines changed
4 files changed +41
-10
lines changed Original file line number Diff line number Diff line change @@ -1032,7 +1032,9 @@ def get_tensor_with_basic_indexing(
1032
1032
)
1033
1033
attrs ['infer_flags' ] = infer_flags
1034
1034
1035
- if paddle .in_dynamic_mode ():
1035
+ from . import in_dynamic_or_pir_mode
1036
+
1037
+ if in_dynamic_or_pir_mode ():
1036
1038
if "StartsTensorList" in inputs .keys ():
1037
1039
st = inputs ['StartsTensorList' ]
1038
1040
else :
Original file line number Diff line number Diff line change 14
14
15
15
from . import OpResult
16
16
17
+ _already_patch_opresult = False
18
+
17
19
18
20
def monkey_patch_opresult ():
19
- # Handling Tensor Methods
20
- import paddle .tensor
21
-
22
- for method_name in paddle .tensor .tensor_method_func :
23
- if hasattr (OpResult , method_name ):
24
- continue
25
- method_impl = getattr (paddle .tensor , method_name , None )
26
- if method_impl :
27
- setattr (OpResult , method_name , method_impl )
21
+ global _already_patch_opresult
22
+ if not _already_patch_opresult :
23
+ # Handling Tensor Methods
24
+ import paddle .tensor
25
+
26
+ for method_name in paddle .tensor .tensor_method_func :
27
+ if hasattr (OpResult , method_name ):
28
+ continue
29
+ method_impl = getattr (paddle .tensor , method_name , None )
30
+ if method_impl :
31
+ setattr (OpResult , method_name , method_impl )
32
+
33
+ # Handling __getitem__
34
+ from ..base .variable_index import _getitem_static
35
+
36
+ OpResult .__getitem__ = _getitem_static
37
+
38
+ _already_patch_opresult = True
Original file line number Diff line number Diff line change @@ -429,6 +429,7 @@ packages=['paddle',
429
429
'paddle.framework',
430
430
'paddle.jit',
431
431
'paddle.jit.dy2static',
432
+ 'paddle.jit.newir_dy2static',
432
433
'paddle.inference',
433
434
'paddle.inference.contrib',
434
435
'paddle.inference.contrib.utils',
Original file line number Diff line number Diff line change @@ -195,5 +195,22 @@ def train_step(to_static=True):
195
195
)
196
196
197
197
198
+ class TestDy2staticNewIR6 (unittest .TestCase ):
199
+ # test basic-indexing __getitem__ for OpResult
200
+ def test_basic_network (self ):
201
+ def func (x ):
202
+ shape = paddle .shape (x )
203
+ out = shape [1 :]
204
+ return out
205
+
206
+ static_func = paddle .jit .to_static (func )
207
+ x = paddle .randn ((2 , 3 , 4 ))
208
+ x .stop_gradient = False
209
+ ans = func (x )
210
+ out = static_func (x )
211
+
212
+ np .testing .assert_allclose (out .numpy (), ans .numpy ())
213
+
214
+
198
215
if __name__ == "__main__" :
199
216
unittest .main ()
You can’t perform that action at this time.
0 commit comments