@@ -250,7 +250,19 @@ def slice_is_same_to_original(start, end, step):
250
250
251
251
252
252
def parse_index (x , indices ):
253
- advanced_index = [None ] * 2 * len (x .shape ) # content is (dim, index)
253
+ from .framework import in_pir_mode
254
+
255
+ if in_pir_mode ():
256
+ is_tensor_array = x .is_dense_tensor_array_type ()
257
+ else :
258
+ is_tensor_array = (
259
+ hasattr (x , "desc" )
260
+ and x .desc .type () == core .VarDesc .VarType .LOD_TENSOR_ARRAY
261
+ )
262
+
263
+ advanced_index = (
264
+ [] if is_tensor_array else [None ] * 2 * len (x .shape )
265
+ ) # content is (dim, index)
254
266
# for set_value / slice / strided_slice OP
255
267
decrease_axes = []
256
268
axes = []
@@ -267,11 +279,6 @@ def parse_index(x, indices):
267
279
indices = replace_ellipsis (x , indices )
268
280
indices , none_axes = replace_none (indices )
269
281
270
- is_tensor_array = (
271
- hasattr (x , "desc" )
272
- and x .desc .type () == core .VarDesc .VarType .LOD_TENSOR_ARRAY
273
- )
274
-
275
282
estimated_dim = 0
276
283
dim = 0
277
284
for i , slice_item in enumerate (indices ):
@@ -740,6 +747,8 @@ def get_tensor_with_basic_indexing(
740
747
if isinstance (end , (list , tuple )):
741
748
if paddle .utils ._contain_var (end ):
742
749
end = paddle .utils .get_int_tensor_list (end )
750
+ if x .is_dense_tensor_array_type ():
751
+ return paddle ._pir_ops .slice_array_dense (x , st )
743
752
out = paddle ._C_ops .slice (
744
753
x ,
745
754
axes ,
0 commit comments