@@ -38,42 +38,64 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
38
38
using framework::OperatorWithKernel::OperatorWithKernel;
39
39
40
40
void InferShape (framework::InferShapeContext* ctx) const override {
41
- PADDLE_ENFORCE (ctx->HasInput (" RpnRois" ),
42
- " Input(RpnRois) shouldn't be null." );
43
- PADDLE_ENFORCE (ctx->HasInput (" GtClasses" ),
44
- " Input(GtClasses) shouldn't be null." );
45
- PADDLE_ENFORCE (ctx->HasInput (" IsCrowd" ),
46
- " Input(IsCrowd) shouldn't be null." );
47
- PADDLE_ENFORCE (ctx->HasInput (" GtBoxes" ),
48
- " Input(GtBoxes) shouldn't be null." );
49
- PADDLE_ENFORCE (ctx->HasInput (" ImInfo" ), " Input(ImInfo) shouldn't be null." );
50
-
51
- PADDLE_ENFORCE (
52
- ctx->HasOutput (" Rois" ),
53
- " Output(Rois) of GenerateProposalLabelsOp should not be null" );
54
- PADDLE_ENFORCE (
55
- ctx->HasOutput (" LabelsInt32" ),
56
- " Output(LabelsInt32) of GenerateProposalLabelsOp should not be null" );
57
- PADDLE_ENFORCE (
58
- ctx->HasOutput (" BboxTargets" ),
59
- " Output(BboxTargets) of GenerateProposalLabelsOp should not be null" );
60
- PADDLE_ENFORCE (ctx->HasOutput (" BboxInsideWeights" ),
61
- " Output(BboxInsideWeights) of GenerateProposalLabelsOp "
62
- " should not be null" );
63
- PADDLE_ENFORCE (ctx->HasOutput (" BboxOutsideWeights" ),
64
- " Output(BboxOutsideWeights) of GenerateProposalLabelsOp "
65
- " should not be null" );
41
+ PADDLE_ENFORCE_EQ (
42
+ ctx->HasInput (" RpnRois" ), true ,
43
+ platform::errors::NotFound (" Input(RpnRois) shouldn't be null." ));
44
+ PADDLE_ENFORCE_EQ (
45
+ ctx->HasInput (" GtClasses" ), true ,
46
+ platform::errors::NotFound (" Input(GtClasses) shouldn't be null." ));
47
+ PADDLE_ENFORCE_EQ (
48
+ ctx->HasInput (" IsCrowd" ), true ,
49
+ platform::errors::NotFound (" Input(IsCrowd) shouldn't be null." ));
50
+ PADDLE_ENFORCE_EQ (
51
+ ctx->HasInput (" GtBoxes" ), true ,
52
+ platform::errors::NotFound (" Input(GtBoxes) shouldn't be null." ));
53
+ PADDLE_ENFORCE_EQ (
54
+ ctx->HasInput (" ImInfo" ), true ,
55
+ platform::errors::NotFound (" Input(ImInfo) shouldn't be null." ));
56
+
57
+ PADDLE_ENFORCE_EQ (
58
+ ctx->HasOutput (" Rois" ), true ,
59
+ platform::errors::NotFound (
60
+ " Output(Rois) of GenerateProposalLabelsOp should not be null" ));
61
+ PADDLE_ENFORCE_EQ (ctx->HasOutput (" LabelsInt32" ), true ,
62
+ platform::errors::NotFound (" Output(LabelsInt32) of "
63
+ " GenerateProposalLabelsOp "
64
+ " should not be null" ));
65
+ PADDLE_ENFORCE_EQ (ctx->HasOutput (" BboxTargets" ), true ,
66
+ platform::errors::NotFound (" Output(BboxTargets) of "
67
+ " GenerateProposalLabelsOp "
68
+ " should not be null" ));
69
+ PADDLE_ENFORCE_EQ (
70
+ ctx->HasOutput (" BboxInsideWeights" ), true ,
71
+ platform::errors::NotFound (
72
+ " Output(BboxInsideWeights) of GenerateProposalLabelsOp "
73
+ " should not be null" ));
74
+ PADDLE_ENFORCE_EQ (
75
+ ctx->HasOutput (" BboxOutsideWeights" ), true ,
76
+ platform::errors::NotFound (
77
+ " Output(BboxOutsideWeights) of GenerateProposalLabelsOp "
78
+ " should not be null" ));
66
79
67
80
auto rpn_rois_dims = ctx->GetInputDim (" RpnRois" );
68
81
auto gt_boxes_dims = ctx->GetInputDim (" GtBoxes" );
69
82
auto im_info_dims = ctx->GetInputDim (" ImInfo" );
70
83
71
84
PADDLE_ENFORCE_EQ (rpn_rois_dims.size (), 2 ,
72
- " The rank of Input(RpnRois) must be 2." );
85
+ platform::errors::InvalidArgument (
86
+ " The dimensions size of Input(RpnRois) must be 2. "
87
+ " But received dimensions size=[%d], dimensions=[%s]." ,
88
+ rpn_rois_dims.size (), rpn_rois_dims));
73
89
PADDLE_ENFORCE_EQ (gt_boxes_dims.size (), 2 ,
74
- " The rank of Input(GtBoxes) must be 2." );
90
+ platform::errors::InvalidArgument (
91
+ " The dimensions size of Input(GtBoxes) must be 2. "
92
+ " But received dimensions size=[%d], dimensions=[%s]." ,
93
+ gt_boxes_dims.size (), gt_boxes_dims));
75
94
PADDLE_ENFORCE_EQ (im_info_dims.size (), 2 ,
76
- " The rank of Input(ImInfo) must be 2." );
95
+ platform::errors::InvalidArgument (
96
+ " The dimensions size of Input(ImInfo) must be 2. But "
97
+ " received dimensions size=[%d], dimensions=[%s]." ,
98
+ im_info_dims.size (), im_info_dims));
77
99
78
100
int class_nums = ctx->Attrs ().Get <int >(" class_nums" );
79
101
@@ -399,15 +421,30 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
399
421
bool use_random = context.Attr <bool >(" use_random" );
400
422
bool is_cascade_rcnn = context.Attr <bool >(" is_cascade_rcnn" );
401
423
bool is_cls_agnostic = context.Attr <bool >(" is_cls_agnostic" );
402
- PADDLE_ENFORCE_EQ (rpn_rois->lod ().size (), 1UL ,
403
- " GenerateProposalLabelsOp rpn_rois needs 1 level of LoD" );
424
+ PADDLE_ENFORCE_EQ (
425
+ rpn_rois->lod ().size (), 1UL ,
426
+ platform::errors::InvalidArgument (
427
+ " GenerateProposalLabelsOp rpn_rois needs 1 level of LoD. But "
428
+ " received level of LoD is [%d], LoD is [%s]." ,
429
+ rpn_rois->lod ().size (), rpn_rois->lod ()));
404
430
PADDLE_ENFORCE_EQ (
405
431
gt_classes->lod ().size (), 1UL ,
406
- " GenerateProposalLabelsOp gt_classes needs 1 level of LoD" );
407
- PADDLE_ENFORCE_EQ (is_crowd->lod ().size (), 1UL ,
408
- " GenerateProposalLabelsOp is_crowd needs 1 level of LoD" );
409
- PADDLE_ENFORCE_EQ (gt_boxes->lod ().size (), 1UL ,
410
- " GenerateProposalLabelsOp gt_boxes needs 1 level of LoD" );
432
+ platform::errors::InvalidArgument (
433
+ " GenerateProposalLabelsOp gt_classes needs 1 level of LoD. But "
434
+ " received level of LoD is [%d], LoD is [%s]." ,
435
+ gt_classes->lod ().size (), gt_classes->lod ()));
436
+ PADDLE_ENFORCE_EQ (
437
+ is_crowd->lod ().size (), 1UL ,
438
+ platform::errors::InvalidArgument (
439
+ " GenerateProposalLabelsOp is_crowd needs 1 level of LoD. But "
440
+ " received level of LoD is [%d], LoD is [%s]." ,
441
+ is_crowd->lod ().size (), is_crowd->lod ()));
442
+ PADDLE_ENFORCE_EQ (
443
+ gt_boxes->lod ().size (), 1UL ,
444
+ platform::errors::InvalidArgument (
445
+ " GenerateProposalLabelsOp gt_boxes needs 1 level of LoD. But "
446
+ " received level of LoD is [%d], LoD is [%s]." ,
447
+ gt_boxes->lod ().size (), gt_boxes->lod ()));
411
448
int64_t n = static_cast <int64_t >(rpn_rois->lod ().back ().size () - 1 );
412
449
413
450
rois->mutable_data <T>({n * batch_size_per_im, kBoxDim }, context.GetPlace ());
0 commit comments