Skip to content

Commit 8ac0227

Browse files
authored
Fix the proformance problem of enforce (#6085)
* Fix Proformance problem of enforce * Fix missing `;` in code * Fix CI
1 parent 3a8311f commit 8ac0227

File tree

6 files changed

+29
-21
lines changed

6 files changed

+29
-21
lines changed

paddle/operators/concat_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class ConcatOp : public framework::OperatorWithKernel {
2525

2626
void InferShape(framework::InferShapeContext *ctx) const override {
2727
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
28-
"Inputs(X) of ConcatOp should be empty.")
28+
"Inputs(X) of ConcatOp should be empty.");
2929
PADDLE_ENFORCE(ctx->HasOutput("Out"),
3030
"Output(Out) of ConcatOp should not be null.");
3131

@@ -45,7 +45,7 @@ class ConcatOp : public framework::OperatorWithKernel {
4545
}
4646
PADDLE_ENFORCE_EQ(out_dims[j], ins[i][j],
4747
"Input tensors should have the same "
48-
"elements except the specify axis.")
48+
"elements except the specify axis.");
4949
}
5050
}
5151
ctx->SetOutputDim("Out", out_dims);

paddle/operators/elementwise_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
3535
auto x_dim = ctx->GetInputDim("X");
3636
auto y_dim = ctx->GetInputDim("Y");
3737
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
38-
"Rank of first input must >= rank of second input.")
38+
"Rank of first input must >= rank of second input.");
3939
ctx->SetOutputDim("Out", x_dim);
4040
ctx->ShareLoD("X", /*->*/ "Out");
4141
}
@@ -120,7 +120,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
120120
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
121121

122122
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
123-
"Rank of first input must >= rank of second input.")
123+
"Rank of first input must >= rank of second input.");
124124

125125
auto x_grad_name = framework::GradVarName("X");
126126
auto y_grad_name = framework::GradVarName("Y");

paddle/operators/elementwise_op_function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
106106
auto x_dims = x->dims();
107107
auto y_dims = y->dims();
108108
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
109-
"Rank of first input must >= rank of second input.")
109+
"Rank of first input must >= rank of second input.");
110110

111111
if (x_dims == y_dims) {
112112
functor f;

paddle/operators/sequence_slice_op.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
5454
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
5555
PADDLE_ENFORCE_EQ(
5656
n, static_cast<size_t>(length->dims()[0]),
57-
"The size of input-sequence and length-array should be the same")
57+
"The size of input-sequence and length-array should be the same");
5858
PADDLE_ENFORCE_EQ(
5959
n, static_cast<size_t>(offset->dims()[0]),
60-
"The size of input-sequence and offset-array should be the same")
60+
"The size of input-sequence and offset-array should be the same");
6161

6262
const int64_t* offset_data = offset->data<int64_t>();
6363
const int64_t* length_data = length->data<int64_t>();
@@ -78,11 +78,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
7878

7979
for (size_t i = 0; i < n; ++i) {
8080
PADDLE_ENFORCE_LT(0, offset_data[i],
81-
"The offset[%d] must greater than zero.", i)
81+
"The offset[%d] must greater than zero.", i);
8282
PADDLE_ENFORCE_LT(0, length_data[i],
83-
"The length[%d] must greater than zero.", i)
83+
"The length[%d] must greater than zero.", i);
8484
PADDLE_ENFORCE_LT(lod[0][i] + offset_data[i] + length_data[i],
85-
lod[0][i + 1], "The target tensor's length overflow.")
85+
lod[0][i + 1], "The target tensor's length overflow.");
8686
}
8787

8888
out->mutable_data<T>(ctx.GetPlace());

paddle/operators/sum_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class SumKernel : public framework::OpKernel<T> {
8484
int64_t offset = 0;
8585
for (int i = 0; i < N; i++) {
8686
PADDLE_ENFORCE_EQ(out->height(),
87-
in_vars[i]->Get<SelectedRows>().height())
87+
in_vars[i]->Get<SelectedRows>().height());
8888
functor(context.device_context(), in_vars[i]->Get<SelectedRows>(),
8989
offset, out);
9090
offset += in_vars[i]->Get<SelectedRows>().value().numel();

paddle/platform/enforce.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,24 @@ inline void throw_on_error(T e) {
234234
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
235235
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
236236
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)
237-
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
238-
PADDLE_ENFORCE(nullptr != (__VAL), #__VAL " should not be null\n%s", \
239-
paddle::string::Sprintf("" __VA_ARGS__));
240-
241-
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
242-
PADDLE_ENFORCE(__VAL0 __CMP __VAL1, \
243-
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
244-
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
245-
paddle::string::to_string(__VAL1), \
246-
paddle::string::Sprintf("" __VA_ARGS__));
237+
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \
238+
do { \
239+
if (UNLIKELY(nullptr == (__VAL))) { \
240+
PADDLE_THROW(#__VAL " should not be null\n%s", \
241+
paddle::string::Sprintf("" __VA_ARGS__)); \
242+
} \
243+
} while (0)
244+
245+
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
246+
do { \
247+
if (!UNLIKELY((__VAL0)__CMP(__VAL1))) { \
248+
PADDLE_THROW("enforce %s " #__CMP " %s failed, %s " #__INV_CMP \
249+
" %s\n%s", \
250+
#__VAL0, #__VAL1, paddle::string::to_string(__VAL0), \
251+
paddle::string::to_string(__VAL1), \
252+
paddle::string::Sprintf("" __VA_ARGS__)); \
253+
} \
254+
} while (0)
247255

248256
} // namespace platform
249257
} // namespace paddle

0 commit comments

Comments
 (0)