@@ -27,6 +27,7 @@ limitations under the License. */
27
27
#include " paddle/phi/kernels/elementwise_kernel.h"
28
28
#include " paddle/phi/kernels/empty_kernel.h"
29
29
#include " paddle/phi/kernels/funcs/eigen/common.h"
30
+ #include " paddle/phi/kernels/funcs/sparse/flatten_indices.h"
30
31
#include " paddle/phi/kernels/sparse/empty_kernel.h"
31
32
32
33
namespace phi {
@@ -39,6 +40,7 @@ void AllocCsrPtr(const Context& dev_ctx,
39
40
DenseTensor dx_crows = phi::EmptyLike<IntT>(dev_ctx, x.crows ());
40
41
DenseTensor dx_cols = phi::EmptyLike<IntT>(dev_ctx, x.cols ());
41
42
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values ());
43
+ dx->set_meta (x.meta ());
42
44
dx->SetMember (dx_crows, dx_cols, dx_values, x.dims ());
43
45
}
44
46
@@ -48,9 +50,117 @@ void AllocCooPtr(const Context& dev_ctx,
48
50
SparseCooTensor* dx) {
49
51
DenseTensor dx_indices = phi::EmptyLike<IntT>(dev_ctx, x.indices ());
50
52
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values ());
53
+ dx->set_meta (x.meta ());
51
54
dx->SetMember (dx_indices, dx_values, x.dims (), x.coalesced ());
52
55
}
53
56
57
+ template <typename T, typename IntT, typename Context>
58
+ void CopyCooValues (const Context& dev_ctx,
59
+ const SparseCooTensor& dout,
60
+ const SparseCooTensor& x,
61
+ SparseCooTensor* dx) {
62
+ Copy (dev_ctx, x.indices (), dev_ctx.GetPlace (), false , dx->mutable_indices ());
63
+
64
+ const int sparse_dim = x.sparse_dim ();
65
+ std::vector<IntT> sparse_offsets (sparse_dim), dout_indexs (dout.nnz ()),
66
+ x_indexs (x.nnz ());
67
+
68
+ phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
69
+ dout.dims (), sparse_dim, sparse_offsets.data ());
70
+
71
+ phi::funcs::sparse::FlattenIndices (dout.indices ().data <IntT>(),
72
+ sparse_offsets.data (),
73
+ dout.nnz (),
74
+ sparse_dim,
75
+ 0 ,
76
+ 1 ,
77
+ dout_indexs.data ());
78
+
79
+ phi::funcs::sparse::FlattenIndices (x.indices ().data <IntT>(),
80
+ sparse_offsets.data (),
81
+ x.nnz (),
82
+ sparse_dim,
83
+ 0 ,
84
+ 1 ,
85
+ x_indexs.data ());
86
+
87
+ size_t i = 0 , j = 0 ;
88
+ T* dx_values_ptr = dx->mutable_values ()->data <T>();
89
+ const T* dout_values_ptr = dout.values ().data <T>();
90
+
91
+ int64_t element_size = 1 ;
92
+ for (auto j = 1 ; j < x.values ().dims ().size (); ++j) {
93
+ element_size *= x.values ().dims ()[j];
94
+ }
95
+
96
+ while (i < dout_indexs.size () && j < x_indexs.size ()) {
97
+ if (dout_indexs[i] == x_indexs[j]) {
98
+ memcpy (dx_values_ptr + j * element_size,
99
+ dout_values_ptr + i * element_size,
100
+ element_size * sizeof (T));
101
+ ++i;
102
+ ++j;
103
+ } else if (dout_indexs[i] > x_indexs[j]) {
104
+ memset (dx_values_ptr + j * element_size, 0 , element_size * sizeof (T));
105
+ ++j;
106
+ } else {
107
+ ++i;
108
+ }
109
+ }
110
+ while (j < x_indexs.size ()) {
111
+ memset (dx_values_ptr + j * element_size, 0 , element_size * sizeof (T));
112
+ ++j;
113
+ }
114
+ }
115
+
116
+ template <typename T, typename IntT, typename Context>
117
+ void CopyCsrValues (const Context& dev_ctx,
118
+ const SparseCsrTensor& dout,
119
+ const SparseCsrTensor& x,
120
+ SparseCsrTensor* dx) {
121
+ Copy (dev_ctx, x.crows (), dev_ctx.GetPlace (), false , dx->mutable_crows ());
122
+ Copy (dev_ctx, x.cols (), dev_ctx.GetPlace (), false , dx->mutable_cols ());
123
+
124
+ const auto & x_dims = x.dims ();
125
+ int batch = x_dims.size () == 2 ? 1 : x_dims[0 ];
126
+ int rows = x_dims.size () == 2 ? x_dims[0 ] : x_dims[1 ];
127
+
128
+ const IntT* x_crows_ptr = x.crows ().data <IntT>();
129
+ const IntT* x_cols_ptr = x.cols ().data <IntT>();
130
+
131
+ const IntT* dout_crows_ptr = dout.crows ().data <IntT>();
132
+ const IntT* dout_cols_ptr = dout.cols ().data <IntT>();
133
+ const T* dout_values_ptr = dout.values ().data <T>();
134
+
135
+ T* dx_values_ptr = dx->mutable_values ()->data <T>();
136
+
137
+ for (int b = 0 ; b < batch; b++) {
138
+ for (int r = 0 ; r < rows; r++) {
139
+ int x_start = x_crows_ptr[b * (rows + 1 ) + r];
140
+ int dout_start = dout_crows_ptr[b * (rows + 1 ) + r];
141
+ int x_row_nnz = x_crows_ptr[b * (rows + 1 ) + r + 1 ] - x_start;
142
+ int dout_row_nnz = dout_crows_ptr[b * (rows + 1 ) + r + 1 ] - dout_start;
143
+ int i = 0 , j = 0 ;
144
+ while (i < x_row_nnz && j < dout_row_nnz) {
145
+ if (x_cols_ptr[x_start + i] == dout_cols_ptr[dout_start + j]) {
146
+ dx_values_ptr[x_start + i] = dout_values_ptr[dout_start + j];
147
+ ++i;
148
+ ++j;
149
+ } else if (x_cols_ptr[x_start + i] < dout_cols_ptr[dout_start + j]) {
150
+ dx_values_ptr[x_start + i] = static_cast <T>(0 );
151
+ ++i;
152
+ } else {
153
+ ++j;
154
+ }
155
+ }
156
+ while (i < x_row_nnz) {
157
+ dx_values_ptr[x_start + i] = static_cast <T>(0 );
158
+ ++i;
159
+ }
160
+ }
161
+ }
162
+ }
163
+
54
164
template <typename T, typename IntT, typename Context>
55
165
void ElementWiseAddCsrGradCPUKernel (const Context& dev_ctx,
56
166
const SparseCsrTensor& x,
@@ -62,16 +172,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
62
172
if (dx != nullptr && dy == nullptr ) {
63
173
VLOG (4 ) << " Special case when dy is not needed" ;
64
174
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
65
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dx);
175
+ CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x , dx);
66
176
} else if (dx == nullptr && dy != nullptr ) {
67
177
VLOG (4 ) << " Special case when dx is not needed" ;
68
178
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
69
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dy);
179
+ CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y , dy);
70
180
} else {
71
181
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
72
182
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
73
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dx);
74
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dy);
183
+ CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x , dx);
184
+ CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y , dy);
75
185
}
76
186
}
77
187
@@ -84,12 +194,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx,
84
194
SparseCsrTensor* dy) {
85
195
if (dx) {
86
196
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
87
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dx);
197
+ CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x , dx);
88
198
}
89
199
90
200
if (dy) {
91
201
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
92
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dy);
202
+ CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y , dy);
93
203
phi::NegativeKernel<T, Context>(
94
204
dev_ctx, dout.values (), dy->mutable_values ());
95
205
}
@@ -105,13 +215,19 @@ void ElementWiseMultiplyCsrGradCPUKernel(const Context& dev_ctx,
105
215
if (dx) {
106
216
// dout*y
107
217
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
108
- sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, dx);
218
+ SparseCsrTensor tmp_dx;
219
+ AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
220
+ sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
221
+ CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
109
222
}
110
223
111
224
if (dy) {
112
225
// dout*x
113
226
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
114
- sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, dy);
227
+ SparseCsrTensor tmp_dy;
228
+ AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);
229
+ sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
230
+ CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
115
231
}
116
232
}
117
233
@@ -126,17 +242,24 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx,
126
242
if (dx) {
127
243
// dout/y
128
244
AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
129
- sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, dx);
245
+ SparseCsrTensor tmp_dx;
246
+ AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
247
+ sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
248
+ CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
130
249
}
131
250
132
251
if (dy) {
133
252
// -dout * out / y
134
253
AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
135
- Copy (dev_ctx, dout, dev_ctx.GetPlace (), false , dy);
254
+ SparseCsrTensor tmp_dy;
255
+ AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);
256
+
257
+ Copy (dev_ctx, dout, dev_ctx.GetPlace (), false , &tmp_dy);
136
258
phi::NegativeKernel<T, Context>(
137
- dev_ctx, dout.values (), dy->mutable_values ());
138
- auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, *dy, out);
139
- sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, dy);
259
+ dev_ctx, dout.values (), tmp_dy.mutable_values ());
260
+ auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, tmp_dy, out);
261
+ sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
262
+ CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
140
263
}
141
264
}
142
265
@@ -151,16 +274,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx,
151
274
if (dx != nullptr && dy == nullptr ) {
152
275
VLOG (4 ) << " Special case when dy is not needed" ;
153
276
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
154
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dx);
277
+ CopyCooValues<T, IntT, Context>(dev_ctx, dout, x , dx);
155
278
} else if (dx == nullptr && dy != nullptr ) {
156
279
VLOG (4 ) << " Special case when dx is not needed" ;
157
280
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
158
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dy);
281
+ CopyCooValues<T, IntT, Context>(dev_ctx, dout, y , dy);
159
282
} else {
160
283
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
161
284
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
162
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dx);
163
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dy);
285
+ CopyCooValues<T, IntT, Context>(dev_ctx, dout, x , dx);
286
+ CopyCooValues<T, IntT, Context>(dev_ctx, dout, y , dy);
164
287
}
165
288
}
166
289
@@ -173,12 +296,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx,
173
296
SparseCooTensor* dy) {
174
297
if (dx) {
175
298
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
176
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dx);
299
+ CopyCooValues<T, IntT, Context>(dev_ctx, dout, x , dx);
177
300
}
178
301
179
302
if (dy) {
180
303
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
181
- Copy (dev_ctx, dout, dev_ctx. GetPlace (), false , dy);
304
+ CopyCooValues<T, IntT, Context>(dev_ctx, dout, y , dy);
182
305
phi::NegativeKernel<T, Context>(
183
306
dev_ctx, dout.values (), dy->mutable_values ());
184
307
}
@@ -194,13 +317,19 @@ void ElementWiseMultiplyCooGradCPUKernel(const Context& dev_ctx,
194
317
if (dx) {
195
318
// dout*y
196
319
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
197
- sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, dx);
320
+ SparseCooTensor tmp_dx;
321
+ AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
322
+ sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
323
+ CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
198
324
}
199
325
200
326
if (dy) {
201
327
// dout*x
202
328
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
203
- sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, dy);
329
+ SparseCooTensor tmp_dy;
330
+ AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
331
+ sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
332
+ CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
204
333
}
205
334
}
206
335
@@ -215,22 +344,26 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx,
215
344
if (dx) {
216
345
// dout/y
217
346
AllocCooPtr<T, IntT>(dev_ctx, x, dx);
218
- sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, dx);
347
+ SparseCooTensor tmp_dx;
348
+ AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
349
+ sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
350
+ CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
219
351
}
220
352
221
353
if (dy) {
222
354
// -dout * out / y
223
355
AllocCooPtr<T, IntT>(dev_ctx, y, dy);
224
- Copy (dev_ctx, dout, dev_ctx.GetPlace (), false , dy);
356
+ SparseCooTensor tmp_dy;
357
+ AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
358
+ Copy (dev_ctx, dout, dev_ctx.GetPlace (), false , &tmp_dy);
225
359
phi::NegativeKernel<T, Context>(
226
- dev_ctx, dout.values (), dy->mutable_values ());
227
- auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, *dy, out);
228
- sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, dy);
360
+ dev_ctx, dout.values (), tmp_dy.mutable_values ());
361
+ auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, tmp_dy, out);
362
+ sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
363
+ CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
229
364
}
230
365
}
231
- // CPU Kernel end
232
366
233
- // Kernel
234
367
template <typename T, typename Context>
235
368
void ElementWiseDivideCsrGradKernel (const Context& dev_ctx,
236
369
const SparseCsrTensor& x,
0 commit comments