@@ -149,19 +149,20 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
149
149
SparseCsrTensor* dy) {
150
150
#if CUDA_VERSION >= 11000
151
151
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
152
+ SparseCsrTensor tmp_dout;
153
+ CastCsrKernel<T, Context>(
154
+ dev_ctx, dout, phi::DataType::INT32, dout.values ().dtype (), &tmp_dout);
152
155
// dx{SparseCsr} = dout{SparseCsr} * y'{SparseCsr}
153
156
if (dx) {
154
157
auto dims_numel = y.dims ().size ();
155
- SparseCsrTensor transpose_y, tmp_dout, tmp_y;
158
+ SparseCsrTensor transpose_y, tmp_y;
156
159
if (dims_numel == 2 ) {
157
160
TransposeCsrKernel<T, Context>(dev_ctx, y, {1 , 0 }, &transpose_y);
158
161
} else {
159
162
TransposeCsrKernel<T, Context>(dev_ctx, y, {0 , 2 , 1 }, &transpose_y);
160
163
}
161
164
CastCsrKernel<T, Context>(
162
- dev_ctx, dout, phi::DATATYPE::INT32, dout.values ().dtype (), &tmp_dout);
163
- CastCsrKernel<T, Context>(
164
- dev_ctx, transpose_y, phi::DATATYPE::INT32, y.values ().dtype (), &tmp_y);
165
+ dev_ctx, transpose_y, phi::DataType::INT32, y.values ().dtype (), &tmp_y);
165
166
166
167
sparse_blas.SPMM (false ,
167
168
false ,
@@ -175,16 +176,14 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
175
176
// dy{SparseCsr} = x'{SparseCsr} * dout{SparseCsr}
176
177
if (dy) {
177
178
auto dims_numel = x.dims ().size ();
178
- SparseCsrTensor transpose_x, tmp_dout, tmp_x;
179
+ SparseCsrTensor transpose_x, tmp_x;
179
180
if (dims_numel == 2 ) {
180
181
TransposeCsrKernel<T, Context>(dev_ctx, x, {1 , 0 }, &transpose_x);
181
182
} else {
182
183
TransposeCsrKernel<T, Context>(dev_ctx, x, {0 , 2 , 1 }, &transpose_x);
183
184
}
184
185
CastCsrKernel<T, Context>(
185
- dev_ctx, dout, phi::DATATYPE::INT32, dout.values ().dtype (), &tmp_dout);
186
- CastCsrKernel<T, Context>(
187
- dev_ctx, transpose_x, phi::DATATYPE::INT32, x.values ().dtype (), &tmp_x);
186
+ dev_ctx, transpose_x, phi::DataType::INT32, x.values ().dtype (), &tmp_x);
188
187
sparse_blas.SPMM (false ,
189
188
false ,
190
189
static_cast <T>(1 ),
0 commit comments