@@ -1452,6 +1452,15 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
1452
1452
PyObject* _index = PyTuple_GET_ITEM (args, 0 );
1453
1453
VLOG (4 ) << " Call new indexing strategy _getitem_dygraph" ;
1454
1454
1455
+ PyObject* index_ptr =
1456
+ !PyTuple_Check (_index) ? PyTuple_Pack (1 , _index) : _index;
1457
+ DEFINE_PADDLE_SCOPE_GUARD ([index_ptr, &_index]() {
1458
+ if (!PyTuple_Check (_index)) {
1459
+ Py_DECREF (index_ptr);
1460
+ VLOG (4 ) << " Call Py_DECREF" ;
1461
+ }
1462
+ });
1463
+
1455
1464
// Note(0x45f): Using defined() instead of initialized()
1456
1465
// to support slice tensor which shape like [0, 0, 0].
1457
1466
PADDLE_ENFORCE_EQ (
@@ -1476,7 +1485,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
1476
1485
1477
1486
// step1: parsing the index and recording them
1478
1487
ParseIndex (tensor,
1479
- _index ,
1488
+ index_ptr ,
1480
1489
&slice_axes,
1481
1490
&slice_starts,
1482
1491
&slice_ends,
@@ -1489,6 +1498,23 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
1489
1498
&has_advanced_index,
1490
1499
&use_strided_slice);
1491
1500
1501
+ // Special: Check if the index is single bool
1502
+ if (PyTuple_GET_SIZE (_index) == 1 &&
1503
+ PyBool_Check (PyTuple_GetItem (_index, 0 ))) {
1504
+ if (PyTuple_GetItem (_index, 0 ) == Py_True) {
1505
+ // unsqueeze the tensor to a new tensor with shape (1,)
1506
+ paddle::Tensor out;
1507
+ out.copy_ (unsqueeze_ad_func (tensor, {0 }), tensor.place (), false );
1508
+ return ToPyObject (out);
1509
+ } else {
1510
+ // create a new tensor with shape (0,)
1511
+ auto shape = tensor.shape ();
1512
+ shape.insert (shape.begin (), 0 );
1513
+ auto out = paddle::empty (shape, tensor.dtype (), tensor.place ());
1514
+ return ToPyObject (out);
1515
+ }
1516
+ }
1517
+
1492
1518
// step2: Dealing with basic indexing
1493
1519
bool out_is_view = false ;
1494
1520
auto out = getTensorWithBasicIndexing (tensor,
@@ -1748,6 +1774,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
1748
1774
tensor.name ()));
1749
1775
}
1750
1776
const int rank = tensor.shape ().size ();
1777
+ const int size = PyTuple_GET_SIZE (index_ptr);
1751
1778
std::vector<int > slice_starts, slice_ends, slice_strides;
1752
1779
std::vector<int64_t > slice_axes, decrease_axis, infer_flags, none_axes;
1753
1780
@@ -1760,7 +1787,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
1760
1787
1761
1788
// step1: parsing the index and recording them
1762
1789
ParseIndex (tensor,
1763
- _index ,
1790
+ index_ptr ,
1764
1791
&slice_axes,
1765
1792
&slice_starts,
1766
1793
&slice_ends,
@@ -1808,14 +1835,18 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
1808
1835
if (InputsContainDistTensor (&mesh, self->tensor , value_tensor)) {
1809
1836
ConvertAllInputsToDistTensor (mesh, self->tensor , value_tensor);
1810
1837
}
1811
- self->tensor = set_value_with_tensor__ad_func (self->tensor ,
1812
- value_tensor,
1813
- slice_starts,
1814
- slice_ends,
1815
- slice_strides,
1816
- slice_axes,
1817
- decrease_axis,
1818
- none_axes);
1838
+ if (size == 1 && PyTuple_GetItem (index_ptr, 0 ) == Py_False) {
1839
+ // do nothing
1840
+ } else {
1841
+ self->tensor = set_value_with_tensor__ad_func (self->tensor ,
1842
+ value_tensor,
1843
+ slice_starts,
1844
+ slice_ends,
1845
+ slice_strides,
1846
+ slice_axes,
1847
+ decrease_axis,
1848
+ none_axes);
1849
+ }
1819
1850
if (PyCheckTensor (value_obj)) {
1820
1851
// pass the stop_gradient from value to tensor.
1821
1852
// pass stop gradient should be done after CheckInplace in
@@ -1830,15 +1861,19 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
1830
1861
if (InputsContainDistTensor (&mesh, self->tensor )) {
1831
1862
ConvertAllInputsToDistTensor (mesh, self->tensor );
1832
1863
}
1833
- self->tensor = set_value__ad_func (self->tensor ,
1834
- slice_starts,
1835
- slice_ends,
1836
- slice_strides,
1837
- slice_axes,
1838
- decrease_axis,
1839
- none_axes,
1840
- {1 },
1841
- values);
1864
+ if (size == 1 && PyTuple_GetItem (index_ptr, 0 ) == Py_False) {
1865
+ // do nothing
1866
+ } else {
1867
+ self->tensor = set_value__ad_func (self->tensor ,
1868
+ slice_starts,
1869
+ slice_ends,
1870
+ slice_strides,
1871
+ slice_axes,
1872
+ decrease_axis,
1873
+ none_axes,
1874
+ {1 },
1875
+ values);
1876
+ }
1842
1877
}
1843
1878
} else {
1844
1879
// step3.2: Case for there are advanced indexing.
0 commit comments