Skip to content

Commit b76a6de

Browse files
Cherry Pick Fix error message (#20164)
* fix the error message for reduce_mean and reduce_sum op (#20063) * fix the error message for reduce_mean and reduce_sum op test=develop * fix typo test=develop * fix according review advice test=develop * fix the test test=develop * fix test=develop * Fill constant error message fix (#20075) * fix the constant error message test=develop * fix typo test=develop * fix typo test=develop * fix code style test=develop * fix comment and bugs test=develop * fix the bug test=develop * fix and add unittest test=develop * fix the typo test=develop * add support for the fill_constant op test=develop * add test for ci coverage test=develop
1 parent 907a853 commit b76a6de

File tree

8 files changed

+130
-11
lines changed

8 files changed

+130
-11
lines changed

paddle/fluid/operators/fill_constant_op.cc

+1
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,5 @@ REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
8787
ops::FillConstantKernel<double>,
8888
ops::FillConstantKernel<int64_t>,
8989
ops::FillConstantKernel<int>,
90+
ops::FillConstantKernel<bool>,
9091
ops::FillConstantKernel<paddle::platform::float16>);

paddle/fluid/operators/reduce_ops/reduce_op.h

+17-8
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,20 @@ class ReduceOp : public framework::OperatorWithKernel {
165165
"Output(Out) of ReduceOp should not be null.");
166166
auto x_dims = ctx->GetInputDim("X");
167167
auto x_rank = x_dims.size();
168-
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
168+
PADDLE_ENFORCE_LE(x_rank, 6,
169+
"ShapeError: The input tensor X's dimensions of Reduce "
170+
"should be less equal than 6. But received X's "
171+
"dimensions = %d, X's shape = [%s].",
172+
x_rank, x_dims);
169173
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
174+
170175
for (size_t i = 0; i < dims.size(); ++i) {
176+
PADDLE_ENFORCE_LT(dims[i], x_rank,
177+
"ShapeError: The reduce dim index %d should be in the "
178+
"range [-dimension(X), dimension(X)]."
179+
"which dimesion = %d, But received dim index = %d",
180+
i, x_rank, dims[i]);
171181
if (dims[i] < 0) dims[i] = x_rank + dims[i];
172-
PADDLE_ENFORCE_LT(
173-
dims[i], x_rank,
174-
"The dim should be in the range [-rank(input), rank(input)).");
175182
}
176183
sort(dims.begin(), dims.end());
177184
bool reduce_all = ctx->Attrs().Get<bool>("reduce_all");
@@ -202,7 +209,7 @@ class ReduceOp : public framework::OperatorWithKernel {
202209
}
203210
auto out_dims = framework::make_ddim(dims_vector);
204211
ctx->SetOutputDim("Out", out_dims);
205-
if (dims[0] != 0) {
212+
if (dims.size() > 0 && dims[0] != 0) {
206213
// Only pass LoD when not reducing on the first dim.
207214
ctx->ShareLoD("X", /*->*/ "Out");
208215
}
@@ -223,10 +230,12 @@ class ReduceGradOp : public framework::OperatorWithKernel {
223230
PADDLE_ENFORCE_LE(x_rank, 6, "Tensors with rank at most 6 are supported.");
224231
auto dims = ctx->Attrs().Get<std::vector<int>>("dim");
225232
for (size_t i = 0; i < dims.size(); ++i) {
233+
PADDLE_ENFORCE_LT(dims[i], x_rank,
234+
"ShapeError: The reduce dim index %d should be in the "
235+
"range [-dimension(X), dimension(X)]."
236+
"which dimesion = %d, But received dim index = %d",
237+
i, x_rank, dims[i]);
226238
if (dims[i] < 0) dims[i] = x_rank + dims[i];
227-
PADDLE_ENFORCE_LT(
228-
dims[i], x_rank,
229-
"The dim should be in the range [-rank(input), rank(input)).");
230239
}
231240
sort(dims.begin(), dims.end());
232241
auto x_grad_name = framework::GradVarName("X");

python/paddle/fluid/data_feeder.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,19 @@
2727

2828

2929
def convert_dtype(dtype):
30-
if dtype == core.VarDesc.VarType.FP32:
30+
if isinstance(dtype, str):
31+
if dtype in [
32+
'float32', 'int64', 'float64', 'float16', 'int32', 'uint8',
33+
'bool'
34+
]:
35+
return dtype
36+
else:
37+
raise ValueError(
38+
"dtype must be any of [bool, int32, float32, int64, "
39+
"float64, uint8]")
40+
elif dtype == core.VarDesc.VarType.BOOL:
41+
return 'bool'
42+
elif dtype == core.VarDesc.VarType.FP32:
3143
return 'float32'
3244
elif dtype == core.VarDesc.VarType.INT64:
3345
return 'int64'
@@ -40,7 +52,7 @@ def convert_dtype(dtype):
4052
elif dtype == core.VarDesc.VarType.UINT8:
4153
return 'uint8'
4254
else:
43-
raise ValueError("dtype must be any of [int32, float32, int64, "
55+
raise ValueError("dtype must be any of [bool,int32, float32, int64, "
4456
"float64, uint8]")
4557

4658

python/paddle/fluid/layers/nn.py

+18
Original file line numberDiff line numberDiff line change
@@ -5611,6 +5611,15 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
56115611

56125612
"""
56135613
helper = LayerHelper('reduce_sum', **locals())
5614+
if not isinstance(input, Variable):
5615+
raise TypeError(
5616+
"The type of 'input' in reduce_sum must be Variable, but received %s"
5617+
% (type(input)))
5618+
if convert_dtype(
5619+
input.dtype) not in ['float32', 'float64', 'int32', 'int64']:
5620+
raise TypeError(
5621+
"The data type of 'input' in reduce_sum must be float32 or float64 or int32 or int64, but received %s."
5622+
% (convert_dtype(input.dtype)))
56145623
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
56155624
if dim is not None and not isinstance(dim, list):
56165625
dim = [dim]
@@ -5670,6 +5679,15 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
56705679
fluid.layers.reduce_mean(y, dim=[0, 1]) # [4.0, 5.0]
56715680
"""
56725681
helper = LayerHelper('reduce_mean', **locals())
5682+
if not isinstance(input, Variable):
5683+
raise TypeError(
5684+
"The type of 'input' in reduce_mean must be Variable, but received %s"
5685+
% (type(input)))
5686+
if convert_dtype(
5687+
input.dtype) not in ['float32', 'float64', 'int32', 'int64']:
5688+
raise TypeError(
5689+
"The data type of 'input' in reduce_mean must be float32 or float64 or int32 or int64, but received %s."
5690+
% (convert_dtype(input.dtype)))
56735691
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
56745692
if dim is not None and not isinstance(dim, list):
56755693
dim = [dim]

python/paddle/fluid/layers/tensor.py

+14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..initializer import Constant, force_init_on_cpu
2222
from ..core import VarDesc
2323
from .layer_function_generator import templatedoc
24+
from ..data_feeder import convert_dtype
2425
import numpy
2526

2627
__all__ = [
@@ -397,8 +398,21 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
397398
"""
398399

399400
helper = LayerHelper("fill_constant", **locals())
401+
if convert_dtype(dtype) not in [
402+
'bool', 'float16', 'float32', 'float64', 'int32', 'int64'
403+
]:
404+
raise TypeError(
405+
"The create data type in fill_constant must be one of 'bool', float16, float32,"
406+
"float64, int32 or int64, but received %s." % convert_dtype(
407+
(dtype)))
400408
if out is None:
401409
out = helper.create_variable_for_type_inference(dtype=dtype)
410+
else:
411+
if not (convert_dtype(dtype) == convert_dtype(out.dtype)):
412+
raise TypeError(
413+
"The create data type in op must be same with out type"
414+
"but received %s and out dtype %s." % (convert_dtype(
415+
(dtype), convert_dtype(out.dtype))))
402416
helper.append_op(
403417
type='fill_constant',
404418
inputs={},

python/paddle/fluid/tests/test_if_else_op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def compare_ifelse_op_and_numpy(self, place):
183183
false_target = fluid.layers.tanh(false_target)
184184
ie.output(false_target)
185185
if_out = ie()
186-
out = layers.reduce_sum(if_out)
186+
out = layers.reduce_sum(if_out[0])
187187

188188
exe = fluid.Executor(place)
189189
exe.run(fluid.default_startup_program())

python/paddle/fluid/tests/unittests/test_fill_constant_op.py

+38
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import paddle.fluid.core as core
2222
from paddle.fluid.op import Operator
23+
import paddle.fluid as fluid
24+
from paddle.fluid import compiler, Program, program_guard
2325

2426

2527
class TestFillConstantOp1(OpTest):
@@ -104,5 +106,41 @@ def test_fill_constant_with_selected_rows(self):
104106
self.check_with_place(place)
105107

106108

109+
class TestFillConstantOpError(OpTest):
110+
def test_errors(self):
111+
with program_guard(Program(), Program()):
112+
#for ci coverage
113+
x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16")
114+
self.assertRaises(
115+
ValueError,
116+
fluid.layers.fill_constant,
117+
shape=[1],
118+
value=5,
119+
dtype='uint4')
120+
self.assertRaises(
121+
ValueError,
122+
fluid.layers.fill_constant,
123+
shape=[1],
124+
value=5,
125+
dtype='int16',
126+
out=x1)
127+
# The input dtype of fill_constant must be one of bool, float16,
128+
#float32, float64, int32 or int64
129+
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
130+
self.assertRaises(
131+
TypeError,
132+
fluid.layers.fill_constant,
133+
shape=[1],
134+
value=5,
135+
dtype='uint8')
136+
self.assertRaises(
137+
TypeError,
138+
fluid.layers.fill_constant,
139+
shape=[1],
140+
value=5,
141+
dtype='float64',
142+
out=x2)
143+
144+
107145
if __name__ == "__main__":
108146
unittest.main()

python/paddle/fluid/tests/unittests/test_reduce_op.py

+27
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import unittest
1818
import numpy as np
1919
from op_test import OpTest
20+
import paddle.fluid.core as core
21+
import paddle.fluid as fluid
22+
from paddle.fluid import compiler, Program, program_guard
2023

2124

2225
class TestSumOp(OpTest):
@@ -411,5 +414,29 @@ def test_check_grad(self):
411414
self.check_grad(['X'], 'Out')
412415

413416

417+
class TestReduceSumOpError(OpTest):
418+
def test_errors(self):
419+
with program_guard(Program(), Program()):
420+
# The input type of reduce_sum_op must be Variable.
421+
x1 = fluid.create_lod_tensor(
422+
np.array([[-1]]), [[1]], fluid.CPUPlace())
423+
self.assertRaises(TypeError, fluid.layers.reduce_sum, x1)
424+
# The input dtype of reduce_sum_op must be float32 or float64 or int32 or int64.
425+
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
426+
self.assertRaises(TypeError, fluid.layers.reduce_sum, x2)
427+
428+
429+
class TestReduceMeanOpError(OpTest):
430+
def test_errors(self):
431+
with program_guard(Program(), Program()):
432+
# The input type of reduce_mean_op must be Variable.
433+
x1 = fluid.create_lod_tensor(
434+
np.array([[-1]]), [[1]], fluid.CPUPlace())
435+
self.assertRaises(TypeError, fluid.layers.reduce_mean, x1)
436+
# The input dtype of reduce_mean_op must be float32 or float64 or int32 or int64.
437+
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
438+
self.assertRaises(TypeError, fluid.layers.reduce_mean, x2)
439+
440+
414441
if __name__ == '__main__':
415442
unittest.main()

0 commit comments

Comments
 (0)