@@ -200,27 +200,37 @@ void BatchNormKernel(const Context& dev_ctx,
200
200
{*mean_out},
201
201
{{" value" , static_cast <float >(momentum)}});
202
202
mean_muls_runner.Run (stream);
203
-
204
203
const auto & mean_axpy_runner =
205
204
NpuOpRunner (" Axpy" ,
206
205
{*mean_out, *saved_mean},
207
206
{*mean_out},
208
207
{{" alpha" , static_cast <float >(1 - momentum)}});
209
208
mean_axpy_runner.Run (stream);
210
-
211
209
const auto & var_muls_runner =
212
210
NpuOpRunner (" Muls" ,
213
211
{tmp_running_var},
214
212
{*variance_out},
215
213
{{" value" , static_cast <float >(momentum)}});
216
214
var_muls_runner.Run (stream);
217
-
218
215
const auto & var_axpy_runner =
219
216
NpuOpRunner (" Axpy" ,
220
217
{*variance_out, *saved_variance},
221
218
{*variance_out},
222
219
{{" alpha" , static_cast <float >(1 - momentum)}});
223
220
var_axpy_runner.Run (stream);
221
+
222
+ const auto & adds_runner =
223
+ NpuOpRunner (" Adds" ,
224
+ {*saved_variance},
225
+ {*saved_variance},
226
+ {{" value" , static_cast <float >(epsilon)}});
227
+ adds_runner.Run (stream);
228
+ const auto & inv_runner =
229
+ NpuOpRunner (" Inv" , {*saved_variance}, {*saved_variance}, {});
230
+ inv_runner.Run (stream);
231
+ const auto & sqrt_ruuner =
232
+ NpuOpRunner (" Sqrt" , {*saved_variance}, {*saved_variance}, {});
233
+ sqrt_ruuner.Run (stream);
224
234
}
225
235
}
226
236
@@ -326,8 +336,21 @@ void BatchNormGradKernel(
326
336
}
327
337
328
338
const auto * running_mean = use_global_stats ? mean.get_ptr () : &saved_mean;
329
- const auto * running_vstd =
330
- use_global_stats ? variance.get_ptr () : &saved_variance;
339
+ phi::DenseTensor running_invstd;
340
+ auto * running_vstd = use_global_stats ? variance.get_ptr () : &running_invstd;
341
+ if (!use_global_stats) {
342
+ running_invstd.Resize (saved_variance.dims ());
343
+ dev_ctx.template Alloc <float >(&running_invstd);
344
+ const auto & square_runner = NpuOpRunner (
345
+ " Square" ,
346
+ {*(use_global_stats ? variance.get_ptr () : &saved_variance)},
347
+ {running_invstd},
348
+ {});
349
+ square_runner.Run (stream);
350
+ const auto & inv_runner =
351
+ NpuOpRunner (" Inv" , {running_invstd}, {running_invstd}, {});
352
+ inv_runner.Run (stream);
353
+ }
331
354
332
355
NpuOpRunner runner_update;
333
356
runner_update.SetType (update_name)
@@ -364,26 +387,17 @@ void BatchNormGradKernel(
364
387
}
365
388
}
366
389
367
- if (use_global_stats) {
368
- const auto * running_vstd = variance.get_ptr ();
369
- const auto & runner_infer = NpuOpRunner (" BNInferGrad" ,
370
- {dy_tensor, scale, *running_vstd},
371
- {dx_tensor},
372
- {{" epsilon" , epsilon}});
373
- runner_infer.Run (stream);
374
- } else {
375
- const auto & runner_reduce = NpuOpRunner (reduce_name,
376
- {dy_tensor,
377
- x_tensor,
378
- *d_scale,
379
- *d_bias,
380
- scale,
381
- saved_mean,
382
- saved_variance},
383
- {dx_tensor},
384
- {{" epsilon" , epsilon}});
385
- runner_reduce.Run (stream);
386
- }
390
+ const auto & runner_reduce = NpuOpRunner (reduce_name,
391
+ {dy_tensor,
392
+ x_tensor,
393
+ *d_scale,
394
+ *d_bias,
395
+ scale,
396
+ *running_mean,
397
+ *running_vstd},
398
+ {dx_tensor},
399
+ {{" epsilon" , epsilon}});
400
+ runner_reduce.Run (stream);
387
401
}
388
402
}
389
403
0 commit comments