Skip to content

Commit fccf664

Browse files
authored
[BugFix] fix tensor_array slice bugs in _getitem_impl_ (#46447)
* fix tensor_array slice bugs in _getitem_impl_ * fix when var is a paddle.Tensor * code format
1 parent 97004f6 commit fccf664

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py

+14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import unittest
1617

1718
import paddle
@@ -124,6 +125,14 @@ def test_list_append_in_while_loop_with_stack(x, iter_num):
124125
return out
125126

126127

128+
def test_tensor_array_slice(x, iter_num):
129+
a = []
130+
for i in range(paddle.to_tensor(3)):
131+
a.append(paddle.to_tensor(i))
132+
t = a[1:3]
133+
return a[2]
134+
135+
127136
# Situation 2: Test list pop
128137
def test_list_pop_without_control_flow_1(x):
129138
x = fluid.dygraph.to_variable(x)
@@ -292,6 +301,11 @@ def init_dygraph_func(self):
292301
self.all_dygraph_funcs = [test_list_append_in_while_loop_with_stack]
293302

294303

304+
class TestTensorArraySlice(TestListInWhileLoop):
305+
def init_dygraph_func(self):
306+
self.all_dygraph_funcs = [test_tensor_array_slice]
307+
308+
295309
class TestListInForLoop(TestListInWhileLoop):
296310
def init_dygraph_func(self):
297311
self.all_dygraph_funcs = [

python/paddle/fluid/variable_index.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,10 @@ def _getitem_impl_(var, item):
380380
item = replace_ellipsis(var, item)
381381
item, none_axes = replace_none(item)
382382
slice_info = SliceInfo()
383+
is_tensor_array = (
384+
hasattr(var, "desc")
385+
and var.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
386+
)
383387

384388
for dim, slice_item in enumerate(item):
385389
if is_integer_or_scalar_tensor(slice_item) and not is_bool_tensor(
@@ -390,13 +394,13 @@ def _getitem_impl_(var, item):
390394
and var.shape[dim] is not None
391395
and var.shape[dim] >= 0
392396
and slice_item >= var.shape[dim]
397+
and not is_tensor_array
393398
):
394399
# For python, if users write a, b = var, the __getitem__
395400
# method will iterate through 0, 1, 2 ... until __getitem__
396401
# throws an IndexError, then stop. The var[0], var[1] will
397402
# be given to a, b respectively. If more values are given,
398403
# the unpack size would cause error.
399-
#
400404
# We raises IndexError here to support grammar like `a, b = var`
401405
raise IndexError(
402406
"slice_item %d at dim %d should be >= 0 and < var.shape[%d]: %d"
@@ -422,7 +426,7 @@ def _getitem_impl_(var, item):
422426
if end is None:
423427
if var.shape[dim] != -1 and (
424428
paddle.fluid.framework._non_static_mode()
425-
or var.desc.type() != core.VarDesc.VarType.LOD_TENSOR_ARRAY
429+
or not is_tensor_array
426430
):
427431
end = var.shape[dim] if step > 0 else -1
428432
else:

0 commit comments

Comments
 (0)