@@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel {
24
24
using framework::OperatorWithKernel::OperatorWithKernel;
25
25
26
26
void InferShape (framework::InferShapeContext* ctx) const override {
27
- PADDLE_ENFORCE (ctx->HasInput (" X" ),
28
- " Input(X) of SamplingIdOp should not be null." );
29
- PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
30
- " Output(Out) of SamplingIdOp should not be null." );
31
- PADDLE_ENFORCE_LT (ctx->Attrs ().Get <float >(" min" ),
32
- ctx->Attrs ().Get <float >(" max" ), " min must less then max" );
27
+ OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " SampleIn" );
28
+ OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " X" , " SampleOut" );
29
+ PADDLE_ENFORCE_LT (
30
+ ctx->Attrs ().Get <float >(" min" ), ctx->Attrs ().Get <float >(" max" ),
31
+ platform::errors::InvalidArgument (
32
+ " min must less then max, but here min is %f, max is %f" ,
33
+ ctx->Attrs ().Get <float >(" min" ), ctx->Attrs ().Get <float >(" max" )));
33
34
34
35
auto input_dims = ctx->GetInputDim (" X" );
35
- PADDLE_ENFORCE (input_dims.size () == 2 ,
36
- " Input(X, Filter) should be 2-D tensor." );
36
+ PADDLE_ENFORCE_EQ (
37
+ input_dims.size (), 2 ,
38
+ platform::errors::InvalidArgument (
39
+ " Input(X, Filter) should be 2-D tensor. But X dim is %d" ,
40
+ input_dims.size ()));
37
41
38
42
auto dim0 = input_dims[0 ];
39
43
framework::DDim dims = framework::make_ddim ({dim0});
0 commit comments