Skip to content

Commit 10c2fd4

Browse files
fix slice grad (PaddlePaddle#64820)
1 parent 72e8d4e commit 10c2fd4

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

paddle/fluid/prim/api/composite_backward/composite_backward_api.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,13 +744,20 @@ void slice_grad(const Tensor& input,
744744
paddings.push_back(offsets[i]);
745745
paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
746746
}
747+
Tensor reshape_out_grad;
748+
if (out_grad.shape().size() == 0) {
749+
reshape_out_grad = full<T>({1}, 1, input.dtype());
750+
} else {
751+
reshape_out_grad = out_grad;
752+
}
753+
747754
if (decrease_size > 0 &&
748755
(decrease_size != static_cast<size_t>(in_dims.size()))) {
749756
auto out_tmp =
750-
pad<T>(reshape<T>(out_grad, origin_out_shape), paddings, 0.0);
757+
pad<T>(reshape<T>(reshape_out_grad, origin_out_shape), paddings, 0.0);
751758
set_output<T>(out_tmp, input_grad);
752759
} else {
753-
auto out_tmp = pad<T>(out_grad, paddings, 0.0);
760+
auto out_tmp = pad<T>(reshape_out_grad, paddings, 0.0);
754761
set_output<T>(out_tmp, input_grad);
755762
}
756763
}

paddle/fluid/primitive/rule/vjp/details.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,13 +1483,20 @@ void slice_grad(const Tensor& input,
14831483
paddings.push_back(offsets[i]);
14841484
paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
14851485
}
1486+
Tensor reshape_out_grad;
1487+
if (out_grad.shape().size() == 0) {
1488+
reshape_out_grad = full<T>({1}, 1, input.dtype());
1489+
} else {
1490+
reshape_out_grad = out_grad;
1491+
}
1492+
14861493
if (decrease_size > 0 &&
14871494
(decrease_size != static_cast<size_t>(in_dims.size()))) {
14881495
auto out_tmp =
1489-
pad<T>(reshape<T>(out_grad, origin_out_shape), paddings, 0.0);
1496+
pad<T>(reshape<T>(reshape_out_grad, origin_out_shape), paddings, 0.0);
14901497
set_output<T>(out_tmp, input_grad);
14911498
} else {
1492-
auto out_tmp = pad<T>(out_grad, paddings, 0.0);
1499+
auto out_tmp = pad<T>(reshape_out_grad, paddings, 0.0);
14931500
set_output<T>(out_tmp, input_grad);
14941501
}
14951502
}

0 commit comments

Comments
 (0)