Skip to content

Commit 3ae939e

Browse files
authored
unify PADDLE_ASSERT_MSG into PADDLE_ENFORCE(error_message) (#19631)
* remove assert.h * change PADDLE_ASSERT_MSG to PADDLE_ENFORCE test=develop * fix tensorrt paddle_enforce test=develop
1 parent af692c9 commit 3ae939e

14 files changed

+61
-90
lines changed

paddle/fluid/framework/dim.h

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include <type_traits>
2121

2222
#include "paddle/fluid/framework/array.h"
23-
#include "paddle/fluid/platform/assert.h"
2423
#include "paddle/fluid/platform/enforce.h"
2524
#include "paddle/fluid/platform/hostdevice.h"
2625

paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu

+3-4
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,9 @@ int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
111111
float const* input_ptr = reinterpret_cast<float const*>(inputs[0]);
112112
float* const* h_odatas = reinterpret_cast<float* const*>(outputs);
113113
float** output_ptrs = thrust::raw_pointer_cast(&d_output_ptrs_[0]);
114-
PADDLE_ENFORCE(cudaMemcpyAsync(output_ptrs, h_odatas,
115-
d_output_ptrs_.size() * sizeof(float*),
116-
cudaMemcpyHostToDevice,
117-
stream) == cudaSuccess);
114+
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemcpyAsync(
115+
output_ptrs, h_odatas, d_output_ptrs_.size() * sizeof(float*),
116+
cudaMemcpyHostToDevice, stream));
118117

119118
int outer_rows = outer_rows_ * batchSize;
120119

paddle/fluid/operators/center_loss_op.cu

+4-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License. */
1414

