Skip to content

Commit 4b03294

Browse files
[PHI] implement index_elementwise op (#72942)
* [PHI] implement index_elementwise op * update ComputeStrides * update * Add stride mechanism and unit tests * add index_elementwise_get_grad kernel * fix AdvancedIndex bug
1 parent 5172626 commit 4b03294

14 files changed

+1022
-191
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 195 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,159 @@ namespace py = pybind11;
3939

4040
namespace paddle {
4141
namespace pybind {
42+
static inline common::DDim infer_size_symdimvector(common::DDim a,
43+
common::DDim b) {
44+
// Use ptrdiff_t to ensure signed comparison.
45+
auto dimsA = a.size();
46+
auto dimsB = b.size();
47+
auto ndim = dimsA > dimsB ? dimsA : dimsB;
48+
common::DDim expandedSizes = common::make_ddim(std::vector<int64_t>(ndim, 0));
49+
50+
for (int64_t i = ndim - 1; i >= 0; --i) {
51+
int64_t offset = ndim - 1 - i;
52+
int64_t dimA = dimsA - 1 - offset;
53+
int64_t dimB = dimsB - 1 - offset;
54+
auto sizeA = (dimA >= 0) ? a[dimA] : 1;
55+
auto sizeB = (dimB >= 0) ? b[dimB] : 1;
56+
57+
PADDLE_ENFORCE_EQ(
58+
sizeA == sizeB || sizeA == 1 || sizeB == 1,
59+
true,
60+
common::errors::Fatal("The size of tensor a (",
61+
sizeA,
62+
") must match the size of tensor b (",
63+
sizeB,
64+
") at non-singleton dimension ",
65+
i));
66+
67+
// 1s map to the other size (even 0).
68+
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
69+
}
70+
71+
return expandedSizes;
72+
}
73+
74+
static inline std::vector<paddle::Tensor> expand_outplace(
75+
std::vector<paddle::Tensor> to_expand) {
76+
// expands a list of Tensors; ignores undefined (null) tensors
77+
bool first = true;
78+
common::DDim sizes;
79+
for (size_t i = 0; i < to_expand.size(); i++) {
80+
if (!to_expand[i].initialized()) {
81+
continue;
82+
} else if (first) {
83+
sizes = to_expand[i].dims();
84+
first = false;
85+
} else {
86+
sizes = infer_size_symdimvector(sizes, to_expand[i].dims());
87+
}
88+
}
89+
90+
std::vector<paddle::Tensor> result(to_expand.size());
91+
for (size_t i = 0; i < to_expand.size(); i++) {
92+
if (!to_expand[i].initialized()) {
93+
continue;
94+
} else if (to_expand[i].dims() == sizes) {
95+
result[i] = to_expand[i];
96+
} else {
97+
result[i] =
98+
expand_ad_func(to_expand[i], common::vectorize<int64_t>(sizes));
99+
}
100+
}
101+
return result;
102+
}
103+
104+
struct AdvancedIndex {
105+
AdvancedIndex(paddle::Tensor src,
106+
std::vector<paddle::Tensor> indices,
107+
bool bool_case);
108+
109+
paddle::Tensor src;
110+
std::vector<paddle::Tensor> indices;
111+
std::vector<int64_t> indexed_sizes;
112+
std::vector<int64_t> indexed_strides;
113+
std::vector<int64_t> src_sizes;
114+
std::vector<int64_t> src_strides;
115+
int64_t dims_before;
116+
int64_t dims_after;
117+
bool bool_case;
118+
};
119+
120+
inline static void restride_src(std::vector<int64_t>* shape,
121+
std::vector<int64_t>* strides,
122+
int64_t dims_before,
123+
int64_t dims_indexed,
124+
std::vector<int64_t> replacement_shape) {
125+
int64_t end = dims_before + dims_indexed;
126+
shape->erase(shape->begin() + dims_before, shape->begin() + end);
127+
strides->erase(strides->begin() + dims_before, strides->begin() + end);
128+
shape->insert(shape->begin() + dims_before,
129+
replacement_shape.begin(),
130+
replacement_shape.end());
131+
strides->insert(strides->begin() + dims_before, replacement_shape.size(), 0);
132+
}
133+
134+
// move to cuda kernel
135+
inline static paddle::Tensor reshape_indexer(paddle::Tensor* index,
136+
int64_t dims_before,
137+
int64_t dims_after) {
138+
auto orig_shape = common::vectorize<int64_t>(index->dims());
139+
auto shape = std::vector<int64_t>{};
140+
shape.insert(shape.end(), dims_before, 1);
141+
shape.insert(shape.end(), orig_shape.begin(), orig_shape.end());
142+
shape.insert(shape.end(), dims_after, 1);
143+
*index = reshape_ad_func(*index, shape);
144+
return *index;
145+
}
146+
147+
inline AdvancedIndex::AdvancedIndex(paddle::Tensor src,
148+
std::vector<paddle::Tensor> indices_list,
149+
bool bool_case = false) {
150+
uint32_t element_size_bytes = phi::SizeOf(src.dtype());
151+
int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
152+
std::vector<int64_t> shape_vec = common::vectorize<int64_t>(src.dims());
153+
std::vector<int64_t> stride_vec = common::vectorize<int64_t>(src.strides());
154+
std::vector<int64_t> replacement_shape;
155+
std::vector<int64_t> idx_shape_vec = {};
156+
std::vector<int64_t> idx_stride_vec = {};
157+
158+
for (size_t dim = 0; dim < indices_list.size(); dim++) {
159+
if (!indices_list[dim].defined() || indices_list[dim].dims().size() == 0) {
160+
if (dims_indexed == 0) {
161+
dims_before++;
162+
} else {
163+
dims_after++;
164+
}
165+
} else {
166+
dims_indexed++;
167+
replacement_shape = common::vectorize<int64_t>(indices_list[dim].dims());
168+
if (bool_case && !replacement_shape.empty() &&
169+
replacement_shape.back() == 1) {
170+
replacement_shape.pop_back();
171+
}
172+
173+
idx_shape_vec.push_back(shape_vec[dim]);
174+
idx_stride_vec.push_back(stride_vec[dim] * element_size_bytes);
175+
}
176+
}
177+
178+
this->dims_before = dims_before;
179+
this->dims_after = dims_after;
180+
restride_src(
181+
&shape_vec, &stride_vec, dims_before, dims_indexed, replacement_shape);
182+
this->src_sizes = shape_vec;
183+
this->src_strides = stride_vec;
184+
185+
this->indexed_sizes = idx_shape_vec;
186+
this->indexed_strides = idx_stride_vec;
187+
188+
// use dims_before and dims_after / move to cuda kernel
189+
for (auto& index : indices_list) {
190+
if (index.defined() && index.dims().size() > 0) {
191+
this->indices.push_back(reshape_indexer(&index, dims_before, dims_after));
192+
}
193+
}
194+
}
42195

43196
template <typename T>
44197
inline T GetDenseTensorValue(const phi::DenseTensor* x) {
@@ -493,6 +646,35 @@ static paddle::Tensor dealWithAdvancedIndex(
493646
return transed_tensor;
494647
}
495648

649+
static std::vector<paddle::Tensor> PrepareIndices(
650+
const paddle::Tensor& tensor,
651+
const paddle::Tensor& bool_2_idx,
652+
const paddle::Tensor& bool_index) {
653+
std::vector<paddle::Tensor> indices;
654+
for (int j = 0; j < bool_2_idx.shape()[1]; ++j) {
655+
paddle::Tensor sliced_tensor =
656+
slice_ad_func(bool_2_idx, {1}, {j}, {j + 1}, {1}, {});
657+
658+
// Calculate the required dimensionality
659+
int64_t original_ndim =
660+
tensor.shape().size() - bool_index.shape().size() + 1;
661+
int64_t sliced_ndim = sliced_tensor.shape().size();
662+
int64_t num_ones_to_add = original_ndim - sliced_ndim;
663+
664+
// Reshape the tensor by adding 1s if needed
665+
if (num_ones_to_add > 0) {
666+
std::vector<int64_t> new_shape = sliced_tensor.shape();
667+
for (int64_t k = 0; k < num_ones_to_add; ++k) {
668+
new_shape.push_back(1);
669+
}
670+
sliced_tensor = reshape_ad_func(sliced_tensor, new_shape);
671+
}
672+
673+
indices.emplace_back(sliced_tensor);
674+
}
675+
return indices;
676+
}
677+
496678
static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
497679
const paddle::Tensor& bool_index) {
498680
PADDLE_ENFORCE(bool_index.shape().size() <= tensor.shape().size(),
@@ -532,7 +714,20 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
532714
}
533715

534716
auto bool_2_idx = nonzero_ad_func(bool_index);
717+
#ifdef PADDLE_WITH_CUDA
718+
auto indices = PrepareIndices(tensor, bool_2_idx, bool_index);
719+
AdvancedIndex ad = AdvancedIndex(tensor, indices, true);
720+
721+
return index_elementwise_get_ad_func(tensor,
722+
ad.indices,
723+
ad.src_sizes,
724+
ad.src_strides,
725+
ad.indexed_sizes,
726+
ad.indexed_strides);
727+
#else
728+
535729
return gather_nd_ad_func(tensor, bool_2_idx);
730+
#endif
536731
}
537732

538733
static void ParseBoolAndBroadcastIndices(
@@ -672,150 +867,5 @@ static paddle::Tensor dealWithValues(const paddle::Tensor& tensor,
672867
return value_tensor;
673868
}
674869

675-
static inline common::DDim infer_size_symdimvector(common::DDim a,
676-
common::DDim b) {
677-
// Use ptrdiff_t to ensure signed comparison.
678-
auto dimsA = a.size();
679-
auto dimsB = b.size();
680-
auto ndim = dimsA > dimsB ? dimsA : dimsB;
681-
common::DDim expandedSizes = common::make_ddim(std::vector<int64_t>(ndim, 0));
682-
683-
for (int64_t i = ndim - 1; i >= 0; --i) {
684-
int64_t offset = ndim - 1 - i;
685-
int64_t dimA = dimsA - 1 - offset;
686-
int64_t dimB = dimsB - 1 - offset;
687-
auto sizeA = (dimA >= 0) ? a[dimA] : 1;
688-
auto sizeB = (dimB >= 0) ? b[dimB] : 1;
689-
690-
PADDLE_ENFORCE_EQ(
691-
sizeA == sizeB || sizeA == 1 || sizeB == 1,
692-
true,
693-
common::errors::Fatal("The size of tensor a (",
694-
sizeA,
695-
") must match the size of tensor b (",
696-
sizeB,
697-
") at non-singleton dimension ",
698-
i));
699-
700-
// 1s map to the other size (even 0).
701-
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
702-
}
703-
704-
return expandedSizes;
705-
}
706-
707-
static inline std::vector<paddle::Tensor> expand_outplace(
708-
std::vector<paddle::Tensor> to_expand) {
709-
// expands a list of Tensors; ignores undefined (null) tensors
710-
bool first = true;
711-
common::DDim sizes;
712-
for (size_t i = 0; i < to_expand.size(); i++) {
713-
if (!to_expand[i].initialized()) {
714-
continue;
715-
} else if (first) {
716-
sizes = to_expand[i].dims();
717-
first = false;
718-
} else {
719-
sizes = infer_size_symdimvector(sizes, to_expand[i].dims());
720-
}
721-
}
722-
723-
std::vector<paddle::Tensor> result(to_expand.size());
724-
for (size_t i = 0; i < to_expand.size(); i++) {
725-
if (!to_expand[i].initialized()) {
726-
continue;
727-
} else if (to_expand[i].dims() == sizes) {
728-
result[i] = to_expand[i];
729-
} else {
730-
result[i] =
731-
expand_ad_func(to_expand[i], common::vectorize<int64_t>(sizes));
732-
}
733-
}
734-
return result;
735-
}
736-
737-
struct AdvancedIndex {
738-
AdvancedIndex(paddle::Tensor src, std::vector<paddle::Tensor> indices);
739-
740-
paddle::Tensor src;
741-
std::vector<paddle::Tensor> indices;
742-
std::vector<int64_t> indexed_sizes;
743-
std::vector<int64_t> indexed_strides;
744-
std::vector<int64_t> src_sizes;
745-
std::vector<int64_t> src_strides;
746-
int64_t dims_before;
747-
int64_t dims_after;
748-
};
749-
750-
inline static void restride_src(std::vector<int64_t>* shape,
751-
std::vector<int64_t>* strides,
752-
int64_t dims_before,
753-
int64_t dims_indexed,
754-
std::vector<int64_t> replacement_shape) {
755-
int64_t end = dims_before + dims_indexed;
756-
shape->erase(shape->begin() + dims_before, shape->begin() + end);
757-
strides->erase(strides->begin() + dims_before, strides->begin() + end);
758-
shape->insert(shape->begin() + dims_before,
759-
replacement_shape.begin(),
760-
replacement_shape.end());
761-
strides->insert(strides->begin() + dims_before, replacement_shape.size(), 0);
762-
}
763-
764-
// move to cuda kernel
765-
inline static paddle::Tensor reshape_indexer(paddle::Tensor* index,
766-
int64_t dims_before,
767-
int64_t dims_after) {
768-
auto orig_shape = common::vectorize<int64_t>(index->dims());
769-
auto shape = std::vector<int64_t>{};
770-
shape.insert(shape.end(), dims_before, 1);
771-
shape.insert(shape.end(), orig_shape.begin(), orig_shape.end());
772-
shape.insert(shape.end(), dims_after, 1);
773-
*index = reshape_ad_func(*index, shape);
774-
return *index;
775-
}
776-
777-
inline AdvancedIndex::AdvancedIndex(paddle::Tensor src,
778-
std::vector<paddle::Tensor> indices_list) {
779-
uint32_t element_size_bytes = phi::SizeOf(src.dtype());
780-
int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
781-
std::vector<int64_t> shape_vec = common::vectorize<int64_t>(src.dims());
782-
std::vector<int64_t> stride_vec = common::vectorize<int64_t>(src.strides());
783-
std::vector<int64_t> replacement_shape;
784-
std::vector<int64_t> idx_shape_vec = {};
785-
std::vector<int64_t> idx_stride_vec = {};
786-
787-
for (size_t dim = 0; dim < indices_list.size(); dim++) {
788-
if (!indices_list[dim].defined() || indices_list[dim].dims().size() == 0) {
789-
if (dims_indexed == 0) {
790-
dims_before++;
791-
} else {
792-
dims_after++;
793-
}
794-
} else {
795-
dims_indexed++;
796-
replacement_shape = common::vectorize<int64_t>(indices_list[dim].dims());
797-
idx_shape_vec.push_back(shape_vec[dim]);
798-
idx_stride_vec.push_back(stride_vec[dim] * element_size_bytes);
799-
}
800-
}
801-
802-
this->dims_before = dims_before;
803-
this->dims_after = dims_after;
804-
restride_src(
805-
&shape_vec, &stride_vec, dims_before, dims_indexed, replacement_shape);
806-
this->src_sizes = shape_vec;
807-
this->src_strides = stride_vec;
808-
809-
this->indexed_sizes = idx_shape_vec;
810-
this->indexed_strides = idx_stride_vec;
811-
812-
// use dims_before and dims_after / move to cuda kernel
813-
for (auto& index : indices_list) {
814-
if (index.defined() && index.dims().size() > 0) {
815-
this->indices.push_back(reshape_indexer(&index, dims_before, dims_after));
816-
}
817-
}
818-
}
819-
820870
} // namespace pybind
821871
} // namespace paddle

paddle/phi/infermeta/backward.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,4 +2091,18 @@ void FusedRMSNormGradInferMeta(const MetaTensor& x,
20912091
scale_grad->set_dims(scale.dims());
20922092
scale_grad->set_dtype(scale.dtype());
20932093
}
2094+
2095+
void IndexElementwiseGetGradInferMeta(
2096+
const MetaTensor& x,
2097+
const std::vector<const MetaTensor*>& index,
2098+
const MetaTensor& out_grad,
2099+
const std::vector<int64_t>& input_dims,
2100+
const std::vector<int64_t>& input_strides,
2101+
const std::vector<int64_t>& index_dims,
2102+
const std::vector<int64_t>& index_strides,
2103+
MetaTensor* x_grad) {
2104+
if (x_grad) {
2105+
x_grad->share_meta(x);
2106+
}
2107+
}
20942108
} // namespace phi

0 commit comments

Comments
 (0)