Skip to content

Commit aeb8c2e

Browse files
author
zhangkaihuo
authored
[Sparse]Fix the bug of elementwise_grad (#52102)
1 parent 8b622d5 commit aeb8c2e

File tree

3 files changed

+203
-76
lines changed

3 files changed

+203
-76
lines changed

paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc

+161-28
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License. */
2727
#include "paddle/phi/kernels/elementwise_kernel.h"
2828
#include "paddle/phi/kernels/empty_kernel.h"
2929
#include "paddle/phi/kernels/funcs/eigen/common.h"
30+
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h"
3031
#include "paddle/phi/kernels/sparse/empty_kernel.h"
3132

3233
namespace phi {
@@ -39,6 +40,7 @@ void AllocCsrPtr(const Context& dev_ctx,
3940
DenseTensor dx_crows = phi::EmptyLike<IntT>(dev_ctx, x.crows());
4041
DenseTensor dx_cols = phi::EmptyLike<IntT>(dev_ctx, x.cols());
4142
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values());
43+
dx->set_meta(x.meta());
4244
dx->SetMember(dx_crows, dx_cols, dx_values, x.dims());
4345
}
4446

@@ -48,9 +50,117 @@ void AllocCooPtr(const Context& dev_ctx,
4850
SparseCooTensor* dx) {
4951
DenseTensor dx_indices = phi::EmptyLike<IntT>(dev_ctx, x.indices());
5052
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values());
53+
dx->set_meta(x.meta());
5154
dx->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
5255
}
5356

57+
template <typename T, typename IntT, typename Context>
58+
void CopyCooValues(const Context& dev_ctx,
59+
const SparseCooTensor& dout,
60+
const SparseCooTensor& x,
61+
SparseCooTensor* dx) {
62+
Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, dx->mutable_indices());
63+
64+
const int sparse_dim = x.sparse_dim();
65+
std::vector<IntT> sparse_offsets(sparse_dim), dout_indexs(dout.nnz()),
66+
x_indexs(x.nnz());
67+
68+
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
69+
dout.dims(), sparse_dim, sparse_offsets.data());
70+
71+
phi::funcs::sparse::FlattenIndices(dout.indices().data<IntT>(),
72+
sparse_offsets.data(),
73+
dout.nnz(),
74+
sparse_dim,
75+
0,
76+
1,
77+
dout_indexs.data());
78+
79+
phi::funcs::sparse::FlattenIndices(x.indices().data<IntT>(),
80+
sparse_offsets.data(),
81+
x.nnz(),
82+
sparse_dim,
83+
0,
84+
1,
85+
x_indexs.data());
86+
87+
size_t i = 0, j = 0;
88+
T* dx_values_ptr = dx->mutable_values()->data<T>();
89+
const T* dout_values_ptr = dout.values().data<T>();
90+
91+
int64_t element_size = 1;
92+
for (auto j = 1; j < x.values().dims().size(); ++j) {
93+
element_size *= x.values().dims()[j];
94+
}
95+
96+
while (i < dout_indexs.size() && j < x_indexs.size()) {
97+
if (dout_indexs[i] == x_indexs[j]) {
98+
memcpy(dx_values_ptr + j * element_size,
99+
dout_values_ptr + i * element_size,
100+
element_size * sizeof(T));
101+
++i;
102+
++j;
103+
} else if (dout_indexs[i] > x_indexs[j]) {
104+
memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T));
105+
++j;
106+
} else {
107+
++i;
108+
}
109+
}
110+
while (j < x_indexs.size()) {
111+
memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T));
112+
++j;
113+
}
114+
}
115+
116+
template <typename T, typename IntT, typename Context>
117+
void CopyCsrValues(const Context& dev_ctx,
118+
const SparseCsrTensor& dout,
119+
const SparseCsrTensor& x,
120+
SparseCsrTensor* dx) {
121+
Copy(dev_ctx, x.crows(), dev_ctx.GetPlace(), false, dx->mutable_crows());
122+
Copy(dev_ctx, x.cols(), dev_ctx.GetPlace(), false, dx->mutable_cols());
123+
124+
const auto& x_dims = x.dims();
125+
int batch = x_dims.size() == 2 ? 1 : x_dims[0];
126+
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];
127+
128+
const IntT* x_crows_ptr = x.crows().data<IntT>();
129+
const IntT* x_cols_ptr = x.cols().data<IntT>();
130+
131+
const IntT* dout_crows_ptr = dout.crows().data<IntT>();
132+
const IntT* dout_cols_ptr = dout.cols().data<IntT>();
133+
const T* dout_values_ptr = dout.values().data<T>();
134+
135+
T* dx_values_ptr = dx->mutable_values()->data<T>();
136+
137+
for (int b = 0; b < batch; b++) {
138+
for (int r = 0; r < rows; r++) {
139+
int x_start = x_crows_ptr[b * (rows + 1) + r];
140+
int dout_start = dout_crows_ptr[b * (rows + 1) + r];
141+
int x_row_nnz = x_crows_ptr[b * (rows + 1) + r + 1] - x_start;
142+
int dout_row_nnz = dout_crows_ptr[b * (rows + 1) + r + 1] - dout_start;
143+
int i = 0, j = 0;
144+
while (i < x_row_nnz && j < dout_row_nnz) {
145+
if (x_cols_ptr[x_start + i] == dout_cols_ptr[dout_start + j]) {
146+
dx_values_ptr[x_start + i] = dout_values_ptr[dout_start + j];
147+
++i;
148+
++j;
149+
} else if (x_cols_ptr[x_start + i] < dout_cols_ptr[dout_start + j]) {
150+
dx_values_ptr[x_start + i] = static_cast<T>(0);
151+
++i;
152+
} else {
153+
++j;
154+
}
155+
}
156+
while (i < x_row_nnz) {
157+
dx_values_ptr[x_start + i] = static_cast<T>(0);
158+
++i;
159+
}
160+
}
161+
}
162+
}
163+
54164
template <typename T, typename IntT, typename Context>
55165
void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
56166
const SparseCsrTensor& x,
@@ -62,16 +172,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
62172
if (dx != nullptr && dy == nullptr) {
63173
VLOG(4) << "Special case when dy is not needed";
64174
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
65-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
175+
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
66176
} else if (dx == nullptr && dy != nullptr) {
67177
VLOG(4) << "Special case when dx is not needed";
68178
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
69-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
179+
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
70180
} else {
71181
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
72182
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
73-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
74-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
183+
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
184+
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
75185
}
76186
}
77187

