Skip to content

Commit f6050da

Browse files
authored
SamplingID Op fix error print (#24521) (#24552)
* fix error print for sampling_id_op * fix spell err * fix spell err test=develop
1 parent 6f65b07 commit f6050da

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

paddle/fluid/operators/sampling_id_op.cc

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel {
2424
using framework::OperatorWithKernel::OperatorWithKernel;
2525

2626
void InferShape(framework::InferShapeContext* ctx) const override {
27-
PADDLE_ENFORCE(ctx->HasInput("X"),
28-
"Input(X) of SamplingIdOp should not be null.");
29-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30-
"Output(Out) of SamplingIdOp should not be null.");
31-
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"),
32-
ctx->Attrs().Get<float>("max"), "min must less then max");
27+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SampleIn");
28+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "X", "SampleOut");
29+
PADDLE_ENFORCE_LT(
30+
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
31+
platform::errors::InvalidArgument(
32+
"min must less then max, but here min is %f, max is %f",
33+
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
3334

3435
auto input_dims = ctx->GetInputDim("X");
35-
PADDLE_ENFORCE(input_dims.size() == 2,
36-
"Input(X, Filter) should be 2-D tensor.");
36+
PADDLE_ENFORCE_EQ(
37+
input_dims.size(), 2,
38+
platform::errors::InvalidArgument(
39+
"Input(X, Filter) should be 2-D tensor. But X dim is %d",
40+
input_dims.size()));
3741

3842
auto dim0 = input_dims[0];
3943
framework::DDim dims = framework::make_ddim({dim0});

paddle/fluid/operators/sampling_id_op.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,15 @@ class SamplingIdKernel : public framework::OpKernel<T> {
3636
const int batch_size = static_cast<int>(input->dims()[0]);
3737
const int width = static_cast<int>(input->dims()[1]);
3838

39-
PADDLE_ENFORCE_GE(batch_size, 0,
40-
"batch_size(dims[0]) must be nonnegative.");
41-
PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative.");
39+
PADDLE_ENFORCE_GE(
40+
batch_size, 0,
41+
platform::errors::InvalidArgument(
42+
"batch_size(dims[0]) must be nonnegative. but it is %d.",
43+
batch_size));
44+
PADDLE_ENFORCE_GE(
45+
width, 0,
46+
platform::errors::InvalidArgument(
47+
"width(dims[1]) must be nonnegative. but it is %d.", width));
4248

4349
std::vector<T> ins_vector;
4450
framework::TensorToVector(*input, context.device_context(), &ins_vector);

0 commit comments

Comments
 (0)