Skip to content

Commit 3c3c9e5

Browse files
authored
[NPU] fix bn (PaddlePaddle#569)
1 parent cc62b02 commit 3c3c9e5

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

backends/npu/kernels/batch_norm_kernel.cc

+39-25
Original file line numberDiff line numberDiff line change
@@ -200,27 +200,37 @@ void BatchNormKernel(const Context& dev_ctx,
200200
{*mean_out},
201201
{{"value", static_cast<float>(momentum)}});
202202
mean_muls_runner.Run(stream);
203-
204203
const auto& mean_axpy_runner =
205204
NpuOpRunner("Axpy",
206205
{*mean_out, *saved_mean},
207206
{*mean_out},
208207
{{"alpha", static_cast<float>(1 - momentum)}});
209208
mean_axpy_runner.Run(stream);
210-
211209
const auto& var_muls_runner =
212210
NpuOpRunner("Muls",
213211
{tmp_running_var},
214212
{*variance_out},
215213
{{"value", static_cast<float>(momentum)}});
216214
var_muls_runner.Run(stream);
217-
218215
const auto& var_axpy_runner =
219216
NpuOpRunner("Axpy",
220217
{*variance_out, *saved_variance},
221218
{*variance_out},
222219
{{"alpha", static_cast<float>(1 - momentum)}});
223220
var_axpy_runner.Run(stream);
221+
222+
const auto& adds_runner =
223+
NpuOpRunner("Adds",
224+
{*saved_variance},
225+
{*saved_variance},
226+
{{"value", static_cast<float>(epsilon)}});
227+
adds_runner.Run(stream);
228+
const auto& inv_runner =
229+
NpuOpRunner("Inv", {*saved_variance}, {*saved_variance}, {});
230+
inv_runner.Run(stream);
231+
const auto& sqrt_ruuner =
232+
NpuOpRunner("Sqrt", {*saved_variance}, {*saved_variance}, {});
233+
sqrt_ruuner.Run(stream);
224234
}
225235
}
226236

@@ -326,8 +336,21 @@ void BatchNormGradKernel(
326336
}
327337

328338
const auto* running_mean = use_global_stats ? mean.get_ptr() : &saved_mean;
329-
const auto* running_vstd =
330-
use_global_stats ? variance.get_ptr() : &saved_variance;
339+
phi::DenseTensor running_invstd;
340+
auto* running_vstd = use_global_stats ? variance.get_ptr() : &running_invstd;
341+
if (!use_global_stats) {
342+
running_invstd.Resize(saved_variance.dims());
343+
dev_ctx.template Alloc<float>(&running_invstd);
344+
const auto& square_runner = NpuOpRunner(
345+
"Square",
346+
{*(use_global_stats ? variance.get_ptr() : &saved_variance)},
347+
{running_invstd},
348+
{});
349+
square_runner.Run(stream);
350+
const auto& inv_runner =
351+
NpuOpRunner("Inv", {running_invstd}, {running_invstd}, {});
352+
inv_runner.Run(stream);
353+
}
331354

332355
NpuOpRunner runner_update;
333356
runner_update.SetType(update_name)
@@ -364,26 +387,17 @@ void BatchNormGradKernel(
364387
}
365388
}
366389

367-
if (use_global_stats) {
368-
const auto* running_vstd = variance.get_ptr();
369-
const auto& runner_infer = NpuOpRunner("BNInferGrad",
370-
{dy_tensor, scale, *running_vstd},
371-
{dx_tensor},
372-
{{"epsilon", epsilon}});
373-
runner_infer.Run(stream);
374-
} else {
375-
const auto& runner_reduce = NpuOpRunner(reduce_name,
376-
{dy_tensor,
377-
x_tensor,
378-
*d_scale,
379-
*d_bias,
380-
scale,
381-
saved_mean,
382-
saved_variance},
383-
{dx_tensor},
384-
{{"epsilon", epsilon}});
385-
runner_reduce.Run(stream);
386-
}
390+
const auto& runner_reduce = NpuOpRunner(reduce_name,
391+
{dy_tensor,
392+
x_tensor,
393+
*d_scale,
394+
*d_bias,
395+
scale,
396+
*running_mean,
397+
*running_vstd},
398+
{dx_tensor},
399+
{{"epsilon", epsilon}});
400+
runner_reduce.Run(stream);
387401
}
388402
}
389403

0 commit comments

Comments
 (0)