@@ -84,12 +194,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx,
84194
SparseCsrTensor* dy) {
85195
if (dx) {
86196
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
87-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
197+
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
88198
}
89199

90200
if (dy) {
91201
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
92-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
202+
CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
93203
phi::NegativeKernel<T, Context>(
94204
dev_ctx, dout.values(), dy->mutable_values());
95205
}
@@ -105,13 +215,19 @@ void ElementWiseMultiplyCsrGradCPUKernel(const Context& dev_ctx,
105215
if (dx) {
106216
// dout*y
107217
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
108-
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, dx);
218+
SparseCsrTensor tmp_dx;
219+
AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
220+
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
221+
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
109222
}
110223

111224
if (dy) {
112225
// dout*x
113226
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
114-
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, dy);
227+
SparseCsrTensor tmp_dy;
228+
AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);
229+
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
230+
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
115231
}
116232
}
117233

@@ -126,17 +242,24 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx,
126242
if (dx) {
127243
// dout/y
128244
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
129-
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, dx);
245+
SparseCsrTensor tmp_dx;
246+
AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
247+
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
248+
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
130249
}
131250

132251
if (dy) {
133252
// -dout * out / y
134253
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
135-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
254+
SparseCsrTensor tmp_dy;
255+
AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);
256+
257+
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy);
136258
phi::NegativeKernel<T, Context>(
137-
dev_ctx, dout.values(), dy->mutable_values());
138-
auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, *dy, out);
139-
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, dy);
259+
dev_ctx, dout.values(), tmp_dy.mutable_values());
260+
auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, tmp_dy, out);
261+
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
262+
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
140263
}
141264
}
142265

@@ -151,16 +274,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx,
151274
if (dx != nullptr && dy == nullptr) {
152275
VLOG(4) << "Special case when dy is not needed";
153276
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
154-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
277+
CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
155278
} else if (dx == nullptr && dy != nullptr) {
156279
VLOG(4) << "Special case when dx is not needed";
157280
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
158-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
281+
CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
159282
} else {
160283
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
161284
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
162-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
163-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
285+
CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
286+
CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
164287
}
165288
}
166289

@@ -173,12 +296,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx,
173296
SparseCooTensor* dy) {
174297
if (dx) {
175298
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
176-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx);
299+
CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
177300
}
178301

179302
if (dy) {
180303
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
181-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
304+
CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
182305
phi::NegativeKernel<T, Context>(
183306
dev_ctx, dout.values(), dy->mutable_values());
184307
}
@@ -194,13 +317,19 @@ void ElementWiseMultiplyCooGradCPUKernel(const Context& dev_ctx,
194317
if (dx) {
195318
// dout*y
196319
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
197-
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, dx);
320+
SparseCooTensor tmp_dx;
321+
AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
322+
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
323+
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
198324
}
199325

200326
if (dy) {
201327
// dout*x
202328
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
203-
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, dy);
329+
SparseCooTensor tmp_dy;
330+
AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
331+
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
332+
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
204333
}
205334
}
206335

@@ -215,22 +344,26 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx,
215344
if (dx) {
216345
// dout/y
217346
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
218-
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, dx);
347+
SparseCooTensor tmp_dx;
348+
AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
349+
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
350+
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
219351
}
220352

221353
if (dy) {
222354
// -dout * out / y
223355
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
224-
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy);
356+
SparseCooTensor tmp_dy;
357+
AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
358+
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy);
225359
phi::NegativeKernel<T, Context>(
226-
dev_ctx, dout.values(), dy->mutable_values());
227-
auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, *dy, out);
228-
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, dy);
360+
dev_ctx, dout.values(), tmp_dy.mutable_values());
361+
auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, tmp_dy, out);
362+
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
363+
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
229364
}
230365
}
231-
// CPU Kernel end
232366

233-
// Kernel
234367
template <typename T, typename Context>
235368
void ElementWiseDivideCsrGradKernel(const Context& dev_ctx,
236369
const SparseCsrTensor& x,

0 commit comments

Comments
 (0)