@@ -25,12 +25,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
25
25
using framework::OperatorWithKernel::OperatorWithKernel;
26
26
27
27
void InferShape (framework::InferShapeContext* ctx) const override {
28
- PADDLE_ENFORCE_EQ (ctx->HasInput (" X" ), true , " Input(X) should be not null." );
29
- PADDLE_ENFORCE_EQ (ctx->HasInput (" Label" ), true ,
30
- " Input(Label) should be not null." );
31
-
32
- PADDLE_ENFORCE_EQ (ctx->HasOutput (" Y" ), true ,
33
- " Output(Y) should be not null." );
28
+ OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " CrossEntropy" );
29
+ OP_INOUT_CHECK (ctx->HasInput (" Label" ), " Input" , " Label" , " CrossEntropy" );
30
+ OP_INOUT_CHECK (ctx->HasOutput (" Y" ), " Output" , " Y" , " CrossEntropy" );
34
31
35
32
auto x_dims = ctx->GetInputDim (" X" );
36
33
auto label_dims = ctx->GetInputDim (" Label" );
@@ -44,53 +41,61 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
44
41
PADDLE_ENFORCE_EQ (
45
42
framework::slice_ddim (x_dims, 0 , rank - 1 ),
46
43
framework::slice_ddim (label_dims, 0 , rank - 1 ),
47
- " ShapeError: Input(X) and Input(Label) shall have the same shape "
48
- " except the last dimension. But received: the shape of Input(X) is "
49
- " [%s],"
50
- " the shape of Input(Label) is [%s]." ,
51
- x_dims, label_dims);
44
+ platform::errors::InvalidArgument (
45
+ " Input(X) and Input(Label) shall have the same shape "
46
+ " except the last dimension. But received: the shape of Input(X) "
47
+ " is "
48
+ " [%s], the shape of Input(Label) is [%s]." ,
49
+ x_dims, label_dims));
52
50
}
53
51
54
52
if (IsSoftLabel (ctx)) {
55
53
PADDLE_ENFORCE_EQ (
56
54
rank, label_dims.size (),
57
- " ShapeError: If Attr(soft_label) == true, Input(X) and Input(Label) "
58
- " shall have the same dimensions. But received: the dimensions of "
59
- " Input(X) is [%d],"
60
- " the shape of Input(X) is [%s], the dimensions of Input(Label) is "
61
- " [%d], the shape of"
62
- " Input(Label) is [%s]" ,
63
- rank, x_dims, label_dims.size (), label_dims);
55
+ platform::errors::InvalidArgument (
56
+ " If Attr(soft_label) == true, Input(X) and Input(Label) "
57
+ " shall have the same dimensions. But received: the dimensions of "
58
+ " Input(X) is [%d],"
59
+ " the shape of Input(X) is [%s], the dimensions of Input(Label) "
60
+ " is "
61
+ " [%d], the shape of"
62
+ " Input(Label) is [%s]" ,
63
+ rank, x_dims, label_dims.size (), label_dims));
64
64
65
65
if (check) {
66
66
PADDLE_ENFORCE_EQ (
67
67
x_dims[rank - 1 ], label_dims[rank - 1 ],
68
- " ShapeError: If Attr(soft_label) == true, the last dimension of "
69
- " Input(X) and Input(Label) should be equal. But received: the"
70
- " last dimension of Input(X) is [%d], the shape of Input(X) is [%s],"
71
- " the last dimension of Input(Label) is [%d], the shape of "
72
- " Input(Label)"
73
- " is [%s], the last dimension is [%d]." ,
74
- x_dims[rank - 1 ], x_dims, label_dims[rank - 1 ], label_dims,
75
- rank - 1 );
68
+ platform::errors::InvalidArgument (
69
+ " If Attr(soft_label) == true, the last dimension of "
70
+ " Input(X) and Input(Label) should be equal. But received: the"
71
+ " last dimension of Input(X) is [%d], the shape of Input(X) is "
72
+ " [%s],"
73
+ " the last dimension of Input(Label) is [%d], the shape of "
74
+ " Input(Label)"
75
+ " is [%s], the last dimension is [%d]." ,
76
+ x_dims[rank - 1 ], x_dims, label_dims[rank - 1 ], label_dims,
77
+ rank - 1 ));
76
78
}
77
79
} else {
78
80
if (rank == label_dims.size ()) {
79
81
PADDLE_ENFORCE_EQ (
80
82
label_dims[rank - 1 ], 1UL ,
81
- " ShapeError: the last dimension of Input(Label) should be 1."
82
- " But received: the last dimension of Input(Label) is [%d],"
83
- " the last dimension is [%d]" ,
84
- label_dims[rank - 1 ], rank - 1 );
83
+ platform::errors::InvalidArgument (
84
+ " the last dimension of Input(Label) should be 1."
85
+ " But received: the last dimension of Input(Label) is [%d],"
86
+ " the last dimension is [%d]" ,
87
+ label_dims[rank - 1 ], rank - 1 ));
85
88
} else {
86
- PADDLE_ENFORCE_EQ (rank, label_dims.size () + 1 ,
87
- " ShapeError: The rank of Input(X) should be equal to "
88
- " Input(Label) plus 1."
89
- " But received: The dimension of Input(X) is [%d], "
90
- " the shape of Input(X) is [%s],"
91
- " the dimension of Input(Label) is [%d], the shape of "
92
- " Input(Label) is [%s]" ,
93
- rank, x_dims, label_dims.size (), label_dims);
89
+ PADDLE_ENFORCE_EQ (
90
+ rank, label_dims.size () + 1 ,
91
+ platform::errors::InvalidArgument (
92
+ " ShapeError: The rank of Input(X) should be equal to "
93
+ " Input(Label) plus 1."
94
+ " But received: The dimension of Input(X) is [%d], "
95
+ " the shape of Input(X) is [%s],"
96
+ " the dimension of Input(Label) is [%d], the shape of "
97
+ " Input(Label) is [%s]" ,
98
+ rank, x_dims, label_dims.size (), label_dims));
94
99
}
95
100
}
96
101
@@ -122,19 +127,23 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
122
127
using framework::OperatorWithKernel::OperatorWithKernel;
123
128
124
129
void InferShape (framework::InferShapeContext* ctx) const {
125
- PADDLE_ENFORCE_EQ (ctx->HasInput (" Label" ), true ,
126
- " Input(Label) should be not null. " );
127
- PADDLE_ENFORCE_EQ (ctx->HasInput (framework::GradVarName (" Y" )), true ,
128
- " Input(Y@GRAD) shoudl be not null. " );
129
- PADDLE_ENFORCE_EQ (ctx->HasOutput (framework::GradVarName (" X" )), true ,
130
- " Output(X@GRAD) should be not null. " );
130
+ OP_INOUT_CHECK (ctx->HasInput (" Label" ), " Input " , " Label " ,
131
+ " CrossEntropyGradientOpBase " );
132
+ OP_INOUT_CHECK (ctx->HasInput (framework::GradVarName (" Y" )), " Input " ,
133
+ framework::GradVarName ( " Y " ), " CrossEntropyGradientOpBase " );
134
+ OP_INOUT_CHECK (ctx->HasOutput (framework::GradVarName (" X" )), " Output " ,
135
+ framework::GradVarName ( " X " ), " CrossEntropyGradientOpBase " );
131
136
132
137
auto x_dims = GetXDim (ctx);
133
138
auto label_dims = ctx->GetInputDim (" Label" );
134
139
auto dy_dims = ctx->GetInputDim (framework::GradVarName (" Y" ));
135
140
int rank = x_dims.size ();
136
- PADDLE_ENFORCE_EQ (dy_dims.size (), label_dims.size (),
137
- " Input(Y@Grad) and Input(Y) should have the same rank." );
141
+ PADDLE_ENFORCE_EQ (
142
+ dy_dims.size (), label_dims.size (),
143
+ platform::errors::InvalidArgument (
144
+ " Input(Y@Grad) and Input(Y) should have the same rank."
145
+ " But received: Y@Grad's rank is [%d], Y's rank is [%d]" ,
146
+ dy_dims.size (), label_dims.size ()));
138
147
139
148
bool check = true ;
140
149
if ((!ctx->IsRuntime ()) &&
@@ -143,10 +152,15 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
143
152
}
144
153
145
154
if (check) {
146
- PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank - 1 ),
147
- framework::slice_ddim (dy_dims, 0 , rank - 1 ),
148
- " The Input(X) and Input(Y@Grad) should have the same "
149
- " shape except the last dimension." );
155
+ PADDLE_ENFORCE_EQ (
156
+ framework::slice_ddim (x_dims, 0 , rank - 1 ),
157
+ framework::slice_ddim (dy_dims, 0 , rank - 1 ),
158
+ platform::errors::InvalidArgument (
159
+ " The Input(X) and Input(Y@Grad) should have the same "
160
+ " shape except the last dimension. but received: "
161
+ " the shape of Input(X) is [%s], "
162
+ " the shape of Input(Y@Grad) is [%s]." ,
163
+ x_dims, dy_dims));
150
164
}
151
165
152
166
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
@@ -253,7 +267,7 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
253
267
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
254
268
255
269
void InferShape (framework::InferShapeContext* ctx) const override {
256
- PADDLE_ENFORCE_EQ (ctx->HasInput (" X" ), true , " Input(X) should be not null. " );
270
+ OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input " , " X " , " CrossEntropyGradientOp " );
257
271
CrossEntropyGradientOpBase::InferShape (ctx);
258
272
}
259
273
};
@@ -281,11 +295,10 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
281
295
void InferShape (framework::InferShapeContext* ctx) const override {
282
296
CrossEntropyOpBase::InferShape (ctx);
283
297
284
- PADDLE_ENFORCE_EQ (ctx->HasOutput (" XShape" ), true ,
285
- " Output(XShape) should be not null." );
286
-
287
- PADDLE_ENFORCE_EQ (ctx->HasOutput (" MatchX" ), true ,
288
- " Output(MatchX) should be not null." );
298
+ OP_INOUT_CHECK (ctx->HasOutput (" XShape" ), " Output" , " XShape" ,
299
+ " CrossEntropyOp2" );
300
+ OP_INOUT_CHECK (ctx->HasOutput (" MatchX" ), " Output" , " MatchX" ,
301
+ " CrossEntropyOp2" );
289
302
auto x_dims = ctx->GetInputDim (" X" );
290
303
auto x_dims_vec = framework::vectorize (x_dims);
291
304
x_dims_vec.push_back (0 );
@@ -305,8 +318,8 @@ class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
305
318
public:
306
319
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
307
320
void InferShape (framework::InferShapeContext* ctx) const override {
308
- PADDLE_ENFORCE_EQ (ctx->HasInput (" MatchX" ), true ,
309
- " Input(MatchX) must exist " );
321
+ OP_INOUT_CHECK (ctx->HasInput (" MatchX" ), " Input " , " MatchX " ,
322
+ " CrossEntropyGradientOp2 " );
310
323
CrossEntropyGradientOpBase::InferShape (ctx);
311
324
}
312
325
0 commit comments