Skip to content

Commit 1bde725

Browse files
authored
[Prim][PIR] fix bugs that not use full_scalar (#67170)
* add unittest case for dynamic shape; fix bugs * fix code style * support dynamic shape, add unittest * fix gelu fp16 unittest
1 parent a30b293 commit 1bde725

File tree

5 files changed

+894
-61
lines changed

5 files changed

+894
-61
lines changed

paddle/fluid/primitive/composite/composite.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,10 @@ std::tuple<Tensor, Tensor> huber_loss_decomp(const Tensor& input,
200200
}
201201
auto val = label - input;
202202
auto abs_val = abs<T>(val);
203+
auto factor = full_scalar<T>(0.5, input.dtype());
203204
auto ans = where<T>(abs_val <= delta_full,
204-
0.5 * val * val,
205-
delta_full * (abs_val - 0.5 * delta_full));
205+
factor * val * val,
206+
delta_full * (abs_val - factor * delta_full));
206207
return std::make_tuple(ans, val);
207208
}
208209

paddle/fluid/primitive/rule/vjp/details.h

Lines changed: 111 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ void cumsum_grad(const Tensor& x,
5252
Tensor* x_grad) {
5353
if (x_grad) {
5454
auto grad = cumsum<T>(out_grad, axis, flatten, exclusive, !reverse);
55-
grad = reshape<T>(grad, x.shape());
55+
if (has_dynamic_shape(x.shape())) {
56+
grad = backend::reshape<T>(grad, shape<T>(x));
57+
} else {
58+
grad = reshape<T>(grad, x.shape());
59+
}
5660
set_output<T>(grad, x_grad);
5761
}
5862
}
@@ -146,8 +150,14 @@ void divide_grad(const Tensor& x,
146150
template <typename T>
147151
void floor_grad(const Tensor& out_grad, Tensor* x_grad) {
148152
if (x_grad) {
149-
auto zero_tensor =
150-
full<T>(common::vectorize(out_grad.dims()), 0.0, out_grad.dtype());
153+
Tensor zero_tensor;
154+
if (has_dynamic_shape(out_grad.shape())) {
155+
zero_tensor = backend::full_with_tensor<T>(
156+
shape<T>(out_grad), 0.0, out_grad.dtype());
157+
} else {
158+
zero_tensor =
159+
full<T>(common::vectorize(out_grad.dims()), 0.0, out_grad.dtype());
160+
}
151161
set_output<T>(zero_tensor, x_grad);
152162
}
153163
}
@@ -303,9 +313,12 @@ void gelu_grad(const Tensor& x,
303313
if (approximate) {
304314
float kbeta = M_SQRT2 * M_2_SQRTPI * 0.5;
305315
float kkappa = 0.044715;
316+
Tensor kbeta_ = full_scalar<T>(kbeta, promoted_x.dtype());
317+
Tensor kkappa_ = full_scalar<T>(kkappa, promoted_x.dtype());
318+
306319
auto x_sq = promoted_x * promoted_x;
307320
auto x_cube = x_sq * promoted_x;
308-
auto inner = kbeta * (promoted_x + kkappa * x_cube);
321+
auto inner = kbeta_ * (promoted_x + kkappa_ * x_cube);
309322
auto tanh_inner = tanh<T>(inner);
310323

311324
auto left = scale<T>(promoted_x, 0.5);
@@ -314,7 +327,7 @@ void gelu_grad(const Tensor& x,
314327
auto left_derivative = scale<T>(right, 0.5);
315328

316329
auto tanh_derivative = scale<T>(tanh_inner * tanh_inner, -1., 1.);
317-
auto inner_derivative = kbeta * (scale<T>(3 * kkappa * x_sq, 1., 1.));
330+
auto inner_derivative = kbeta_ * (scale<T>(3 * kkappa_ * x_sq, 1., 1.));
318331
auto right_derivative = left * tanh_derivative * inner_derivative;
319332

320333
set_output<T>(
@@ -324,8 +337,11 @@ void gelu_grad(const Tensor& x,
324337
} else {
325338
float kalpha = M_SQRT1_2;
326339
float kbeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
327-
auto cdf = scale<T>(scale<T>(erf<T>(kalpha * promoted_x), 1., 1.), 0.5);
328-
auto pdf = kbeta * exp<T>(scale<T>(promoted_x * promoted_x, -0.5));
340+
Tensor kalpha_ = full_scalar<T>(kalpha, promoted_x.dtype());
341+
Tensor kbeta_ = full_scalar<T>(kbeta, promoted_x.dtype());
342+
343+
auto cdf = scale<T>(scale<T>(erf<T>(kalpha_ * promoted_x), 1., 1.), 0.5);
344+
auto pdf = kbeta_ * exp<T>(scale<T>(promoted_x * promoted_x, -0.5));
329345
set_output<T>(
330346
cast<T>(promoted_out_grad * (cdf + promoted_x * pdf), x.type()),
331347
x_grad);
@@ -336,9 +352,12 @@ void gelu_grad(const Tensor& x,
336352
if (approximate) {
337353
auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
338354
auto kKappa = 0.044715;
355+
Tensor kBeta_ = full_scalar<T>(kBeta, x.dtype());
356+
Tensor kKappa_ = full_scalar<T>(kKappa, x.dtype());
357+
339358
auto x_sq = x * x;
340359
auto x_cube = x_sq * x;
341-
auto inner = kBeta * (x + kKappa * x_cube);
360+
auto inner = kBeta_ * (x + kKappa_ * x_cube);
342361
auto tanh_inner = tanh<T>(inner);
343362

344363
auto left = scale<T>(x, 0.5);
@@ -347,15 +366,18 @@ void gelu_grad(const Tensor& x,
347366
auto left_derivative = scale<T>(right, 0.5);
348367

349368
auto tanh_derivative = scale<T>(tanh_inner * tanh_inner, -1., 1.);
350-
auto inner_derivative = kBeta * (scale<T>(3 * kKappa * x_sq, 1., 1.));
369+
auto inner_derivative = kBeta_ * (scale<T>(3 * kKappa_ * x_sq, 1., 1.));
351370
auto right_derivative = left * tanh_derivative * inner_derivative;
352371

353372
set_output<T>(out_grad * (left_derivative + right_derivative), x_grad);
354373
} else {
355374
auto kAlpha = M_SQRT1_2;
356375
auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
357-
auto cdf = scale<T>(scale<T>(erf<T>(kAlpha * x), 1., 1.), 0.5);
358-
auto pdf = kBeta * exp<T>(scale<T>(x * x, -0.5));
376+
Tensor kAlpha_ = full_scalar<T>(kAlpha, x.dtype());
377+
Tensor kBeta_ = full_scalar<T>(kBeta, x.dtype());
378+
379+
auto cdf = scale<T>(scale<T>(erf<T>(kAlpha_ * x), 1., 1.), 0.5);
380+
auto pdf = kBeta_ * exp<T>(scale<T>(x * x, -0.5));
359381
set_output<T>(out_grad * (cdf + x * pdf), x_grad);
360382
}
361383
}
@@ -409,8 +431,13 @@ void reduce_as_grad(const Tensor& x,
409431
template <typename T>
410432
void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) {
411433
if (grad_x) {
412-
const auto& x_dims = x.dims();
413-
auto grad_x_tmp = reshape<T>(grad_out, common::vectorize(x_dims));
434+
Tensor grad_x_tmp;
435+
if (has_dynamic_shape(x.shape())) {
436+
grad_x_tmp = backend::reshape<T>(grad_out, shape<T>(x));
437+
} else {
438+
const auto& x_dims = x.dims();
439+
grad_x_tmp = reshape<T>(grad_out, common::vectorize(x_dims));
440+
}
414441
set_output<T>(grad_x_tmp, grad_x);
415442
}
416443
}
@@ -503,7 +530,7 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
503530
template <typename T>
504531
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
505532
if (!grad_x) return;
506-
auto grad_x_tmp = grad_out * (1 - out * out);
533+
auto grad_x_tmp = grad_out * (full_scalar<T>(1.0, out.dtype()) - out * out);
507534
set_output<T>(grad_x_tmp, grad_x);
508535
}
509536

@@ -961,9 +988,8 @@ void dropout_grad(const Tensor& mask,
961988
template <typename T>
962989
void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
963990
if (x_grad) {
964-
auto m_2_sqrt_pi =
965-
full<T>(common::vectorize(x.dims()), M_2_SQRTPI, x.dtype());
966-
auto neg_one = full<T>(common::vectorize(x.dims()), -1.0, x.dtype());
991+
auto m_2_sqrt_pi = full_scalar<T>(M_2_SQRTPI, x.dtype());
992+
auto neg_one = full_scalar<T>(-1.0, x.dtype());
967993
auto neg_tmp = neg_one * x * x;
968994
auto mul_tmp = m_2_sqrt_pi * exp<T>(neg_tmp);
969995
set_output<T>(out_grad * mul_tmp, x_grad);
@@ -1000,7 +1026,8 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
10001026
template <typename T>
10011027
void square_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
10021028
if (x_grad) {
1003-
Tensor x_grad_tmp = 2 * x * out_grad;
1029+
auto two = full_scalar<T>(2.0, x.dtype());
1030+
Tensor x_grad_tmp = two * x * out_grad;
10041031
set_output<T>(x_grad_tmp, x_grad);
10051032
}
10061033
}
@@ -1046,17 +1073,17 @@ void silu_grad(const Tensor& x,
10461073
const Tensor& out_grad,
10471074
Tensor* x_grad) {
10481075
if (x_grad) {
1076+
auto one = full_scalar<T>(1.0, x.dtype());
10491077
auto org_dtype = x.dtype();
10501078
bool need_cast = org_dtype == phi::DataType::FLOAT16 ||
10511079
org_dtype == phi::DataType::BFLOAT16;
10521080
if (need_cast) {
10531081
auto x_cast = cast<T>(x, phi::DataType::FLOAT32);
10541082
auto out_cast = cast<T>(out, phi::DataType::FLOAT32);
10551083
auto out_grad_cast = cast<T>(out_grad, phi::DataType::FLOAT32);
1056-
auto res = out_grad_cast * sigmoid<T>(x_cast) * (1.0 + x_cast - out_cast);
1084+
auto res = out_grad_cast * sigmoid<T>(x_cast) * (one + x_cast - out_cast);
10571085
set_output<T>(cast<T>(res, org_dtype), x_grad);
10581086
} else {
1059-
auto one = full_scalar<T>(1.0, x.dtype());
10601087
auto res = out_grad * sigmoid<T>(x) * (one + x - out);
10611088
set_output<T>(res, x_grad);
10621089
}
@@ -1243,29 +1270,39 @@ void maximum_grad(const Tensor& x,
12431270
if (x_grad) {
12441271
auto x_tmp = cast<T>(greater_than<T>(x, y), out_grad.dtype());
12451272
auto dx_res = out_grad * x_tmp;
1246-
if (out_grad.dims() != x.dims()) {
1247-
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
1248-
auto dx_reduce_res =
1249-
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
1250-
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
1251-
set_output<T>(dx_tmp, x_grad);
1273+
if (has_dynamic_shape(x.shape())) {
1274+
auto dx_reduce_res = reduce_as<T>(dx_res, x);
1275+
set_output<T>(dx_reduce_res, x_grad);
12521276
} else {
1253-
set_output<T>(dx_res, x_grad);
1277+
if (out_grad.dims() != x.dims()) {
1278+
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
1279+
auto dx_reduce_res =
1280+
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
1281+
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
1282+
set_output<T>(dx_tmp, x_grad);
1283+
} else {
1284+
set_output<T>(dx_res, x_grad);
1285+
}
12541286
}
12551287
}
12561288

12571289
if (y_grad) {
12581290
auto y_tmp = cast<T>(less_equal<T>(x, y), out_grad.dtype());
12591291
auto dy_res = out_grad * y_tmp;
1260-
if (out_grad.dims() != y.dims()) {
1261-
phi::DDim reduce_dim =
1262-
get_reduce_dims_from_out(out_grad.dims(), y.dims());
1263-
auto dy_reduce_res =
1264-
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
1265-
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
1266-
set_output<T>(dy_tmp, y_grad);
1292+
if (has_dynamic_shape(y.shape())) {
1293+
auto dy_reduce_res = reduce_as<T>(dy_res, y);
1294+
set_output<T>(dy_reduce_res, y_grad);
12671295
} else {
1268-
set_output<T>(dy_res, y_grad);
1296+
if (out_grad.dims() != y.dims()) {
1297+
phi::DDim reduce_dim =
1298+
get_reduce_dims_from_out(out_grad.dims(), y.dims());
1299+
auto dy_reduce_res =
1300+
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
1301+
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
1302+
set_output<T>(dy_tmp, y_grad);
1303+
} else {
1304+
set_output<T>(dy_res, y_grad);
1305+
}
12691306
}
12701307
}
12711308
}
@@ -1664,13 +1701,19 @@ void tile_grad(const Tensor& x,
16641701
template <typename T>
16651702
void hardswish_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
16661703
if (x_grad) {
1667-
auto offset = full<T>(common::vectorize(x.dims()), 3.0, x.dtype());
1704+
const Tensor offset = full_scalar<T>(3.0, x.dtype());
1705+
Tensor zero;
1706+
if (has_dynamic_shape(x.shape())) {
1707+
zero = backend::full_with_tensor<T>(shape<T>(x), 0.0, x.dtype());
1708+
} else {
1709+
zero = full<T>(common::vectorize(x.dims()), 0.0, x.dtype());
1710+
}
16681711
auto condition = less_equal<T>(x, offset);
1669-
auto tmp1 = where<T>(condition, out_grad * ((x / 3.0) + 0.5), out_grad);
1670-
auto res = where<T>(
1671-
less_than<T>(x, full<T>(common::vectorize(x.dims()), -3.0, x.dtype())),
1672-
full<T>(common::vectorize(x.dims()), 0.0, x.dtype()),
1673-
tmp1);
1712+
auto factor = full_scalar<T>(0.5, x.dtype());
1713+
auto tmp1 =
1714+
where<T>(condition, out_grad * ((x / offset) + factor), out_grad);
1715+
auto res =
1716+
where<T>(less_than<T>(x, full_scalar<T>(-3.0, x.dtype())), zero, tmp1);
16741717
set_output<T>(res, x_grad);
16751718
}
16761719
}
@@ -1681,8 +1724,8 @@ void leaky_relu_grad(const Tensor& out,
16811724
float negative_slope,
16821725
Tensor* x_grad) {
16831726
if (x_grad) {
1684-
auto condition = greater_than<T>(
1685-
out, full<T>(common::vectorize(out.dims()), 0.0, out.dtype()));
1727+
auto zero = full_scalar<T>(0.0, out.dtype());
1728+
auto condition = greater_than<T>(out, zero);
16861729
auto res = where<T>(condition, out_grad, out_grad * negative_slope);
16871730
set_output<T>(res, x_grad);
16881731
}
@@ -2015,29 +2058,39 @@ void minimum_grad(const Tensor& x,
20152058
if (x_grad) {
20162059
auto x_tmp = cast<T>(less_than<T>(x, y), out_grad.dtype());
20172060
auto dx_res = out_grad * x_tmp;
2018-
if (out_grad.dims() != x.dims()) {
2019-
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
2020-
auto dx_reduce_res =
2021-
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
2022-
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
2023-
set_output<T>(dx_tmp, x_grad);
2061+
if (has_dynamic_shape(x.shape())) {
2062+
auto dx_reduce_res = reduce_as<T>(dx_res, x);
2063+
set_output<T>(dx_reduce_res, x_grad);
20242064
} else {
2025-
set_output<T>(dx_res, x_grad);
2065+
if (out_grad.dims() != x.dims()) {
2066+
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
2067+
auto dx_reduce_res =
2068+
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
2069+
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
2070+
set_output<T>(dx_tmp, x_grad);
2071+
} else {
2072+
set_output<T>(dx_res, x_grad);
2073+
}
20262074
}
20272075
}
20282076

20292077
if (y_grad) {
20302078
auto y_tmp = cast<T>(greater_equal<T>(x, y), out_grad.dtype());
20312079
auto dy_res = out_grad * y_tmp;
2032-
if (out_grad.dims() != y.dims()) {
2033-
phi::DDim reduce_dim =
2034-
get_reduce_dims_from_out(out_grad.dims(), y.dims());
2035-
auto dy_reduce_res =
2036-
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
2037-
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
2038-
set_output<T>(dy_tmp, y_grad);
2080+
if (has_dynamic_shape(y.shape())) {
2081+
auto dy_reduce_res = reduce_as<T>(dy_res, y);
2082+
set_output<T>(dy_reduce_res, y_grad);
20392083
} else {
2040-
set_output<T>(dy_res, y_grad);
2084+
if (out_grad.dims() != y.dims()) {
2085+
phi::DDim reduce_dim =
2086+
get_reduce_dims_from_out(out_grad.dims(), y.dims());
2087+
auto dy_reduce_res =
2088+
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
2089+
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
2090+
set_output<T>(dy_tmp, y_grad);
2091+
} else {
2092+
set_output<T>(dy_res, y_grad);
2093+
}
20412094
}
20422095
}
20432096
}

python/paddle/autograd/backward_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,21 @@
4949
"pd_op.split",
5050
"pd_op.multiply",
5151
"pd_op.relu",
52-
"pd_op.sigmoid",
5352
"pd_op.divide",
5453
"pd_op.pow",
5554
"pd_op.elementwise_pow",
5655
"pd_op.softmax",
5756
"pd_op.matmul",
57+
"pd_op.cumsum",
58+
"pd_op.erf",
59+
"pd_op.floor",
60+
"pd_op.reshape",
61+
"pd_op.leaky_relu",
62+
"pd_op.softsign",
63+
"pd_op.maximum",
64+
"pd_op.minimum",
65+
"pd_op.gelu",
66+
"pd_op.hardswish",
5867
"pd_op.reduce_as",
5968
]
6069

0 commit comments

Comments
 (0)