@@ -52,7 +52,11 @@ void cumsum_grad(const Tensor& x,
52
52
Tensor* x_grad) {
53
53
if (x_grad) {
54
54
auto grad = cumsum<T>(out_grad, axis, flatten, exclusive, !reverse);
55
- grad = reshape<T>(grad, x.shape ());
55
+ if (has_dynamic_shape (x.shape ())) {
56
+ grad = backend::reshape<T>(grad, shape<T>(x));
57
+ } else {
58
+ grad = reshape<T>(grad, x.shape ());
59
+ }
56
60
set_output<T>(grad, x_grad);
57
61
}
58
62
}
@@ -146,8 +150,14 @@ void divide_grad(const Tensor& x,
146
150
template <typename T>
147
151
void floor_grad (const Tensor& out_grad, Tensor* x_grad) {
148
152
if (x_grad) {
149
- auto zero_tensor =
150
- full<T>(common::vectorize (out_grad.dims ()), 0.0 , out_grad.dtype ());
153
+ Tensor zero_tensor;
154
+ if (has_dynamic_shape (out_grad.shape ())) {
155
+ zero_tensor = backend::full_with_tensor<T>(
156
+ shape<T>(out_grad), 0.0 , out_grad.dtype ());
157
+ } else {
158
+ zero_tensor =
159
+ full<T>(common::vectorize (out_grad.dims ()), 0.0 , out_grad.dtype ());
160
+ }
151
161
set_output<T>(zero_tensor, x_grad);
152
162
}
153
163
}
@@ -303,9 +313,12 @@ void gelu_grad(const Tensor& x,
303
313
if (approximate) {
304
314
float kbeta = M_SQRT2 * M_2_SQRTPI * 0.5 ;
305
315
float kkappa = 0.044715 ;
316
+ Tensor kbeta_ = full_scalar<T>(kbeta, promoted_x.dtype ());
317
+ Tensor kkappa_ = full_scalar<T>(kkappa, promoted_x.dtype ());
318
+
306
319
auto x_sq = promoted_x * promoted_x;
307
320
auto x_cube = x_sq * promoted_x;
308
- auto inner = kbeta * (promoted_x + kkappa * x_cube);
321
+ auto inner = kbeta_ * (promoted_x + kkappa_ * x_cube);
309
322
auto tanh_inner = tanh<T>(inner);
310
323
311
324
auto left = scale<T>(promoted_x, 0.5 );
@@ -314,7 +327,7 @@ void gelu_grad(const Tensor& x,
314
327
auto left_derivative = scale<T>(right, 0.5 );
315
328
316
329
auto tanh_derivative = scale<T>(tanh_inner * tanh_inner, -1 ., 1 .);
317
- auto inner_derivative = kbeta * (scale<T>(3 * kkappa * x_sq, 1 ., 1 .));
330
+ auto inner_derivative = kbeta_ * (scale<T>(3 * kkappa_ * x_sq, 1 ., 1 .));
318
331
auto right_derivative = left * tanh_derivative * inner_derivative;
319
332
320
333
set_output<T>(
@@ -324,8 +337,11 @@ void gelu_grad(const Tensor& x,
324
337
} else {
325
338
float kalpha = M_SQRT1_2;
326
339
float kbeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 ;
327
- auto cdf = scale<T>(scale<T>(erf<T>(kalpha * promoted_x), 1 ., 1 .), 0.5 );
328
- auto pdf = kbeta * exp<T>(scale<T>(promoted_x * promoted_x, -0.5 ));
340
+ Tensor kalpha_ = full_scalar<T>(kalpha, promoted_x.dtype ());
341
+ Tensor kbeta_ = full_scalar<T>(kbeta, promoted_x.dtype ());
342
+
343
+ auto cdf = scale<T>(scale<T>(erf<T>(kalpha_ * promoted_x), 1 ., 1 .), 0.5 );
344
+ auto pdf = kbeta_ * exp<T>(scale<T>(promoted_x * promoted_x, -0.5 ));
329
345
set_output<T>(
330
346
cast<T>(promoted_out_grad * (cdf + promoted_x * pdf), x.type ()),
331
347
x_grad);
@@ -336,9 +352,12 @@ void gelu_grad(const Tensor& x,
336
352
if (approximate) {
337
353
auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 ;
338
354
auto kKappa = 0.044715 ;
355
+ Tensor kBeta_ = full_scalar<T>(kBeta , x.dtype ());
356
+ Tensor kKappa_ = full_scalar<T>(kKappa , x.dtype ());
357
+
339
358
auto x_sq = x * x;
340
359
auto x_cube = x_sq * x;
341
- auto inner = kBeta * (x + kKappa * x_cube);
360
+ auto inner = kBeta_ * (x + kKappa_ * x_cube);
342
361
auto tanh_inner = tanh<T>(inner);
343
362
344
363
auto left = scale<T>(x, 0.5 );
@@ -347,15 +366,18 @@ void gelu_grad(const Tensor& x,
347
366
auto left_derivative = scale<T>(right, 0.5 );
348
367
349
368
auto tanh_derivative = scale<T>(tanh_inner * tanh_inner, -1 ., 1 .);
350
- auto inner_derivative = kBeta * (scale<T>(3 * kKappa * x_sq, 1 ., 1 .));
369
+ auto inner_derivative = kBeta_ * (scale<T>(3 * kKappa_ * x_sq, 1 ., 1 .));
351
370
auto right_derivative = left * tanh_derivative * inner_derivative;
352
371
353
372
set_output<T>(out_grad * (left_derivative + right_derivative), x_grad);
354
373
} else {
355
374
auto kAlpha = M_SQRT1_2;
356
375
auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 ;
357
- auto cdf = scale<T>(scale<T>(erf<T>(kAlpha * x), 1 ., 1 .), 0.5 );
358
- auto pdf = kBeta * exp<T>(scale<T>(x * x, -0.5 ));
376
+ Tensor kAlpha_ = full_scalar<T>(kAlpha , x.dtype ());
377
+ Tensor kBeta_ = full_scalar<T>(kBeta , x.dtype ());
378
+
379
+ auto cdf = scale<T>(scale<T>(erf<T>(kAlpha_ * x), 1 ., 1 .), 0.5 );
380
+ auto pdf = kBeta_ * exp<T>(scale<T>(x * x, -0.5 ));
359
381
set_output<T>(out_grad * (cdf + x * pdf), x_grad);
360
382
}
361
383
}
@@ -409,8 +431,13 @@ void reduce_as_grad(const Tensor& x,
409
431
template <typename T>
410
432
void reshape_grad (const Tensor& x, const Tensor& grad_out, Tensor* grad_x) {
411
433
if (grad_x) {
412
- const auto & x_dims = x.dims ();
413
- auto grad_x_tmp = reshape<T>(grad_out, common::vectorize (x_dims));
434
+ Tensor grad_x_tmp;
435
+ if (has_dynamic_shape (x.shape ())) {
436
+ grad_x_tmp = backend::reshape<T>(grad_out, shape<T>(x));
437
+ } else {
438
+ const auto & x_dims = x.dims ();
439
+ grad_x_tmp = reshape<T>(grad_out, common::vectorize (x_dims));
440
+ }
414
441
set_output<T>(grad_x_tmp, grad_x);
415
442
}
416
443
}
@@ -503,7 +530,7 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
503
530
template <typename T>
504
531
void tanh_grad (const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
505
532
if (!grad_x) return ;
506
- auto grad_x_tmp = grad_out * (1 - out * out);
533
+ auto grad_x_tmp = grad_out * (full_scalar<T>( 1.0 , out. dtype ()) - out * out);
507
534
set_output<T>(grad_x_tmp, grad_x);
508
535
}
509
536
@@ -961,9 +988,8 @@ void dropout_grad(const Tensor& mask,
961
988
template <typename T>
962
989
void erf_grad (const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
963
990
if (x_grad) {
964
- auto m_2_sqrt_pi =
965
- full<T>(common::vectorize (x.dims ()), M_2_SQRTPI, x.dtype ());
966
- auto neg_one = full<T>(common::vectorize (x.dims ()), -1.0 , x.dtype ());
991
+ auto m_2_sqrt_pi = full_scalar<T>(M_2_SQRTPI, x.dtype ());
992
+ auto neg_one = full_scalar<T>(-1.0 , x.dtype ());
967
993
auto neg_tmp = neg_one * x * x;
968
994
auto mul_tmp = m_2_sqrt_pi * exp<T>(neg_tmp);
969
995
set_output<T>(out_grad * mul_tmp, x_grad);
@@ -1000,7 +1026,8 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
1000
1026
template <typename T>
1001
1027
void square_grad (const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
1002
1028
if (x_grad) {
1003
- Tensor x_grad_tmp = 2 * x * out_grad;
1029
+ auto two = full_scalar<T>(2.0 , x.dtype ());
1030
+ Tensor x_grad_tmp = two * x * out_grad;
1004
1031
set_output<T>(x_grad_tmp, x_grad);
1005
1032
}
1006
1033
}
@@ -1046,17 +1073,17 @@ void silu_grad(const Tensor& x,
1046
1073
const Tensor& out_grad,
1047
1074
Tensor* x_grad) {
1048
1075
if (x_grad) {
1076
+ auto one = full_scalar<T>(1.0 , x.dtype ());
1049
1077
auto org_dtype = x.dtype ();
1050
1078
bool need_cast = org_dtype == phi::DataType::FLOAT16 ||
1051
1079
org_dtype == phi::DataType::BFLOAT16;
1052
1080
if (need_cast) {
1053
1081
auto x_cast = cast<T>(x, phi::DataType::FLOAT32);
1054
1082
auto out_cast = cast<T>(out, phi::DataType::FLOAT32);
1055
1083
auto out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
1056
- auto res = out_grad_cast * sigmoid<T>(x_cast) * (1.0 + x_cast - out_cast);
1084
+ auto res = out_grad_cast * sigmoid<T>(x_cast) * (one + x_cast - out_cast);
1057
1085
set_output<T>(cast<T>(res, org_dtype), x_grad);
1058
1086
} else {
1059
- auto one = full_scalar<T>(1.0 , x.dtype ());
1060
1087
auto res = out_grad * sigmoid<T>(x) * (one + x - out);
1061
1088
set_output<T>(res, x_grad);
1062
1089
}
@@ -1243,29 +1270,39 @@ void maximum_grad(const Tensor& x,
1243
1270
if (x_grad) {
1244
1271
auto x_tmp = cast<T>(greater_than<T>(x, y), out_grad.dtype ());
1245
1272
auto dx_res = out_grad * x_tmp;
1246
- if (out_grad.dims () != x.dims ()) {
1247
- auto reduce_dim = get_reduce_dims_from_out (out_grad.dims (), x.dims ());
1248
- auto dx_reduce_res =
1249
- dx_res.sum (common::vectorize (reduce_dim), x.dtype (), false );
1250
- auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize (x.dims ()));
1251
- set_output<T>(dx_tmp, x_grad);
1273
+ if (has_dynamic_shape (x.shape ())) {
1274
+ auto dx_reduce_res = reduce_as<T>(dx_res, x);
1275
+ set_output<T>(dx_reduce_res, x_grad);
1252
1276
} else {
1253
- set_output<T>(dx_res, x_grad);
1277
+ if (out_grad.dims () != x.dims ()) {
1278
+ auto reduce_dim = get_reduce_dims_from_out (out_grad.dims (), x.dims ());
1279
+ auto dx_reduce_res =
1280
+ dx_res.sum (common::vectorize (reduce_dim), x.dtype (), false );
1281
+ auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize (x.dims ()));
1282
+ set_output<T>(dx_tmp, x_grad);
1283
+ } else {
1284
+ set_output<T>(dx_res, x_grad);
1285
+ }
1254
1286
}
1255
1287
}
1256
1288
1257
1289
if (y_grad) {
1258
1290
auto y_tmp = cast<T>(less_equal<T>(x, y), out_grad.dtype ());
1259
1291
auto dy_res = out_grad * y_tmp;
1260
- if (out_grad.dims () != y.dims ()) {
1261
- phi::DDim reduce_dim =
1262
- get_reduce_dims_from_out (out_grad.dims (), y.dims ());
1263
- auto dy_reduce_res =
1264
- dy_res.sum (common::vectorize (reduce_dim), y.dtype (), false );
1265
- auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize (y.dims ()));
1266
- set_output<T>(dy_tmp, y_grad);
1292
+ if (has_dynamic_shape (y.shape ())) {
1293
+ auto dy_reduce_res = reduce_as<T>(dy_res, y);
1294
+ set_output<T>(dy_reduce_res, y_grad);
1267
1295
} else {
1268
- set_output<T>(dy_res, y_grad);
1296
+ if (out_grad.dims () != y.dims ()) {
1297
+ phi::DDim reduce_dim =
1298
+ get_reduce_dims_from_out (out_grad.dims (), y.dims ());
1299
+ auto dy_reduce_res =
1300
+ dy_res.sum (common::vectorize (reduce_dim), y.dtype (), false );
1301
+ auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize (y.dims ()));
1302
+ set_output<T>(dy_tmp, y_grad);
1303
+ } else {
1304
+ set_output<T>(dy_res, y_grad);
1305
+ }
1269
1306
}
1270
1307
}
1271
1308
}
@@ -1664,13 +1701,19 @@ void tile_grad(const Tensor& x,
1664
1701
template <typename T>
1665
1702
void hardswish_grad (const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
1666
1703
if (x_grad) {
1667
- auto offset = full<T>(common::vectorize (x.dims ()), 3.0 , x.dtype ());
1704
+ const Tensor offset = full_scalar<T>(3.0 , x.dtype ());
1705
+ Tensor zero;
1706
+ if (has_dynamic_shape (x.shape ())) {
1707
+ zero = backend::full_with_tensor<T>(shape<T>(x), 0.0 , x.dtype ());
1708
+ } else {
1709
+ zero = full<T>(common::vectorize (x.dims ()), 0.0 , x.dtype ());
1710
+ }
1668
1711
auto condition = less_equal<T>(x, offset);
1669
- auto tmp1 = where <T>(condition, out_grad * ((x / 3.0 ) + 0.5 ), out_grad );
1670
- auto res = where<T>(
1671
- less_than <T>(x, full<T>( common::vectorize (x. dims ()), - 3.0 , x. dtype ())),
1672
- full<T>( common::vectorize (x. dims ()), 0.0 , x. dtype ()),
1673
- tmp1);
1712
+ auto factor = full_scalar <T>(0.5 , x. dtype () );
1713
+ auto tmp1 =
1714
+ where <T>(condition, out_grad * ((x / offset) + factor), out_grad);
1715
+ auto res =
1716
+ where<T>(less_than<T>(x, full_scalar<T>(- 3.0 , x. dtype ())), zero, tmp1);
1674
1717
set_output<T>(res, x_grad);
1675
1718
}
1676
1719
}
@@ -1681,8 +1724,8 @@ void leaky_relu_grad(const Tensor& out,
1681
1724
float negative_slope,
1682
1725
Tensor* x_grad) {
1683
1726
if (x_grad) {
1684
- auto condition = greater_than <T>(
1685
- out, full <T>(common::vectorize ( out. dims ()), 0.0 , out. dtype ()) );
1727
+ auto zero = full_scalar <T>(0.0 , out. dtype ());
1728
+ auto condition = greater_than <T>(out, zero );
1686
1729
auto res = where<T>(condition, out_grad, out_grad * negative_slope);
1687
1730
set_output<T>(res, x_grad);
1688
1731
}
@@ -2015,29 +2058,39 @@ void minimum_grad(const Tensor& x,
2015
2058
if (x_grad) {
2016
2059
auto x_tmp = cast<T>(less_than<T>(x, y), out_grad.dtype ());
2017
2060
auto dx_res = out_grad * x_tmp;
2018
- if (out_grad.dims () != x.dims ()) {
2019
- auto reduce_dim = get_reduce_dims_from_out (out_grad.dims (), x.dims ());
2020
- auto dx_reduce_res =
2021
- dx_res.sum (common::vectorize (reduce_dim), x.dtype (), false );
2022
- auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize (x.dims ()));
2023
- set_output<T>(dx_tmp, x_grad);
2061
+ if (has_dynamic_shape (x.shape ())) {
2062
+ auto dx_reduce_res = reduce_as<T>(dx_res, x);
2063
+ set_output<T>(dx_reduce_res, x_grad);
2024
2064
} else {
2025
- set_output<T>(dx_res, x_grad);
2065
+ if (out_grad.dims () != x.dims ()) {
2066
+ auto reduce_dim = get_reduce_dims_from_out (out_grad.dims (), x.dims ());
2067
+ auto dx_reduce_res =
2068
+ dx_res.sum (common::vectorize (reduce_dim), x.dtype (), false );
2069
+ auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize (x.dims ()));
2070
+ set_output<T>(dx_tmp, x_grad);
2071
+ } else {
2072
+ set_output<T>(dx_res, x_grad);
2073
+ }
2026
2074
}
2027
2075
}
2028
2076
2029
2077
if (y_grad) {
2030
2078
auto y_tmp = cast<T>(greater_equal<T>(x, y), out_grad.dtype ());
2031
2079
auto dy_res = out_grad * y_tmp;
2032
- if (out_grad.dims () != y.dims ()) {
2033
- phi::DDim reduce_dim =
2034
- get_reduce_dims_from_out (out_grad.dims (), y.dims ());
2035
- auto dy_reduce_res =
2036
- dy_res.sum (common::vectorize (reduce_dim), y.dtype (), false );
2037
- auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize (y.dims ()));
2038
- set_output<T>(dy_tmp, y_grad);
2080
+ if (has_dynamic_shape (y.shape ())) {
2081
+ auto dy_reduce_res = reduce_as<T>(dy_res, y);
2082
+ set_output<T>(dy_reduce_res, y_grad);
2039
2083
} else {
2040
- set_output<T>(dy_res, y_grad);
2084
+ if (out_grad.dims () != y.dims ()) {
2085
+ phi::DDim reduce_dim =
2086
+ get_reduce_dims_from_out (out_grad.dims (), y.dims ());
2087
+ auto dy_reduce_res =
2088
+ dy_res.sum (common::vectorize (reduce_dim), y.dtype (), false );
2089
+ auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize (y.dims ()));
2090
+ set_output<T>(dy_tmp, y_grad);
2091
+ } else {
2092
+ set_output<T>(dy_res, y_grad);
2093
+ }
2041
2094
}
2042
2095
}
2043
2096
}
0 commit comments