Skip to content

Commit 08fa89f

Browse files
authored
Fix softmax and prelu ut problem. (PaddlePaddle#1262)
1 parent cef4ed5 commit 08fa89f

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

backends/npu/kernels/prelu_kernel.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ void PReluKernel(const Context& dev_ctx,
9191
FillNpuTensorWithConstant<T>(out, dev_ctx, val);
9292
out->Resize(out_dim);
9393
} else {
94-
const auto& runner = NpuOpRunner("PRelu", {x, alpha}, {*out}, {});
95-
runner.Run(stream);
94+
EXEC_NPU_CMD(aclnnPrelu, dev_ctx, x, alpha, *out);
9695
}
9796
}
9897
}
@@ -207,9 +206,13 @@ void PReluGradKernel(const Context& dev_ctx,
207206
x_grad->Resize(x_grad_dim);
208207
} else {
209208
phi::DenseTensor weight(alpha);
210-
const auto& runner = NpuOpRunner(
211-
"PReluGrad", {out_grad, x, weight}, {*x_grad, *alpha_grad}, {});
212-
runner.Run(stream);
209+
EXEC_NPU_CMD(aclnnPreluBackward,
210+
dev_ctx,
211+
out_grad,
212+
x,
213+
weight,
214+
*x_grad,
215+
*alpha_grad);
213216
}
214217
}
215218
}

backends/npu/kernels/softmax_kernel.cc

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717

1818
namespace custom_kernel {
1919

20+
template <typename T, typename Context>
21+
void CastKernel(const Context& dev_ctx,
22+
const phi::DenseTensor& x,
23+
phi::DataType dtype,
24+
phi::DenseTensor* out);
25+
2026
template <typename T, typename Context>
2127
void AclopSoftmaxKernel(const Context& dev_ctx,
2228
const phi::DenseTensor& x,
@@ -141,7 +147,32 @@ void SoftmaxGradKernel(const Context& dev_ctx,
141147
dev_ctx, out, out_grad, axis, x_grad)));
142148
dev_ctx.template Alloc<T>(x_grad);
143149
int64_t dim = static_cast<int64_t>(axis);
144-
EXEC_NPU_CMD(aclnnSoftmaxBackward, dev_ctx, out_grad, out, dim, *x_grad);
150+
151+
phi::DenseTensor cast_x;
152+
if (out_grad.dtype() == phi::DataType::FLOAT64) {
153+
phi::DenseTensorMeta meta(out_grad.meta());
154+
meta.dtype = phi::DataType::FLOAT32;
155+
cast_x.set_meta(meta);
156+
157+
custom_kernel::CastKernel<T, Context>(
158+
dev_ctx, out_grad, phi::DataType::FLOAT32, &cast_x);
159+
} else {
160+
cast_x = out_grad;
161+
}
162+
163+
phi::DenseTensor cast_y;
164+
if (out.dtype() == phi::DataType::FLOAT64) {
165+
phi::DenseTensorMeta meta(out.meta());
166+
meta.dtype = phi::DataType::FLOAT32;
167+
cast_y.set_meta(meta);
168+
169+
custom_kernel::CastKernel<T, Context>(
170+
dev_ctx, out, phi::DataType::FLOAT32, &cast_y);
171+
} else {
172+
cast_y = out;
173+
}
174+
175+
EXEC_NPU_CMD(aclnnSoftmaxBackward, dev_ctx, cast_x, cast_y, dim, *x_grad);
145176
}
146177

147178
} // namespace custom_kernel

0 commit comments

Comments
 (0)