Skip to content

Commit 2f12cb7

Browse files
committed
fix roi align grad kernel
1 parent f1aae59 commit 2f12cb7

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

paddle/phi/kernels/gpu/roi_align_grad_kernel.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,13 @@ void RoiAlignGradKernel(const Context& dev_ctx,
176176
int sampling_ratio,
177177
bool aligned,
178178
DenseTensor* dx) {
179+
if (x.numel() == 0 || boxes.numel() == 0) {
180+
dev_ctx.template Alloc<T>(dx);
181+
return;
182+
}
183+
179184
int rois_num = boxes.dims()[0];
185+
180186
int channels = x.dims()[1];
181187
int height = x.dims()[2];
182188
int width = x.dims()[3];
@@ -185,6 +191,10 @@ void RoiAlignGradKernel(const Context& dev_ctx,
185191
return;
186192
}
187193

194+
if (dx->numel() == 9) {
195+
dev_ctx.template Alloc<T>(dx);
196+
return;
197+
}
188198
DenseTensor box_batch_id_list;
189199
box_batch_id_list.Resize({rois_num});
190200
int* box_batch_size = dev_ctx.template HostAlloc<int>(&box_batch_id_list);

0 commit comments

Comments
 (0)