Skip to content

Commit 08beb1e

Browse files
authored
fix put along axis momory (#71863)
* fix put along axis momory * fix bug * update * fix bug * try fix bug * fix compile bug
1 parent c6e5870 commit 08beb1e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

paddle/phi/kernels/stride/slice_grad_kernel.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "paddle/common/flags.h"
1717
#include "paddle/phi/backends/all_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/strided_utils.h"
2021
#include "paddle/phi/kernels/slice_kernel.h"
2122

@@ -46,6 +47,16 @@ void SliceGradStridedKernel(const Context& dev_ctx,
4647
}));
4748
DenseTensor tmp;
4849
tmp.set_meta(out_grad.meta());
50+
if (out_grad.numel() == 0) {
51+
// set zero to input_grad
52+
53+
PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] {
54+
phi::StridedTensorFill<data_t>(
55+
*input_grad, 0, input_grad);
56+
}));
57+
58+
return;
59+
}
4960
SliceStridedKernel<Context>(dev_ctx,
5061
*input_grad,
5162
axes,

0 commit comments

Comments
 (0)