Skip to content

fix potential tensor leak in tensor.__setitem__ #35013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 25, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
#include <utility>
#include <vector>

#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/imperative/all_reduce.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h"
Expand Down Expand Up @@ -424,7 +425,15 @@ static void ParseIndexingSlice(
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
// types, and list of Bool and Integers.
// wrap to tuple

// NOTE(zhiqiu): PyTuple_Pack increases refcount.
PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
DEFINE_PADDLE_SCOPE_GUARD([index, _index]() {
if (!PyTuple_Check(_index)) {
Py_DECREF(index);
VLOG(4) << "Call Py_DECREF";
}
});
PADDLE_ENFORCE_EQ(
tensor->IsInitialized(), true,
platform::errors::InvalidArgument("tensor has not been initialized"));
Expand Down Expand Up @@ -550,8 +559,6 @@ static void ParseIndexingSlice(
platform::errors::InvalidArgument(
"Too many indices (%d) for tensor of dimension %d.",
valid_indexs, rank));

if (!PyTuple_Check(_index)) Py_DecRef(index);
}

template <typename P>
Expand Down Expand Up @@ -811,11 +818,21 @@ void BindImperative(py::module *m_ptr) {
.def("__setitem__",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
py::object &value_obj) {
VLOG(4) << "Call __setitem__";

auto self_tensor =
self->MutableVar()->GetMutable<framework::LoDTensor>();
// NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
// https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
PyObject *index_ptr = !PyTuple_Check(_index.ptr())
? PyTuple_Pack(1, _index.ptr())
: _index.ptr();
DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
if (!PyTuple_Check(_index.ptr())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above.

Py_DECREF(index_ptr);
VLOG(4) << "Call Py_DECREF";
}
});
// 1. Check argumnets
// 1.1 Check whether value obj is a tensor.
bool value_is_tensor = true;
Expand All @@ -826,6 +843,18 @@ void BindImperative(py::module *m_ptr) {
value_is_tensor = false;
}

auto is_tensor = [](py::handle var) {
if (!var.ptr() || var.ptr() == Py_None) {
return false;
}
try {
py::cast<std::shared_ptr<imperative::VarBase>>(var);
return true;
} catch (py::cast_error &) {
return false;
}
};

// 1.2 Check whether _index can be parsed.
const int size = PyTuple_GET_SIZE(index_ptr);
for (int dim = 0; dim < size; ++dim) {
Expand All @@ -842,6 +871,7 @@ void BindImperative(py::module *m_ptr) {
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if (parse_index && value_is_tensor) {
VLOG(4) << "index is integer/slice/ellipsis and value is tensor";
std::vector<int> axes, starts, ends, steps, decrease_axes,
none_axes, infer_flags, list_select_idxs;
// if index is a list, list_select_flag will be true
Expand All @@ -850,7 +880,6 @@ void BindImperative(py::module *m_ptr) {
&steps, &decrease_axes, &none_axes,
&infer_flags, &list_select_idxs,
&list_select_flag);

framework::AttributeMap attrs = {
{"axes", axes},
{"starts", starts},
Expand Down Expand Up @@ -882,20 +911,43 @@ void BindImperative(py::module *m_ptr) {
}
} else {
auto self_numpy = TensorToPyArray(*self_tensor);
VLOG(4) << "parse_index is false";

if (value_is_tensor) {
VLOG(4) << "value is tensor";
auto value =
value_obj.cast<std::shared_ptr<imperative::VarBase>>();
auto value_tensor =
value->MutableVar()->GetMutable<framework::LoDTensor>();
auto value_numpy = TensorToPyArray(*value_tensor);

self_numpy[_index] = value_numpy;
if (is_tensor(_index)) {
VLOG(4) << "index is tensor";
auto index_var =
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
auto index_tensor = index_var->MutableVar()
->GetMutable<framework::LoDTensor>();
auto index_numpy = TensorToPyArray(*index_tensor);
self_numpy[index_numpy] = value_numpy;
} else {
VLOG(4) << "index is not tensor";
self_numpy[_index] = value_numpy;
}
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
} else {
auto value_numpy = value_obj;
self_numpy[_index] = value_numpy;
VLOG(4) << "value is not tensor";
if (is_tensor(_index)) {
VLOG(4) << "index is tensor";
auto index_var =
py::cast<std::shared_ptr<imperative::VarBase>>(_index);
auto index_tensor = index_var->MutableVar()
->GetMutable<framework::LoDTensor>();
auto index_numpy = TensorToPyArray(*index_tensor);
self_numpy[index_numpy] = value_obj;
} else {
VLOG(4) << "index is not tensor";
self_numpy[_index] = value_obj;
}
SetTensorFromPyArray(self_tensor, self_numpy,
self_tensor->place(), true);
}
Expand All @@ -907,6 +959,7 @@ void BindImperative(py::module *m_ptr) {
})
.def("_getitem_index_not_tensor",
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
VLOG(4) << "Call _getitem_index_not_tensor";
std::vector<int> slice_axes, slice_starts, slice_ends,
slice_strides, decrease_axis, none_axes, infer_flags,
list_select_idxs;
Expand Down