Skip to content

Commit ee0eede

Browse files
authored
[cherry-pick] Fix the error message of zeros op (#20476) (#20593)
test=release/1.6 * fix the error message of zeros op test=develop * Fix unittest of zeros op test=develop
1 parent cefbcf7 commit ee0eede

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

python/paddle/fluid/layers/tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,13 @@ def zeros(shape, dtype, force_cpu=False):
902902
import paddle.fluid as fluid
903903
data = fluid.layers.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]]
904904
"""
905+
if convert_dtype(dtype) not in [
906+
'bool', 'float16', 'float32', 'float64', 'int32', 'int64'
907+
]:
908+
raise TypeError(
909+
"The create data type in zeros must be one of bool, float16, float32,"
910+
" float64, int32 or int64, but received %s." % convert_dtype(
911+
(dtype)))
905912
return fill_constant(value=0.0, **locals())
906913

907914

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from op_test import OpTest
20+
21+
import paddle.fluid.core as core
22+
from paddle.fluid.op import Operator
23+
import paddle.fluid as fluid
24+
from paddle.fluid import compiler, Program, program_guard
25+
26+
27+
class TestZerosOpError(OpTest):
28+
def test_errors(self):
29+
with program_guard(Program(), Program()):
30+
# The input dtype of zeros_op must be bool, float16, float32, float64, int32, int64.
31+
x1 = fluid.layers.data(name='x1', shape=[4], dtype="int8")
32+
self.assertRaises(TypeError, fluid.layers.zeros, x1)
33+
x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8")
34+
self.assertRaises(TypeError, fluid.layers.zeros, x2)
35+
36+
37+
if __name__ == "__main__":
38+
unittest.main()

0 commit comments

Comments
 (0)