@@ -39,6 +39,159 @@ namespace py = pybind11;
39
39
40
40
namespace paddle {
41
41
namespace pybind {
42
+ static inline common::DDim infer_size_symdimvector (common::DDim a,
43
+ common::DDim b) {
44
+ // Use ptrdiff_t to ensure signed comparison.
45
+ auto dimsA = a.size ();
46
+ auto dimsB = b.size ();
47
+ auto ndim = dimsA > dimsB ? dimsA : dimsB;
48
+ common::DDim expandedSizes = common::make_ddim (std::vector<int64_t >(ndim, 0 ));
49
+
50
+ for (int64_t i = ndim - 1 ; i >= 0 ; --i) {
51
+ int64_t offset = ndim - 1 - i;
52
+ int64_t dimA = dimsA - 1 - offset;
53
+ int64_t dimB = dimsB - 1 - offset;
54
+ auto sizeA = (dimA >= 0 ) ? a[dimA] : 1 ;
55
+ auto sizeB = (dimB >= 0 ) ? b[dimB] : 1 ;
56
+
57
+ PADDLE_ENFORCE_EQ (
58
+ sizeA == sizeB || sizeA == 1 || sizeB == 1 ,
59
+ true ,
60
+ common::errors::Fatal (" The size of tensor a (" ,
61
+ sizeA,
62
+ " ) must match the size of tensor b (" ,
63
+ sizeB,
64
+ " ) at non-singleton dimension " ,
65
+ i));
66
+
67
+ // 1s map to the other size (even 0).
68
+ expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
69
+ }
70
+
71
+ return expandedSizes;
72
+ }
73
+
74
+ static inline std::vector<paddle::Tensor> expand_outplace (
75
+ std::vector<paddle::Tensor> to_expand) {
76
+ // expands a list of Tensors; ignores undefined (null) tensors
77
+ bool first = true ;
78
+ common::DDim sizes;
79
+ for (size_t i = 0 ; i < to_expand.size (); i++) {
80
+ if (!to_expand[i].initialized ()) {
81
+ continue ;
82
+ } else if (first) {
83
+ sizes = to_expand[i].dims ();
84
+ first = false ;
85
+ } else {
86
+ sizes = infer_size_symdimvector (sizes, to_expand[i].dims ());
87
+ }
88
+ }
89
+
90
+ std::vector<paddle::Tensor> result (to_expand.size ());
91
+ for (size_t i = 0 ; i < to_expand.size (); i++) {
92
+ if (!to_expand[i].initialized ()) {
93
+ continue ;
94
+ } else if (to_expand[i].dims () == sizes) {
95
+ result[i] = to_expand[i];
96
+ } else {
97
+ result[i] =
98
+ expand_ad_func (to_expand[i], common::vectorize<int64_t >(sizes));
99
+ }
100
+ }
101
+ return result;
102
+ }
103
+
104
+ struct AdvancedIndex {
105
+ AdvancedIndex (paddle::Tensor src,
106
+ std::vector<paddle::Tensor> indices,
107
+ bool bool_case);
108
+
109
+ paddle::Tensor src;
110
+ std::vector<paddle::Tensor> indices;
111
+ std::vector<int64_t > indexed_sizes;
112
+ std::vector<int64_t > indexed_strides;
113
+ std::vector<int64_t > src_sizes;
114
+ std::vector<int64_t > src_strides;
115
+ int64_t dims_before;
116
+ int64_t dims_after;
117
+ bool bool_case;
118
+ };
119
+
120
+ inline static void restride_src (std::vector<int64_t >* shape,
121
+ std::vector<int64_t >* strides,
122
+ int64_t dims_before,
123
+ int64_t dims_indexed,
124
+ std::vector<int64_t > replacement_shape) {
125
+ int64_t end = dims_before + dims_indexed;
126
+ shape->erase (shape->begin () + dims_before, shape->begin () + end);
127
+ strides->erase (strides->begin () + dims_before, strides->begin () + end);
128
+ shape->insert (shape->begin () + dims_before,
129
+ replacement_shape.begin (),
130
+ replacement_shape.end ());
131
+ strides->insert (strides->begin () + dims_before, replacement_shape.size (), 0 );
132
+ }
133
+
134
+ // move to cuda kernel
135
+ inline static paddle::Tensor reshape_indexer (paddle::Tensor* index,
136
+ int64_t dims_before,
137
+ int64_t dims_after) {
138
+ auto orig_shape = common::vectorize<int64_t >(index->dims ());
139
+ auto shape = std::vector<int64_t >{};
140
+ shape.insert (shape.end (), dims_before, 1 );
141
+ shape.insert (shape.end (), orig_shape.begin (), orig_shape.end ());
142
+ shape.insert (shape.end (), dims_after, 1 );
143
+ *index = reshape_ad_func (*index, shape);
144
+ return *index;
145
+ }
146
+
147
+ inline AdvancedIndex::AdvancedIndex (paddle::Tensor src,
148
+ std::vector<paddle::Tensor> indices_list,
149
+ bool bool_case = false ) {
150
+ uint32_t element_size_bytes = phi::SizeOf (src.dtype ());
151
+ int64_t dims_before = 0 , dims_after = 0 , dims_indexed = 0 ;
152
+ std::vector<int64_t > shape_vec = common::vectorize<int64_t >(src.dims ());
153
+ std::vector<int64_t > stride_vec = common::vectorize<int64_t >(src.strides ());
154
+ std::vector<int64_t > replacement_shape;
155
+ std::vector<int64_t > idx_shape_vec = {};
156
+ std::vector<int64_t > idx_stride_vec = {};
157
+
158
+ for (size_t dim = 0 ; dim < indices_list.size (); dim++) {
159
+ if (!indices_list[dim].defined () || indices_list[dim].dims ().size () == 0 ) {
160
+ if (dims_indexed == 0 ) {
161
+ dims_before++;
162
+ } else {
163
+ dims_after++;
164
+ }
165
+ } else {
166
+ dims_indexed++;
167
+ replacement_shape = common::vectorize<int64_t >(indices_list[dim].dims ());
168
+ if (bool_case && !replacement_shape.empty () &&
169
+ replacement_shape.back () == 1 ) {
170
+ replacement_shape.pop_back ();
171
+ }
172
+
173
+ idx_shape_vec.push_back (shape_vec[dim]);
174
+ idx_stride_vec.push_back (stride_vec[dim] * element_size_bytes);
175
+ }
176
+ }
177
+
178
+ this ->dims_before = dims_before;
179
+ this ->dims_after = dims_after;
180
+ restride_src (
181
+ &shape_vec, &stride_vec, dims_before, dims_indexed, replacement_shape);
182
+ this ->src_sizes = shape_vec;
183
+ this ->src_strides = stride_vec;
184
+
185
+ this ->indexed_sizes = idx_shape_vec;
186
+ this ->indexed_strides = idx_stride_vec;
187
+
188
+ // use dims_before and dims_after / move to cuda kernel
189
+ for (auto & index : indices_list) {
190
+ if (index.defined () && index.dims ().size () > 0 ) {
191
+ this ->indices .push_back (reshape_indexer (&index, dims_before, dims_after));
192
+ }
193
+ }
194
+ }
42
195
43
196
template <typename T>
44
197
inline T GetDenseTensorValue (const phi::DenseTensor* x) {
@@ -493,6 +646,35 @@ static paddle::Tensor dealWithAdvancedIndex(
493
646
return transed_tensor;
494
647
}
495
648
649
+ static std::vector<paddle::Tensor> PrepareIndices (
650
+ const paddle::Tensor& tensor,
651
+ const paddle::Tensor& bool_2_idx,
652
+ const paddle::Tensor& bool_index) {
653
+ std::vector<paddle::Tensor> indices;
654
+ for (int j = 0 ; j < bool_2_idx.shape ()[1 ]; ++j) {
655
+ paddle::Tensor sliced_tensor =
656
+ slice_ad_func (bool_2_idx, {1 }, {j}, {j + 1 }, {1 }, {});
657
+
658
+ // Calculate the required dimensionality
659
+ int64_t original_ndim =
660
+ tensor.shape ().size () - bool_index.shape ().size () + 1 ;
661
+ int64_t sliced_ndim = sliced_tensor.shape ().size ();
662
+ int64_t num_ones_to_add = original_ndim - sliced_ndim;
663
+
664
+ // Reshape the tensor by adding 1s if needed
665
+ if (num_ones_to_add > 0 ) {
666
+ std::vector<int64_t > new_shape = sliced_tensor.shape ();
667
+ for (int64_t k = 0 ; k < num_ones_to_add; ++k) {
668
+ new_shape.push_back (1 );
669
+ }
670
+ sliced_tensor = reshape_ad_func (sliced_tensor, new_shape);
671
+ }
672
+
673
+ indices.emplace_back (sliced_tensor);
674
+ }
675
+ return indices;
676
+ }
677
+
496
678
static paddle::Tensor getValueForBoolTensor (const paddle::Tensor& tensor,
497
679
const paddle::Tensor& bool_index) {
498
680
PADDLE_ENFORCE (bool_index.shape ().size () <= tensor.shape ().size (),
@@ -532,7 +714,20 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
532
714
}
533
715
534
716
auto bool_2_idx = nonzero_ad_func (bool_index);
717
+ #ifdef PADDLE_WITH_CUDA
718
+ auto indices = PrepareIndices (tensor, bool_2_idx, bool_index);
719
+ AdvancedIndex ad = AdvancedIndex (tensor, indices, true );
720
+
721
+ return index_elementwise_get_ad_func (tensor,
722
+ ad.indices ,
723
+ ad.src_sizes ,
724
+ ad.src_strides ,
725
+ ad.indexed_sizes ,
726
+ ad.indexed_strides );
727
+ #else
728
+
535
729
return gather_nd_ad_func (tensor, bool_2_idx);
730
+ #endif
536
731
}
537
732
538
733
static void ParseBoolAndBroadcastIndices (
@@ -672,150 +867,5 @@ static paddle::Tensor dealWithValues(const paddle::Tensor& tensor,
672
867
return value_tensor;
673
868
}
674
869
675
- static inline common::DDim infer_size_symdimvector (common::DDim a,
676
- common::DDim b) {
677
- // Use ptrdiff_t to ensure signed comparison.
678
- auto dimsA = a.size ();
679
- auto dimsB = b.size ();
680
- auto ndim = dimsA > dimsB ? dimsA : dimsB;
681
- common::DDim expandedSizes = common::make_ddim (std::vector<int64_t >(ndim, 0 ));
682
-
683
- for (int64_t i = ndim - 1 ; i >= 0 ; --i) {
684
- int64_t offset = ndim - 1 - i;
685
- int64_t dimA = dimsA - 1 - offset;
686
- int64_t dimB = dimsB - 1 - offset;
687
- auto sizeA = (dimA >= 0 ) ? a[dimA] : 1 ;
688
- auto sizeB = (dimB >= 0 ) ? b[dimB] : 1 ;
689
-
690
- PADDLE_ENFORCE_EQ (
691
- sizeA == sizeB || sizeA == 1 || sizeB == 1 ,
692
- true ,
693
- common::errors::Fatal (" The size of tensor a (" ,
694
- sizeA,
695
- " ) must match the size of tensor b (" ,
696
- sizeB,
697
- " ) at non-singleton dimension " ,
698
- i));
699
-
700
- // 1s map to the other size (even 0).
701
- expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
702
- }
703
-
704
- return expandedSizes;
705
- }
706
-
707
- static inline std::vector<paddle::Tensor> expand_outplace (
708
- std::vector<paddle::Tensor> to_expand) {
709
- // expands a list of Tensors; ignores undefined (null) tensors
710
- bool first = true ;
711
- common::DDim sizes;
712
- for (size_t i = 0 ; i < to_expand.size (); i++) {
713
- if (!to_expand[i].initialized ()) {
714
- continue ;
715
- } else if (first) {
716
- sizes = to_expand[i].dims ();
717
- first = false ;
718
- } else {
719
- sizes = infer_size_symdimvector (sizes, to_expand[i].dims ());
720
- }
721
- }
722
-
723
- std::vector<paddle::Tensor> result (to_expand.size ());
724
- for (size_t i = 0 ; i < to_expand.size (); i++) {
725
- if (!to_expand[i].initialized ()) {
726
- continue ;
727
- } else if (to_expand[i].dims () == sizes) {
728
- result[i] = to_expand[i];
729
- } else {
730
- result[i] =
731
- expand_ad_func (to_expand[i], common::vectorize<int64_t >(sizes));
732
- }
733
- }
734
- return result;
735
- }
736
-
737
- struct AdvancedIndex {
738
- AdvancedIndex (paddle::Tensor src, std::vector<paddle::Tensor> indices);
739
-
740
- paddle::Tensor src;
741
- std::vector<paddle::Tensor> indices;
742
- std::vector<int64_t > indexed_sizes;
743
- std::vector<int64_t > indexed_strides;
744
- std::vector<int64_t > src_sizes;
745
- std::vector<int64_t > src_strides;
746
- int64_t dims_before;
747
- int64_t dims_after;
748
- };
749
-
750
- inline static void restride_src (std::vector<int64_t >* shape,
751
- std::vector<int64_t >* strides,
752
- int64_t dims_before,
753
- int64_t dims_indexed,
754
- std::vector<int64_t > replacement_shape) {
755
- int64_t end = dims_before + dims_indexed;
756
- shape->erase (shape->begin () + dims_before, shape->begin () + end);
757
- strides->erase (strides->begin () + dims_before, strides->begin () + end);
758
- shape->insert (shape->begin () + dims_before,
759
- replacement_shape.begin (),
760
- replacement_shape.end ());
761
- strides->insert (strides->begin () + dims_before, replacement_shape.size (), 0 );
762
- }
763
-
764
- // move to cuda kernel
765
- inline static paddle::Tensor reshape_indexer (paddle::Tensor* index,
766
- int64_t dims_before,
767
- int64_t dims_after) {
768
- auto orig_shape = common::vectorize<int64_t >(index->dims ());
769
- auto shape = std::vector<int64_t >{};
770
- shape.insert (shape.end (), dims_before, 1 );
771
- shape.insert (shape.end (), orig_shape.begin (), orig_shape.end ());
772
- shape.insert (shape.end (), dims_after, 1 );
773
- *index = reshape_ad_func (*index, shape);
774
- return *index;
775
- }
776
-
777
- inline AdvancedIndex::AdvancedIndex (paddle::Tensor src,
778
- std::vector<paddle::Tensor> indices_list) {
779
- uint32_t element_size_bytes = phi::SizeOf (src.dtype ());
780
- int64_t dims_before = 0 , dims_after = 0 , dims_indexed = 0 ;
781
- std::vector<int64_t > shape_vec = common::vectorize<int64_t >(src.dims ());
782
- std::vector<int64_t > stride_vec = common::vectorize<int64_t >(src.strides ());
783
- std::vector<int64_t > replacement_shape;
784
- std::vector<int64_t > idx_shape_vec = {};
785
- std::vector<int64_t > idx_stride_vec = {};
786
-
787
- for (size_t dim = 0 ; dim < indices_list.size (); dim++) {
788
- if (!indices_list[dim].defined () || indices_list[dim].dims ().size () == 0 ) {
789
- if (dims_indexed == 0 ) {
790
- dims_before++;
791
- } else {
792
- dims_after++;
793
- }
794
- } else {
795
- dims_indexed++;
796
- replacement_shape = common::vectorize<int64_t >(indices_list[dim].dims ());
797
- idx_shape_vec.push_back (shape_vec[dim]);
798
- idx_stride_vec.push_back (stride_vec[dim] * element_size_bytes);
799
- }
800
- }
801
-
802
- this ->dims_before = dims_before;
803
- this ->dims_after = dims_after;
804
- restride_src (
805
- &shape_vec, &stride_vec, dims_before, dims_indexed, replacement_shape);
806
- this ->src_sizes = shape_vec;
807
- this ->src_strides = stride_vec;
808
-
809
- this ->indexed_sizes = idx_shape_vec;
810
- this ->indexed_strides = idx_stride_vec;
811
-
812
- // use dims_before and dims_after / move to cuda kernel
813
- for (auto & index : indices_list) {
814
- if (index.defined () && index.dims ().size () > 0 ) {
815
- this ->indices .push_back (reshape_indexer (&index, dims_before, dims_after));
816
- }
817
- }
818
- }
819
-
820
870
} // namespace pybind
821
871
} // namespace paddle
0 commit comments