Skip to content

Commit 116c892

Browse files
authored
tensor_array slice in PIR (#60503)
* use slice_array, now will meet error of destory opresult still in use * disable the pir test until the bug fixed
1 parent a11aabd commit 116c892

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

python/paddle/base/variable_index.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,19 @@ def slice_is_same_to_original(start, end, step):
250250

251251

252252
def parse_index(x, indices):
253-
advanced_index = [None] * 2 * len(x.shape) # content is (dim, index)
253+
from .framework import in_pir_mode
254+
255+
if in_pir_mode():
256+
is_tensor_array = x.is_dense_tensor_array_type()
257+
else:
258+
is_tensor_array = (
259+
hasattr(x, "desc")
260+
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
261+
)
262+
263+
advanced_index = (
264+
[] if is_tensor_array else [None] * 2 * len(x.shape)
265+
) # content is (dim, index)
254266
# for set_value / slice / strided_slice OP
255267
decrease_axes = []
256268
axes = []
@@ -267,11 +279,6 @@ def parse_index(x, indices):
267279
indices = replace_ellipsis(x, indices)
268280
indices, none_axes = replace_none(indices)
269281

270-
is_tensor_array = (
271-
hasattr(x, "desc")
272-
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
273-
)
274-
275282
estimated_dim = 0
276283
dim = 0
277284
for i, slice_item in enumerate(indices):
@@ -740,6 +747,8 @@ def get_tensor_with_basic_indexing(
740747
if isinstance(end, (list, tuple)):
741748
if paddle.utils._contain_var(end):
742749
end = paddle.utils.get_int_tensor_list(end)
750+
if x.is_dense_tensor_array_type():
751+
return paddle._pir_ops.slice_array_dense(x, st)
743752
out = paddle._C_ops.slice(
744753
x,
745754
axes,

test/dygraph_to_static/test_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def init_dygraph_func(self):
292292
test_list_pop_in_while_loop,
293293
]
294294

295+
# TODO(zhangbo): Refine BuildOpFrom for op with sub_block
295296
def train(self, to_static=False):
296297
with base.dygraph.guard():
297298
if to_static:

0 commit comments

Comments
 (0)