@@ -832,10 +832,11 @@ void subtract_double_grad(const Tensor& y,
832
832
int axis,
833
833
Tensor* grad_out_grad) {
834
834
if (grad_out_grad) {
835
- // ddout = ddx - ddy
836
835
if (grad_x_grad && grad_y_grad) {
836
+ // ddout = ddx - ddy
837
837
set_output<T>(grad_x_grad.get () - grad_y_grad.get (), grad_out_grad);
838
838
} else if (grad_x_grad) {
839
+ // ddout = ddx
839
840
if (grad_x_grad.get ().dims () != grad_out.dims ()) {
840
841
// broad cast grad_x_grad to grad_out
841
842
auto grad_x_grad_dims = common::vectorize (grad_x_grad.get ().dims ());
@@ -876,6 +877,7 @@ void subtract_double_grad(const Tensor& y,
876
877
by_pass<T>(grad_x_grad.get (), grad_out_grad);
877
878
}
878
879
} else if (grad_y_grad) {
880
+ // ddout = -ddy
879
881
if (grad_y_grad.get ().dims () != grad_out.dims ()) {
880
882
// broad cast grad_y_grad to grad_out
881
883
auto grad_y_grad_dims = common::vectorize (grad_y_grad.get ().dims ());
@@ -902,18 +904,21 @@ void subtract_double_grad(const Tensor& y,
902
904
}
903
905
}
904
906
if (need_reshape && need_tile) {
905
- set_output<T>(tile<T>(reshape<T>(grad_y_grad.get (), broadcast_dims),
907
+ set_output<T>(tile<T>(reshape<T>(scale<T>(grad_y_grad.get (), -1.0 ),
908
+ broadcast_dims),
906
909
repeat_times),
907
910
grad_out_grad);
908
911
} else if (need_reshape) {
909
- set_output<T>(reshape<T>(grad_y_grad.get (), broadcast_dims),
910
- grad_out_grad);
912
+ set_output<T>(
913
+ reshape<T>(scale<T>(grad_y_grad.get (), -1.0 ), broadcast_dims),
914
+ grad_out_grad);
911
915
} else if (need_tile) {
912
- set_output<T>(tile<T>(grad_y_grad.get (), repeat_times),
913
- grad_out_grad);
916
+ set_output<T>(
917
+ tile<T>(scale<T>(grad_y_grad.get (), -1.0 ), repeat_times),
918
+ grad_out_grad);
914
919
}
915
920
} else {
916
- by_pass<T>(- grad_y_grad.get (), grad_out_grad);
921
+ by_pass<T>(scale<T>( grad_y_grad.get (), - 1.0 ), grad_out_grad);
917
922
}
918
923
} else {
919
924
set_output<T>(full<T>(common::vectorize (grad_out.dims ()),
0 commit comments