Skip to content

Commit eb51342

Browse files
committed
fix
1 parent 9fc182f commit eb51342

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,20 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
149149
SparseCsrTensor* dy) {
150150
#if CUDA_VERSION >= 11000
151151
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
152+
SparseCsrTensor tmp_dout;
153+
CastCsrKernel<T, Context>(
154+
dev_ctx, dout, phi::DataType::INT32, dout.values().dtype(), &tmp_dout);
152155
// dx{SparseCsr} = dout{SparseCsr} * y'{SparseCsr}
153156
if (dx) {
154157
auto dims_numel = y.dims().size();
155-
SparseCsrTensor transpose_y, tmp_dout, tmp_y;
158+
SparseCsrTensor transpose_y, tmp_y;
156159
if (dims_numel == 2) {
157160
TransposeCsrKernel<T, Context>(dev_ctx, y, {1, 0}, &transpose_y);
158161
} else {
159162
TransposeCsrKernel<T, Context>(dev_ctx, y, {0, 2, 1}, &transpose_y);
160163
}
161164
CastCsrKernel<T, Context>(
162-
dev_ctx, dout, phi::DATATYPE::INT32, dout.values().dtype(), &tmp_dout);
163-
CastCsrKernel<T, Context>(
164-
dev_ctx, transpose_y, phi::DATATYPE::INT32, y.values().dtype(), &tmp_y);
165+
dev_ctx, transpose_y, phi::DataType::INT32, y.values().dtype(), &tmp_y);
165166

166167
sparse_blas.SPMM(false,
167168
false,
@@ -175,16 +176,14 @@ void MatmulCsrCsrGradKernel(const Context& dev_ctx,
175176
// dy{SparseCsr} = x'{SparseCsr} * dout{SparseCsr}
176177
if (dy) {
177178
auto dims_numel = x.dims().size();
178-
SparseCsrTensor transpose_x, tmp_dout, tmp_x;
179+
SparseCsrTensor transpose_x, tmp_x;
179180
if (dims_numel == 2) {
180181
TransposeCsrKernel<T, Context>(dev_ctx, x, {1, 0}, &transpose_x);
181182
} else {
182183
TransposeCsrKernel<T, Context>(dev_ctx, x, {0, 2, 1}, &transpose_x);
183184
}
184185
CastCsrKernel<T, Context>(
185-
dev_ctx, dout, phi::DATATYPE::INT32, dout.values().dtype(), &tmp_dout);
186-
CastCsrKernel<T, Context>(
187-
dev_ctx, transpose_x, phi::DATATYPE::INT32, x.values().dtype(), &tmp_x);
186+
dev_ctx, transpose_x, phi::DataType::INT32, x.values().dtype(), &tmp_x);
188187
sparse_blas.SPMM(false,
189188
false,
190189
static_cast<T>(1),

paddle/phi/kernels/sparse/gpu/matmul_kernel.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License. */
2929
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
3030
#include "paddle/phi/kernels/sparse/empty_kernel.h"
3131
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
32+
#include "paddle/phi/kernels/sparse/unary_kernel.h"
3233

3334
namespace phi {
3435
namespace sparse {
@@ -248,9 +249,9 @@ void MatmulCsrCsrKernel(const Context& dev_ctx,
248249
out->set_dims(common::make_ddim(out_dim_vec));
249250
SparseCsrTensor x_tmp, y_tmp;
250251
CastCsrKernel<T, Context>(
251-
dev_ctx, x, phi::DATATYPE::INT32, x.values().dtype(), &x_tmp);
252+
dev_ctx, x, phi::DataType::INT32, x.values().dtype(), &x_tmp);
252253
CastCsrKernel<T, Context>(
253-
dev_ctx, y, phi::DATATYPE::INT32, y.values().dtype(), &y_tmp);
254+
dev_ctx, y, phi::DataType::INT32, y.values().dtype(), &y_tmp);
254255

255256
auto sparse_blas = phi::funcs::sparse::GetSparseBlas<Context, T>(dev_ctx);
256257
sparse_blas.SPMM(

paddle/phi/kernels/sparse/gpu/transpose_kernel.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ void TransposeCsrKernel(const Context &dev_ctx,
312312
out_values_data);
313313
}
314314
}
315-
316315
} // namespace sparse
317316
} // namespace phi
318317

0 commit comments

Comments
 (0)