Skip to content

Commit 40de41b

Browse files
fix subtract double grad bug
1 parent a0e8d37 commit 40de41b

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -832,10 +832,11 @@ void subtract_double_grad(const Tensor& y,
832832
int axis,
833833
Tensor* grad_out_grad) {
834834
if (grad_out_grad) {
835-
// ddout = ddx - ddy
836835
if (grad_x_grad && grad_y_grad) {
836+
// ddout = ddx - ddy
837837
set_output<T>(grad_x_grad.get() - grad_y_grad.get(), grad_out_grad);
838838
} else if (grad_x_grad) {
839+
// ddout = ddx
839840
if (grad_x_grad.get().dims() != grad_out.dims()) {
840841
// broad cast grad_x_grad to grad_out
841842
auto grad_x_grad_dims = common::vectorize(grad_x_grad.get().dims());
@@ -876,6 +877,7 @@ void subtract_double_grad(const Tensor& y,
876877
by_pass<T>(grad_x_grad.get(), grad_out_grad);
877878
}
878879
} else if (grad_y_grad) {
880+
// ddout = -ddy
879881
if (grad_y_grad.get().dims() != grad_out.dims()) {
880882
// broad cast grad_y_grad to grad_out
881883
auto grad_y_grad_dims = common::vectorize(grad_y_grad.get().dims());
@@ -902,18 +904,21 @@ void subtract_double_grad(const Tensor& y,
902904
}
903905
}
904906
if (need_reshape && need_tile) {
905-
set_output<T>(tile<T>(reshape<T>(grad_y_grad.get(), broadcast_dims),
907+
set_output<T>(tile<T>(reshape<T>(scale<T>(grad_y_grad.get(), -1.0),
908+
broadcast_dims),
906909
repeat_times),
907910
grad_out_grad);
908911
} else if (need_reshape) {
909-
set_output<T>(reshape<T>(grad_y_grad.get(), broadcast_dims),
910-
grad_out_grad);
912+
set_output<T>(
913+
reshape<T>(scale<T>(grad_y_grad.get(), -1.0), broadcast_dims),
914+
grad_out_grad);
911915
} else if (need_tile) {
912-
set_output<T>(tile<T>(grad_y_grad.get(), repeat_times),
913-
grad_out_grad);
916+
set_output<T>(
917+
tile<T>(scale<T>(grad_y_grad.get(), -1.0), repeat_times),
918+
grad_out_grad);
914919
}
915920
} else {
916-
by_pass<T>(-grad_y_grad.get(), grad_out_grad);
921+
by_pass<T>(scale<T>(grad_y_grad.get(), -1.0), grad_out_grad);
917922
}
918923
} else {
919924
set_output<T>(full<T>(common::vectorize(grad_out.dims()),

0 commit comments

Comments
 (0)