Skip to content

Commit d0dc9c5

Browse files
[PIR slice] Optimize bool index logic for setitem and getitem (#72644)
* add single bool index logic && optimize set_tensor_value_op * add single bool branch for getitem --------- Co-authored-by: Eddie-Wang1120 <wangjinheng1120@163.com>
1 parent f51e3ff commit d0dc9c5

File tree

3 files changed

+67
-29
lines changed

3 files changed

+67
-29
lines changed

paddle/fluid/pybind/eager_method.cc

+54-19
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,15 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
14521452
PyObject* _index = PyTuple_GET_ITEM(args, 0);
14531453
VLOG(4) << "Call new indexing strategy _getitem_dygraph";
14541454

1455+
PyObject* index_ptr =
1456+
!PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
1457+
DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
1458+
if (!PyTuple_Check(_index)) {
1459+
Py_DECREF(index_ptr);
1460+
VLOG(4) << "Call Py_DECREF";
1461+
}
1462+
});
1463+
14551464
// Note(0x45f): Using defined() instead of initialized()
14561465
// to support slice tensor which shape like [0, 0, 0].
14571466
PADDLE_ENFORCE_EQ(
@@ -1476,7 +1485,7 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
14761485

14771486
// step1: parsing the index and recording them
14781487
ParseIndex(tensor,
1479-
_index,
1488+
index_ptr,
14801489
&slice_axes,
14811490
&slice_starts,
14821491
&slice_ends,
@@ -1489,6 +1498,23 @@ static PyObject* tensor__getitem_dygraph(TensorObject* self,
14891498
&has_advanced_index,
14901499
&use_strided_slice);
14911500

1501+
// Special: Check if the index is single bool
1502+
if (PyTuple_GET_SIZE(_index) == 1 &&
1503+
PyBool_Check(PyTuple_GetItem(_index, 0))) {
1504+
if (PyTuple_GetItem(_index, 0) == Py_True) {
1505+
// unsqueeze the tensor to a new tensor with shape (1,)
1506+
paddle::Tensor out;
1507+
out.copy_(unsqueeze_ad_func(tensor, {0}), tensor.place(), false);
1508+
return ToPyObject(out);
1509+
} else {
1510+
// create a new tensor with shape (0,)
1511+
auto shape = tensor.shape();
1512+
shape.insert(shape.begin(), 0);
1513+
auto out = paddle::empty(shape, tensor.dtype(), tensor.place());
1514+
return ToPyObject(out);
1515+
}
1516+
}
1517+
14921518
// step2: Dealing with basic indexing
14931519
bool out_is_view = false;
14941520
auto out = getTensorWithBasicIndexing(tensor,
@@ -1748,6 +1774,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
17481774
tensor.name()));
17491775
}
17501776
const int rank = tensor.shape().size();
1777+
const int size = PyTuple_GET_SIZE(index_ptr);
17511778
std::vector<int> slice_starts, slice_ends, slice_strides;
17521779
std::vector<int64_t> slice_axes, decrease_axis, infer_flags, none_axes;
17531780

@@ -1760,7 +1787,7 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
17601787