1515
#include <iostream>
1616
#include "paddle/fluid/operators/center_loss_op.h"
17-
#include "paddle/fluid/platform/assert.h"
1817
#include "paddle/fluid/platform/cuda_primitives.h"
1918
#include "paddle/fluid/platform/gpu_info.h"
2019
namespace paddle {
@@ -31,8 +30,8 @@ __global__ void ComputeDifferent(T *centers_diff, const T *X, const T *centers,
3130

3231
while (idy < K) {
3332
int64_t id = ids[idy];
34-
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
35-
PADDLE_ASSERT_MSG(id < N, "received id:", id);
33+
PADDLE_ENFORCE(id >= 0, "received id:", id);
34+
PADDLE_ENFORCE(id < N, "received id:", id);
3635
T *out = centers_diff + idy * D;
3736
const T *x = X + idy * D;
3837
const T *cent = centers + id * D;
@@ -53,8 +52,8 @@ __global__ void UpdateCenters(T *centers, T *centers_diff, const int64_t *ids,
5352
while (idy < K) {
5453
int count = 1;
5554
int64_t id = ids[idy];
56-
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
57-
PADDLE_ASSERT_MSG(id < N, "received id:", id);
55+
PADDLE_ENFORCE(id >= 0, "received id:", id);
56+
PADDLE_ENFORCE(id < N, "received id:", id);
5857

5958
for (int i = 0; i < K; i++) {
6059
if (ids[i] == id) {

paddle/fluid/operators/cross_entropy_op.h

100755100644
+5-5
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ struct HardLabelCrossEntropyForwardFunctor {
155155
HOSTDEVICE void operator()(int64_t idx) const {
156156
auto label = label_[idx];
157157
if (label != ignore_index_) {
158-
PADDLE_ASSERT_MSG(label >= 0 && label < feature_size_,
159-
"Variable value (label) of "
160-
"OP(fluid.layers.cross_entropy) expected >= 0 "
161-
"and < %ld, but got %ld. Please check label value.",
162-
feature_size_, label);
158+
PADDLE_ENFORCE(label >= 0 && label < feature_size_,
159+
"Variable value (label) of "
160+
"OP(fluid.layers.cross_entropy) expected >= 0 "
161+
"and < %ld, but got %ld. Please check label value.",
162+
feature_size_, label);
163163
auto match_x = x_[idx * feature_size_ + label];
164164
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
165165
match_x_[idx] = match_x;

paddle/fluid/operators/lookup_table_op.cu

+4-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/eigen.h"
1616
#include "paddle/fluid/framework/op_registry.h"
1717
#include "paddle/fluid/operators/lookup_table_op.h"
18-
#include "paddle/fluid/platform/assert.h"
1918
#include "paddle/fluid/platform/cuda_primitives.h"
2019
#include "paddle/fluid/platform/float16.h"
2120

@@ -32,12 +31,12 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
3231

3332
while (idy < K) {
3433
int64_t id = ids[idy];
35-
PADDLE_ASSERT_MSG(
34+
PADDLE_ENFORCE(
3635
id >= 0,
3736
"Variable value (input) of OP(fluid.layers.embedding) "
3837
"expected >= 0 and < %ld, but got %ld. Please check input value.",
3938
N, id);
40-
PADDLE_ASSERT_MSG(
39+
PADDLE_ENFORCE(
4140
id < N,
4241
"Variable value (input) of OP(fluid.layers.embedding) "
4342
"expected >= 0 and < %ld, but got %ld. Please check input value.",
@@ -67,12 +66,12 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
6766

6867
while (idy < K) {
6968
int64_t id = ids[idy];
70-
PADDLE_ASSERT_MSG(
69+
PADDLE_ENFORCE(
7170
id >= 0,
7271
"Variable value (input) of OP(fluid.layers.embedding) "
7372
"expected >= 0 and < %ld, but got %ld. Please check input value.",
7473
N, id);
75-
PADDLE_ASSERT_MSG(
74+
PADDLE_ENFORCE(
7675
id < N,
7776
"Variable value (input) of OP(fluid.layers.embedding) "
7877
"expected >= 0 and < %ld, but got %ld. Please check input value.",

paddle/fluid/operators/math/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function(math_library TARGET)
77
set(cc_srcs)
88
set(cu_srcs)
99
set(hip_srcs)
10-
set(math_common_deps device_context framework_proto)
10+
set(math_common_deps device_context framework_proto enforce)
1111
set(multiValueArgs DEPS)
1212
cmake_parse_arguments(math_library "${options}" "${oneValueArgs}"
1313
"${multiValueArgs}" ${ARGN})

paddle/fluid/operators/math/cross_entropy.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
2727
const int ignore_index) {
2828
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
2929
i += blockDim.x * gridDim.x) {
30-
PADDLE_ASSERT_MSG(label[i] >= 0 && label[i] < D || label[i] == ignore_index,
31-
"label[%d] expected >= 0 and < %ld, or == %ld, but got "
32-
"%ld. Please check input value.",
33-
i, D, ignore_index, label[i]);
30+
PADDLE_ENFORCE(label[i] >= 0 && label[i] < D || label[i] == ignore_index,
31+
"label[%d] expected >= 0 and < %ld, or == %ld, but got "
32+
"%ld. Please check input value.",
33+
i, D, ignore_index, label[i]);
3434
Y[i] = ignore_index == label[i]
3535
? static_cast<T>(0)
3636
: -math::TolerableValue<T>()(real_log(X[i * D + label[i]]));

paddle/fluid/operators/math/cross_entropy.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ namespace math {
2525
template <typename T>
2626
struct TolerableValue {
2727
HOSTDEVICE T operator()(const T& x) const {
28-
PADDLE_ASSERT_MSG(std::is_floating_point<T>::value,
29-
"TolerableValue should be float in cross_entropy.");
28+
PADDLE_ENFORCE(std::is_floating_point<T>::value,
29+
"TolerableValue should be float in cross_entropy.");
3030
const T kApproInf = 1e20;
3131

3232
if (x == INFINITY) return kApproInf;

paddle/fluid/operators/math/unpooling.cu

+8-8
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
3737
int cidx = boffset / in_c_stride;
3838
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
3939
int out_index = indices_data[i];
40-
PADDLE_ASSERT_MSG(out_index < out_c_stride,
41-
"out_index < out_c_stride. Expected %ld < %ld, but got "
42-
"%ld >= %ld. Please check input value.",
43-
out_index, out_c_stride, out_index, out_c_stride);
40+
PADDLE_ENFORCE(out_index < out_c_stride,
41+
"out_index < out_c_stride. Expected %ld < %ld, but got "
42+
"%ld >= %ld. Please check input value.",
43+
out_index, out_c_stride, out_index, out_c_stride);
4444
output_data[out_offset + out_index] = input_data[i];
4545
}
4646
}
@@ -62,10 +62,10 @@ __global__ void KernelUnpool2dMaxGrad(
6262
int cidx = boffset / in_c_stride;
6363
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
6464
int out_index = indices_data[i];
65-
PADDLE_ASSERT_MSG(out_index < out_c_stride,
66-
"out_index < out_c_stride. Expected %ld < %ld, but got "
67-
"%ld >= %ld. Please check input value.",
68-
out_index, out_c_stride, out_index, out_c_stride);
65+
PADDLE_ENFORCE(out_index < out_c_stride,
66+
"out_index < out_c_stride. Expected %ld < %ld, but got "
67+
"%ld >= %ld. Please check input value.",
68+
out_index, out_c_stride, out_index, out_c_stride);
6969
input_grad[i] = output_grad[out_offset + out_index];
7070
}
7171
}

paddle/fluid/operators/modified_huber_loss_op.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
2929
template <typename T>
3030
struct CheckLabelValue {
3131
HOSTDEVICE T operator()(const T& val) const {
32-
PADDLE_ASSERT_MSG(val == static_cast<T>(0) || val == static_cast<T>(1),
33-
"LabelValue of modified_huber_loss_op expected to be 0 "
34-
"or 1, but got %ld. Please check input value.",
35-
val);
32+
PADDLE_ENFORCE(val == static_cast<T>(0) || val == static_cast<T>(1),
33+
"LabelValue of modified_huber_loss_op expected to be 0 "
34+
"or 1, but got %ld. Please check input value.",
35+
val);
3636
}
3737
};
3838

paddle/fluid/operators/random_crop_op.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,16 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
6060
size_t offset_i = offsets[i];
6161

6262
if (i == rank - 1) {
63-
PADDLE_ASSERT_MSG(x_stride == 1,
64-
"When i:%d == rank:%d - 1, x_stride of random_crop_op "
65-
"expected to be 1, but got %ld. Please check input "
66-
"value.",
67-
i, rank, x_stride);
68-
PADDLE_ASSERT_MSG(out_stride == 1,
69-
"When i:%d == rank:%d - 1, out_stride of random_crop_op "
70-
"expected to be 1, but got %ld. Please check input "
71-
"value.",
72-
i, rank, out_stride);
63+
PADDLE_ENFORCE(x_stride == 1,
64+
"When i:%d == rank:%d - 1, x_stride of random_crop_op "
65+
"expected to be 1, but got %ld. Please check input "
66+
"value.",
67+
i, rank, x_stride);
68+
PADDLE_ENFORCE(out_stride == 1,
69+
"When i:%d == rank:%d - 1, out_stride of random_crop_op "
70+
"expected to be 1, but got %ld. Please check input "
71+
"value.",
72+
i, rank, out_stride);
7373
x += offset_i;
7474
for (size_t j = 0; j < out_dim_i; ++j) {
7575
*out++ = *x++;

paddle/fluid/operators/sample_logits_op.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
3434
template <typename T>
3535
struct TolerableValue {
3636
HOSTDEVICE T operator()(const T& x) const {
37-
PADDLE_ASSERT_MSG(std::is_floating_point<T>::value,
38-
"TolerableValue should be float in sample_logits_op.");
37+
PADDLE_ENFORCE(std::is_floating_point<T>::value,
38+
"TolerableValue should be float in sample_logits_op.");
3939
const T kApproInf = 1e20;
4040
if (x == INFINITY) return kApproInf;
4141
if (x == -INFINITY) return -kApproInf;

paddle/fluid/platform/assert.h

-39
This file was deleted.

paddle/fluid/platform/enforce.h

+14
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,19 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
289289
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
290290
} while (0)
291291

292+
#if defined(__CUDA_ARCH__)
293+
// For cuda, the assertions can affect performance and it is therefore
294+
// recommended to disable them in production code
295+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
296+
#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...) \
297+
do { \
298+
if (!(_IS_NOT_ERROR)) { \
299+
printf("Exception: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
300+
__FILE__, __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__); \
301+
asm("trap;"); \
302+
} \
303+
} while (0)
304+
#else
292305
#define PADDLE_ENFORCE(COND, ...) \
293306
do { \
294307
auto __cond__ = (COND); \
@@ -302,6 +315,7 @@ DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
302315
} \
303316
} \
304317
} while (0)
318+
#endif
305319

306320
#ifdef PADDLE_WITH_CUDA
307321
#define PADDLE_ENFORCE_CUDA_SUCCESS(COND, ...) \

0 commit comments

Comments
 (0)