@@ -39,6 +39,152 @@ namespace py = pybind11;
39
39
40
40
namespace paddle {
41
41
namespace pybind {
42
+ static inline common::DDim infer_size_symdimvector (common::DDim a,
43
+ common::DDim b) {
44
+ // Use ptrdiff_t to ensure signed comparison.
45
+ auto dimsA = a.size ();
46
+ auto dimsB = b.size ();
47
+ auto ndim = dimsA > dimsB ? dimsA : dimsB;
48
+ common::DDim expandedSizes = common::make_ddim (std::vector<int64_t >(ndim, 0 ));
49
+
50
+ for (int64_t i = ndim - 1 ; i >= 0 ; --i) {
51
+ int64_t offset = ndim - 1 - i;
52
+ int64_t dimA = dimsA - 1 - offset;
53
+ int64_t dimB = dimsB - 1 - offset;
54
+ auto sizeA = (dimA >= 0 ) ? a[dimA] : 1 ;
55
+ auto sizeB = (dimB >= 0 ) ? b[dimB] : 1 ;
56
+
57
+ PADDLE_ENFORCE (sizeA == sizeB || sizeA == 1 || sizeB == 1 ,
58
+ common::errors::Fatal (" The size of tensor a (" ,
59
+ sizeA,
60
+ " ) must match the size of tensor b (" ,
61
+ sizeB,
62
+ " ) at non-singleton dimension " ,
63
+ i));
64
+
65
+ // 1s map to the other size (even 0).
66
+ expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
67
+ }
68
+
69
+ return expandedSizes;
70
+ }
71
+
72
+ static inline std::vector<paddle::Tensor> expand_outplace (
73
+ std::vector<paddle::Tensor> to_expand) {
74
+ // expands a list of Tensors; ignores undefined (null) tensors
75
+ bool first = true ;
76
+ common::DDim sizes;
77
+ for (size_t i = 0 ; i < to_expand.size (); i++) {
78
+ if (!to_expand[i].initialized ()) {
79
+ continue ;
80
+ } else if (first) {
81
+ sizes = to_expand[i].dims ();
82
+ first = false ;
83
+ } else {
84
+ sizes = infer_size_symdimvector (sizes, to_expand[i].dims ());
85
+ }
86
+ }
87
+
88
+ std::vector<paddle::Tensor> result (to_expand.size ());
89
+ for (size_t i = 0 ; i < to_expand.size (); i++) {
90
+ if (!to_expand[i].initialized ()) {
91
+ continue ;
92
+ } else if (to_expand[i].dims () == sizes) {
93
+ result[i] = to_expand[i];
94
+ } else {
95
+ result[i] =
96
+ expand_ad_func (to_expand[i], common::vectorize<int64_t >(sizes));
97
+ }
98
+ }
99
+ return result;
100
+ }
101
+
102
+ struct AdvancedIndex {
103
+ AdvancedIndex (paddle::Tensor src, std::vector<paddle::Tensor> indices);
104
+
105
+ paddle::Tensor src;
106
+ std::vector<paddle::Tensor> indices;
107
+ std::vector<int64_t > indexed_sizes;
108
+ std::vector<int64_t > indexed_strides;
109
+ std::vector<int64_t > src_sizes;
110
+ std::vector<int64_t > src_strides;
111
+ int64_t dims_before;
112
+ int64_t dims_after;
113
+ };
114
+
115
+ inline static void restride_src (std::vector<int64_t >* shape,
116
+ std::vector<int64_t >* strides,
117
+ int64_t dims_before,
118
+ int64_t dims_indexed,
119
+ std::vector<int64_t > replacement_shape) {
120
+ int64_t end = dims_before + dims_indexed;
121
+ shape->erase (shape->begin () + dims_before, shape->begin () + end);
122
+ strides->erase (strides->begin () + dims_before, strides->begin () + end);
123
+ shape->insert (shape->begin () + dims_before,
124
+ replacement_shape.begin (),
125
+ replacement_shape.end ());
126
+ strides->insert (strides->begin () + dims_before, replacement_shape.size (), 0 );
127
+ }
128
+
129
+ // move to cuda kernel
130
+ inline static paddle::Tensor reshape_indexer (paddle::Tensor* index,
131
+ int64_t dims_before,
132
+ int64_t dims_after) {
133
+ auto orig_shape = common::vectorize<int64_t >(index->dims ());
134
+ auto shape = std::vector<int64_t >{};
135
+ shape.insert (shape.end (), dims_before, 1 );
136
+ shape.insert (shape.end (), orig_shape.begin (), orig_shape.end ());
137
+ shape.insert (shape.end (), dims_after, 1 );
138
+ *index = reshape_ad_func (*index, shape);
139
+ return *index;
140
+ }
141
+
142
+ inline AdvancedIndex::AdvancedIndex (paddle::Tensor src,
143
+ std::vector<paddle::Tensor> indices_list) {
144
+ uint32_t element_size_bytes = phi::SizeOf (src.dtype ());
145
+ int64_t dims_before = 0 , dims_after = 0 , dims_indexed = 0 ;
146
+ std::vector<int64_t > shape_vec = common::vectorize<int64_t >(src.dims ());
147
+ std::vector<int64_t > stride_vec = common::vectorize<int64_t >(src.strides ());
148
+ std::vector<int64_t > replacement_shape;
149
+ std::vector<int64_t > idx_shape_vec = {};
150
+ std::vector<int64_t > idx_stride_vec = {};
151
+
152
+ for (size_t dim = 0 ; dim < indices_list.size (); dim++) {
153
+ if (!indices_list[dim].defined () || indices_list[dim].dims ().size () == 0 ) {
154
+ if (dims_indexed == 0 ) {
155
+ dims_before++;
156
+ } else {
157
+ dims_after++;
158
+ }
159
+ } else {
160
+ dims_indexed++;
161
+ replacement_shape = common::vectorize<int64_t >(indices_list[dim].dims ());
162
+ if (!replacement_shape.empty () && replacement_shape.back () == 1 ) {
163
+ replacement_shape.pop_back ();
164
+ }
165
+
166
+ idx_shape_vec.push_back (shape_vec[dim]);
167
+ idx_stride_vec.push_back (stride_vec[dim] * element_size_bytes);
168
+ }
169
+ }
170
+
171
+ this ->dims_before = dims_before;
172
+ this ->dims_after = dims_after;
173
+ restride_src (
174
+ &shape_vec, &stride_vec, dims_before, dims_indexed, replacement_shape);
175
+ this ->src_sizes = shape_vec;
176
+ this ->src_strides = stride_vec;
177
+
178
+ this ->indexed_sizes = idx_shape_vec;
179
+ this ->indexed_strides = idx_stride_vec;
180
+
181
+ // use dims_before and dims_after / move to cuda kernel
182
+ for (auto & index : indices_list) {
183
+ if (index.defined () && index.dims ().size () > 0 ) {
184
+ this ->indices .push_back (reshape_indexer (&index, dims_before, dims_after));
185
+ }
186
+ }
187
+ }
42
188
43
189
template <typename T>
44
190
inline T GetDenseTensorValue (const phi::DenseTensor* x) {
@@ -493,18 +639,33 @@ static paddle::Tensor dealWithAdvancedIndex(
493
639
return transed_tensor;
494
640
}
495
641
496
- inline std::vector<int64_t > ComputeIndexStrides (const paddle::Tensor& input,
497
- const size_t index_dims_size) {
498
- const auto & input_strides = input.strides ();
499
- size_t element_size_bytes = phi::SizeOf (input.dtype ());
500
- std::vector<int64_t > strides (index_dims_size, 0 );
501
- const size_t min_size =
502
- std::min (static_cast <size_t >(input_strides.size ()), index_dims_size);
503
- for (size_t i = 0 ; i < min_size; ++i) {
504
- strides[i] = input_strides[i] * element_size_bytes;
505
- }
642
+ static std::vector<paddle::Tensor> PrepareIndices (
643
+ const paddle::Tensor& tensor,
644
+ const paddle::Tensor& bool_2_idx,
645
+ const paddle::Tensor& bool_index) {
646
+ std::vector<paddle::Tensor> indices;
647
+ for (int j = 0 ; j < bool_2_idx.shape ()[1 ]; ++j) {
648
+ paddle::Tensor sliced_tensor =
649
+ slice_ad_func (bool_2_idx, {1 }, {j}, {j + 1 }, {1 }, {});
506
650
507
- return strides;
651
+ // Calculate the required dimensionality
652
+ int64_t original_ndim =
653
+ tensor.shape ().size () - bool_index.shape ().size () + 1 ;
654
+ int64_t sliced_ndim = sliced_tensor.shape ().size ();
655
+ int64_t num_ones_to_add = original_ndim - sliced_ndim;
656
+
657
+ // Reshape the tensor by adding 1s if needed
658
+ if (num_ones_to_add > 0 ) {
659
+ std::vector<int64_t > new_shape = sliced_tensor.shape ();
660
+ for (int64_t k = 0 ; k < num_ones_to_add; ++k) {
661
+ new_shape.push_back (1 );
662
+ }
663
+ sliced_tensor = reshape_ad_func (sliced_tensor, new_shape);
664
+ }
665
+
666
+ indices.emplace_back (sliced_tensor);
667
+ }
668
+ return indices;
508
669
}
509
670
510
671
static paddle::Tensor getValueForBoolTensor (const paddle::Tensor& tensor,
@@ -547,17 +708,15 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
547
708
548
709
auto bool_2_idx = nonzero_ad_func (bool_index);
549
710
#ifdef PADDLE_WITH_CUDA
550
- std::vector<paddle::Tensor> indices;
551
- for (int j = 0 ; j < bool_2_idx.shape ()[1 ]; ++j) {
552
- paddle::Tensor sliced_tensor =
553
- slice_ad_func (bool_2_idx, {1 }, {j}, {j + 1 }, {1 }, {});
554
- indices.emplace_back (sliced_tensor);
555
- }
556
- auto index_dims_vec = common::vectorize<int64_t >(bool_index.dims ());
557
- auto index_stride = ComputeIndexStrides (tensor, index_dims_vec.size ());
558
-
559
- return index_elementwise_ad_func (
560
- tensor, indices, index_dims_vec, index_stride);
711
+ auto indices = PrepareIndices (tensor, bool_2_idx, bool_index);
712
+ AdvancedIndex ad = AdvancedIndex (tensor, indices);
713
+
714
+ return index_elementwise_get_ad_func (tensor,
715
+ ad.indices ,
716
+ ad.src_sizes ,
717
+ ad.src_strides ,
718
+ ad.indexed_sizes ,
719
+ ad.indexed_strides );
561
720
#else
562
721
563
722
return gather_nd_ad_func (tensor, bool_2_idx);
0 commit comments