17611788
// step1: parsing the index and recording them
17621789
ParseIndex(tensor,
1763-
_index,
1790+
index_ptr,
17641791
&slice_axes,
17651792
&slice_starts,
17661793
&slice_ends,
@@ -1808,14 +1835,18 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
18081835
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
18091836
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
18101837
}
1811-
self->tensor = set_value_with_tensor__ad_func(self->tensor,
1812-
value_tensor,
1813-
slice_starts,
1814-
slice_ends,
1815-
slice_strides,
1816-
slice_axes,
1817-
decrease_axis,
1818-
none_axes);
1838+
if (size == 1 && PyTuple_GetItem(index_ptr, 0) == Py_False) {
1839+
// do nothing
1840+
} else {
1841+
self->tensor = set_value_with_tensor__ad_func(self->tensor,
1842+
value_tensor,
1843+
slice_starts,
1844+
slice_ends,
1845+
slice_strides,
1846+
slice_axes,
1847+
decrease_axis,
1848+
none_axes);
1849+
}
18191850
if (PyCheckTensor(value_obj)) {
18201851
// pass the stop_gradient from value to tensor.
18211852
// pass stop gradient should be done after CheckInplace in
@@ -1830,15 +1861,19 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
18301861
if (InputsContainDistTensor(&mesh, self->tensor)) {
18311862
ConvertAllInputsToDistTensor(mesh, self->tensor);
18321863
}
1833-
self->tensor = set_value__ad_func(self->tensor,
1834-
slice_starts,
1835-
slice_ends,
1836-
slice_strides,
1837-
slice_axes,
1838-
decrease_axis,
1839-
none_axes,
1840-
{1},
1841-
values);
1864+
if (size == 1 && PyTuple_GetItem(index_ptr, 0) == Py_False) {
1865+
// do nothing
1866+
} else {
1867+
self->tensor = set_value__ad_func(self->tensor,
1868+
slice_starts,
1869+
slice_ends,
1870+
slice_strides,
1871+
slice_axes,
1872+
decrease_axis,
1873+
none_axes,
1874+
{1},
1875+
values);
1876+
}
18421877
}
18431878
} else {
18441879
// step3.2: Case for there are advanced indexing.

paddle/fluid/pybind/slice_utils.h

+6-10
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ static int _PySlice_GetIndices(PySliceObject* r,
171171
}
172172

173173
static void ParseIndex(const paddle::Tensor& tensor,
174-
PyObject* _index,
174+
PyObject* index,
175175
std::vector<int64_t>* slice_axes,
176176
std::vector<int>* slice_starts,
177177
std::vector<int>* slice_ends,
@@ -183,14 +183,6 @@ static void ParseIndex(const paddle::Tensor& tensor,
183183
std::vector<paddle::Tensor>* advanced_index,
184184
bool* has_advanced_index,
185185
bool* use_strided_slice) {
186-
// NOTE(zhiqiu): PyTuple_Pack increases refcount.
187-
PyObject* index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
188-
DEFINE_PADDLE_SCOPE_GUARD([index, _index]() {
189-
if (!PyTuple_Check(_index)) {
190-
Py_DECREF(index);
191-
VLOG(4) << "Call Py_DECREF";
192-
}
193-
});
194186
// for case 0-size tensor in slice
195187
PADDLE_ENFORCE_EQ(
196188
tensor.defined(),
@@ -199,7 +191,11 @@ static void ParseIndex(const paddle::Tensor& tensor,
199191
const auto& shape = tensor.dims();
200192
const int rank = shape.size();
201193
const int size = PyTuple_GET_SIZE(index);
202-
194+
if (size == 1 && PyBool_Check(PyTuple_GetItem(index, 0))) {
195+
// true and none using set_value full_set branch
196+
// false do nothing
197+
return;
198+
}
203199
// Check Ellipsis is valid
204200
int specified_dims = 0;
205201
int ell_count = 0;

paddle/phi/kernels/impl/set_value_kernel_impl.h

+7
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ void SetValueImpl(const Context& dev_ctx,
8686
std::vector<int64_t> starts_local = starts.GetData();
8787
std::vector<int64_t> ends_local = ends.GetData();
8888
std::vector<int64_t> steps_local = steps.GetData();
89+
if (starts_local.empty() && ends_local.empty() && steps_local.empty() &&
90+
axes.empty() && decrease_axes.empty() && none_axes.empty() &&
91+
value.numel() == 1) {
92+
ExpandKernel<T, Context>(
93+
dev_ctx, value, IntArray{phi::vectorize<int64_t>(in.dims())}, out);
94+
return;
95+
}
8996
phi::funcs::CheckAndUpdateSliceAttrs(
9097
in_dims, axes, &starts_local, &ends_local, &steps_local);
9198
auto slice_dims = phi::funcs::GetSliceDims(

0 commit comments

Comments
 (0)