Skip to content

Commit 763b6d9

Browse files
authored
fix potential tensor leak in tensor.__setitem__ (#35013)
* fix index tensor leak in __setitem__ * fix another usage of PyTuple_Pack * refine code * refine code * handle None index * add Py_DecRef * revert ut * refine code * merge develop * use RAII * follow comments
1 parent 4bfd044 commit 763b6d9

File tree

1 file changed

+60
-7
lines changed

1 file changed

+60
-7
lines changed

paddle/fluid/pybind/imperative.cc

+60-7
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License. */
2929
#include <utility>
3030
#include <vector>
3131

32+
#include "paddle/fluid/framework/scope_guard.h"
3233
#include "paddle/fluid/imperative/all_reduce.h"
3334
#include "paddle/fluid/imperative/amp_auto_cast.h"
3435
#include "paddle/fluid/imperative/basic_engine.h"
@@ -424,7 +425,15 @@ static void ParseIndexingSlice(
424425
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
425426
// types, and list of Bool and Integers.
426427
// wrap to tuple
428+
429+
// NOTE(zhiqiu): PyTuple_Pack increases refcount.
427430
PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
431+
DEFINE_PADDLE_SCOPE_GUARD([index, _index]() {
432+
if (!PyTuple_Check(_index)) {
433+
Py_DECREF(index);
434+
VLOG(4) << "Call Py_DECREF";
435+
}
436+
});
428437
PADDLE_ENFORCE_EQ(
429438
tensor->IsInitialized(), true,
430439
platform::errors::InvalidArgument("tensor has not been initialized"));
@@ -550,8 +559,6 @@ static void ParseIndexingSlice(
550559
platform::errors::InvalidArgument(
551560
"Too many indices (%d) for tensor of dimension %d.",
552561
valid_indexs, rank));
553-
554-
if (!PyTuple_Check(_index)) Py_DecRef(index);
555562
}
556563

557564
template <typename P>
@@ -811,11 +818,21 @@ void BindImperative(py::module *m_ptr) {
811818
.def("__setitem__",
812819
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
813820
py::object &value_obj) {
821+
VLOG(4) << "Call __setitem__";
822+
814823
auto self_tensor =
815824
self->MutableVar()->GetMutable<framework::LoDTensor>();
825+
// NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
826+
// https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
816827
PyObject *index_ptr = !PyTuple_Check(_index.ptr())
817828
? PyTuple_Pack(1, _index.ptr())
818829
: _index.ptr();
830+
DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
831+
if (!PyTuple_Check(_index.ptr())) {
832+
Py_DECREF(index_ptr);
833+
VLOG(4) << "Call Py_DECREF";
834+
}
835+
});
819836
// 1. Check argumnets
820837
// 1.1 Check whether value obj is a tensor.
821838
bool value_is_tensor = true;
@@ -826,6 +843,18 @@ void BindImperative(py::module *m_ptr) {
826843
value_is_tensor = false;
827844
}
828845

846+
auto is_tensor = [](py::handle var) {
847+
if (!var.ptr() || var.ptr() == Py_None) {
848+
return false;
849+
}
850+
try {
851+
py::cast<std::shared_ptr<imperative::VarBase>>(var);
852+
return true;
853+
} catch (py::cast_error &) {
854+
return false;
855+
}
856+
};
857+
829858
// 1.2 Check whether _index can be parsed.
830859
const int size = PyTuple_GET_SIZE(index_ptr);
831860
for (int dim = 0; dim < size; ++dim) {
@@ -842,6 +871,7 @@ void BindImperative(py::module *m_ptr) {
842871
// TODO(liym27): Try not to call TensorToPyArray because it always
843872
// copys data to cpu place, which reduces performance.
844873
if (parse_index && value_is_tensor) {
874+
VLOG(4) << "index is integer/slice/ellipsis and value is tensor";
845875
std::vector<int> axes, starts, ends, steps, decrease_axes,
846876
none_axes, infer_flags, list_select_idxs;
847877
// if index is a list, list_select_flag will be true
@@ -850,7 +880,6 @@ void BindImperative(py::module *m_ptr) {
850880
&steps, &decrease_axes, &none_axes,
851881
&infer_flags, &list_select_idxs,
852882
&list_select_flag);
853-
854883
framework::AttributeMap attrs = {
855884
{"axes", axes},
856885
{"starts", starts},
@@ -882,20 +911,43 @@ void BindImperative(py::module *m_ptr) {
882911
}
883912
} else {
884913
auto self_numpy = TensorToPyArray(*self_tensor);
914+
VLOG(4) << "parse_index is false";
885915

886916
if (value_is_tensor) {
917+
VLOG(4) << "value is tensor";
887918
auto value =
888919
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
889920
auto value_tensor =
890921
value->MutableVar()->GetMutable<framework::LoDTensor>();
891922
auto value_numpy = TensorToPyArray(*value_tensor);
892-
893-
self_numpy[_index] = value_numpy;
923+
if (is_tensor(_index)) {
924+
VLOG(4) << "index is tensor";
925+
auto index_var =
926+
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
927+
auto index_tensor = index_var->MutableVar()
928+
->GetMutable<framework::LoDTensor>();
929+
auto index_numpy = TensorToPyArray(*index_tensor);
930+
self_numpy[index_numpy] = value_numpy;
931+
} else {
932+
VLOG(4) << "index is not tensor";
933+
self_numpy[_index] = value_numpy;
934+
}
894935
SetTensorFromPyArray(self_tensor, self_numpy,
895936
self_tensor->place(), true);
896937
} else {
897-
auto value_numpy = value_obj;
898-
self_numpy[_index] = value_numpy;
938+
VLOG(4) << "value is not tensor";
939+
if (is_tensor(_index)) {
940+
VLOG(4) << "index is tensor";
941+
auto index_var =
942+
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
943+
auto index_tensor = index_var->MutableVar()
944+
->GetMutable<framework::LoDTensor>();
945+
auto index_numpy = TensorToPyArray(*index_tensor);
946+
self_numpy[index_numpy] = value_obj;
947+
} else {
948+
VLOG(4) << "index is not tensor";
949+
self_numpy[_index] = value_obj;
950+
}
899951
SetTensorFromPyArray(self_tensor, self_numpy,
900952
self_tensor->place(), true);
901953
}
@@ -907,6 +959,7 @@ void BindImperative(py::module *m_ptr) {
907959
})
908960
.def("_getitem_index_not_tensor",
909961
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
962+
VLOG(4) << "Call _getitem_index_not_tensor";
910963
std::vector<int> slice_axes, slice_starts, slice_ends,
911964
slice_strides, decrease_axis, none_axes, infer_flags,
912965
list_select_idxs;

0 commit comments

Comments
 (0)