@@ -29,6 +29,7 @@ limitations under the License. */
29
29
#include < utility>
30
30
#include < vector>
31
31
32
+ #include " paddle/fluid/framework/scope_guard.h"
32
33
#include " paddle/fluid/imperative/all_reduce.h"
33
34
#include " paddle/fluid/imperative/amp_auto_cast.h"
34
35
#include " paddle/fluid/imperative/basic_engine.h"
@@ -424,7 +425,15 @@ static void ParseIndexingSlice(
424
425
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
425
426
// types, and list of Bool and Integers.
426
427
// wrap to tuple
428
+
429
+ // NOTE(zhiqiu): PyTuple_Pack increases refcount.
427
430
PyObject *index = !PyTuple_Check (_index) ? PyTuple_Pack (1 , _index) : _index;
431
+ DEFINE_PADDLE_SCOPE_GUARD ([index , _index]() {
432
+ if (!PyTuple_Check (_index)) {
433
+ Py_DECREF (index );
434
+ VLOG (4 ) << " Call Py_DECREF" ;
435
+ }
436
+ });
428
437
PADDLE_ENFORCE_EQ (
429
438
tensor->IsInitialized (), true ,
430
439
platform::errors::InvalidArgument (" tensor has not been initialized" ));
@@ -550,8 +559,6 @@ static void ParseIndexingSlice(
550
559
platform::errors::InvalidArgument (
551
560
" Too many indices (%d) for tensor of dimension %d." ,
552
561
valid_indexs, rank));
553
-
554
- if (!PyTuple_Check (_index)) Py_DecRef (index );
555
562
}
556
563
557
564
template <typename P>
@@ -811,11 +818,21 @@ void BindImperative(py::module *m_ptr) {
811
818
.def (" __setitem__" ,
812
819
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
813
820
py::object &value_obj) {
821
+ VLOG (4 ) << " Call __setitem__" ;
822
+
814
823
auto self_tensor =
815
824
self->MutableVar ()->GetMutable <framework::LoDTensor>();
825
+ // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
826
+ // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
816
827
PyObject *index_ptr = !PyTuple_Check (_index.ptr ())
817
828
? PyTuple_Pack (1 , _index.ptr ())
818
829
: _index.ptr ();
830
+ DEFINE_PADDLE_SCOPE_GUARD ([index_ptr, &_index]() {
831
+ if (!PyTuple_Check (_index.ptr ())) {
832
+ Py_DECREF (index_ptr);
833
+ VLOG (4 ) << " Call Py_DECREF" ;
834
+ }
835
+ });
819
836
// 1. Check argumnets
820
837
// 1.1 Check whether value obj is a tensor.
821
838
bool value_is_tensor = true ;
@@ -826,6 +843,18 @@ void BindImperative(py::module *m_ptr) {
826
843
value_is_tensor = false ;
827
844
}
828
845
846
+ auto is_tensor = [](py::handle var) {
847
+ if (!var.ptr () || var.ptr () == Py_None) {
848
+ return false ;
849
+ }
850
+ try {
851
+ py::cast<std::shared_ptr<imperative::VarBase>>(var);
852
+ return true ;
853
+ } catch (py::cast_error &) {
854
+ return false ;
855
+ }
856
+ };
857
+
829
858
// 1.2 Check whether _index can be parsed.
830
859
const int size = PyTuple_GET_SIZE (index_ptr);
831
860
for (int dim = 0 ; dim < size; ++dim) {
@@ -842,6 +871,7 @@ void BindImperative(py::module *m_ptr) {
842
871
// TODO(liym27): Try not to call TensorToPyArray because it always
843
872
// copys data to cpu place, which reduces performance.
844
873
if (parse_index && value_is_tensor) {
874
+ VLOG (4 ) << " index is integer/slice/ellipsis and value is tensor" ;
845
875
std::vector<int > axes, starts, ends, steps, decrease_axes,
846
876
none_axes, infer_flags, list_select_idxs;
847
877
// if index is a list, list_select_flag will be true
@@ -850,7 +880,6 @@ void BindImperative(py::module *m_ptr) {
850
880
&steps, &decrease_axes, &none_axes,
851
881
&infer_flags, &list_select_idxs,
852
882
&list_select_flag);
853
-
854
883
framework::AttributeMap attrs = {
855
884
{" axes" , axes},
856
885
{" starts" , starts},
@@ -882,20 +911,43 @@ void BindImperative(py::module *m_ptr) {
882
911
}
883
912
} else {
884
913
auto self_numpy = TensorToPyArray (*self_tensor);
914
+ VLOG (4 ) << " parse_index is false" ;
885
915
886
916
if (value_is_tensor) {
917
+ VLOG (4 ) << " value is tensor" ;
887
918
auto value =
888
919
value_obj.cast <std::shared_ptr<imperative::VarBase>>();
889
920
auto value_tensor =
890
921
value->MutableVar ()->GetMutable <framework::LoDTensor>();
891
922
auto value_numpy = TensorToPyArray (*value_tensor);
892
-
893
- self_numpy[_index] = value_numpy;
923
+ if (is_tensor (_index)) {
924
+ VLOG (4 ) << " index is tensor" ;
925
+ auto index_var =
926
+ py::cast<std::shared_ptr<imperative::VarBase>>(_index);
927
+ auto index_tensor = index_var->MutableVar ()
928
+ ->GetMutable <framework::LoDTensor>();
929
+ auto index_numpy = TensorToPyArray (*index_tensor);
930
+ self_numpy[index_numpy] = value_numpy;
931
+ } else {
932
+ VLOG (4 ) << " index is not tensor" ;
933
+ self_numpy[_index] = value_numpy;
934
+ }
894
935
SetTensorFromPyArray (self_tensor, self_numpy,
895
936
self_tensor->place (), true );
896
937
} else {
897
- auto value_numpy = value_obj;
898
- self_numpy[_index] = value_numpy;
938
+ VLOG (4 ) << " value is not tensor" ;
939
+ if (is_tensor (_index)) {
940
+ VLOG (4 ) << " index is tensor" ;
941
+ auto index_var =
942
+ py::cast<std::shared_ptr<imperative::VarBase>>(_index);
943
+ auto index_tensor = index_var->MutableVar ()
944
+ ->GetMutable <framework::LoDTensor>();
945
+ auto index_numpy = TensorToPyArray (*index_tensor);
946
+ self_numpy[index_numpy] = value_obj;
947
+ } else {
948
+ VLOG (4 ) << " index is not tensor" ;
949
+ self_numpy[_index] = value_obj;
950
+ }
899
951
SetTensorFromPyArray (self_tensor, self_numpy,
900
952
self_tensor->place (), true );
901
953
}
@@ -907,6 +959,7 @@ void BindImperative(py::module *m_ptr) {
907
959
})
908
960
.def (" _getitem_index_not_tensor" ,
909
961
[](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
962
+ VLOG (4 ) << " Call _getitem_index_not_tensor" ;
910
963
std::vector<int > slice_axes, slice_starts, slice_ends,
911
964
slice_strides, decrease_axis, none_axes, infer_flags,
912
965
list_select_idxs;
0 commit comments