diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h index 5bd695ac8c124e..961af8b96738a3 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad.h +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -340,7 +340,17 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx, DenseTensor *dy, int axis = -1) { const auto place = dev_ctx.GetPlace(); - if (dx != nullptr && dy != nullptr) { + if (dx != nullptr && dx->numel() == 0) { + dev_ctx.Alloc(dx); + } + + if (dy != nullptr && dy->numel() == 0) { + dev_ctx.Alloc(dy); + } + + bool need_dx = (dx != nullptr) && (dx->numel() != 0); + bool need_dy = (dy != nullptr) && (dy->numel() != 0); + if (need_dx && need_dy) { std::vector ins = {&dout, &out, &y}; GetGradXAndYOut(dev_ctx, place, @@ -350,11 +360,11 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx, dx, dy, funcs::DivGradXYFunctor()); - } else if (dx != nullptr && dy == nullptr) { + } else if (need_dx) { std::vector ins = {&dout, &y}; GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor()); - } else if (dy != nullptr && dx == nullptr) { + } else if (need_dy) { std::vector ins = {&dout, &out, &y}; GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor()); @@ -377,7 +387,17 @@ void ElementwiseMulGrad(const GPUContext &dev_ctx, int axis) { const auto place = dev_ctx.GetPlace(); - if (dx != nullptr && dy != nullptr) { + if (dx != nullptr && dx->numel() == 0) { + dev_ctx.Alloc(dx); + } + + if (dy != nullptr && dy->numel() == 0) { + dev_ctx.Alloc(dy); + } + + bool need_dx = (dx != nullptr) && (dx->numel() != 0); + bool need_dy = (dy != nullptr) && (dy->numel() != 0); + if (need_dx && need_dy) { std::vector ins = {&dout, &y, &x}; GetGradXAndYOut(dev_ctx, place, @@ -387,11 +407,11 @@ void ElementwiseMulGrad(const GPUContext &dev_ctx, dx, dy, funcs::MultiplyGradXYFunctor()); - } else if (dx != nullptr && dy == nullptr) { + } else if (need_dx) { std::vector ins = {&dout, &y}; GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor()); - } else if (dx == nullptr && dy != nullptr) { + } else if (need_dy) { std::vector ins = {&dout, &x}; GetGradXOrYOut( dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor());