Skip to content

Commit 677b0fb

Browse files
authored
[NPU] Fix inplace multiply_grad (#1274)
1 parent 2444df5 commit 677b0fb

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

backends/npu/kernels/elementwise_mul_kernel.cc

+6
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ void MultiplyGradKernel(const Context& dev_ctx,
212212
if (dx) {
213213
phi::DenseTensor trans_y;
214214
NpuBroadcast<T>(dev_ctx, &y, y_axis, dst_dims, &trans_y);
215+
// For inplace strategy, dx will be stored in addr of dout, which makes
216+
// the result of dy wrong.
217+
if (dx->IsSharedWith(dout)) {
218+
dx->clear();
219+
dx->Resize(x.dims());
220+
}
215221
if (dx->dims() == dout.dims()) {
216222
dev_ctx.template Alloc<T>(dx);
217223
EXEC_NPU_CMD(aclnnMul, dev_ctx, dout, trans_y, *dx);

0 commit comments

Comments
 (0)