Skip to content

Commit 340a104

Browse files
authored
Merge pull request #12658 from velconia/port_pybind11
Port pybind11 and python code to support py3 CI test
2 parents 91e84d3 + a32ce8c commit 340a104

File tree

99 files changed

+1363
-466
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+1363
-466
lines changed

paddle/fluid/framework/op_desc.cc

+59
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,52 @@ std::vector<std::string> OpDesc::AttrNames() const {
202202
}
203203

204204
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
205+
// NOTICE(minqiyang): pybind11 will take the empty list in python as
206+
// the std::vector<int> type in C++; so we have to change the attr's type
207+
// here if we meet this issue
208+
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
209+
if (attr_type == proto::AttrType::INTS &&
210+
boost::get<std::vector<int>>(v).size() == 0u) {
211+
// Find current attr via attr name and set the correct attribute value
212+
const proto::OpProto::Attr &attr = GetProtoAttr(name);
213+
switch (attr.type()) {
214+
case proto::AttrType::BOOLEANS: {
215+
VLOG(11) << "SetAttr: " << Type() << ", " << name
216+
<< " from INTS to BOOLEANS";
217+
this->attrs_[name] = std::vector<bool>();
218+
break;
219+
}
220+
case proto::AttrType::INTS: {
221+
VLOG(11) << "SetAttr: " << Type() << ", " << name
222+
<< " from INTS to INTS";
223+
this->attrs_[name] = std::vector<int>();
224+
break;
225+
}
226+
case proto::AttrType::FLOATS: {
227+
VLOG(11) << "SetAttr: " << Type() << ", " << name
228+
<< " from INTS to FLOATS";
229+
this->attrs_[name] = std::vector<float>();
230+
break;
231+
}
232+
case proto::AttrType::STRINGS: {
233+
VLOG(11) << "SetAttr: " << Type() << ", " << name
234+
<< " from INTS to STRINGS";
235+
this->attrs_[name] = std::vector<std::string>();
236+
break;
237+
}
238+
case proto::AttrType::BLOCKS: {
239+
VLOG(11) << "SetAttr: " << Type() << ", " << name
240+
<< " from INTS to BLOCKS";
241+
this->SetBlocksAttr(name, std::vector<BlockDesc *>());
242+
return;
243+
}
244+
default:
245+
PADDLE_THROW("Wrong attr type %d", attr.type());
246+
}
247+
need_update_ = true;
248+
return;
249+
}
250+
205251
this->attrs_[name] = v;
206252
need_update_ = true;
207253
}
@@ -229,6 +275,19 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
229275
return it->second;
230276
}
231277

278+
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
279+
const std::string &name) const {
280+
const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
281+
for (int i = 0; i != proto.attrs_size(); ++i) {
282+
const proto::OpProto::Attr &attr = proto.attrs(i);
283+
if (attr.name() == name) {
284+
return attr;
285+
}
286+
}
287+
288+
PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
289+
}
290+
232291
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
233292
auto it = attrs_.find(name);
234293
if (it != attrs_.end()) {

paddle/fluid/framework/op_desc.h

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class OpDesc {
8181

8282
Attribute GetAttr(const std::string &name) const;
8383

84+
const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
85+
8486
Attribute GetNullableAttr(const std::string &name) const;
8587

8688
int GetBlockAttrId(const std::string &name) const;

paddle/fluid/pybind/protobuf.cc

+1-6
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,7 @@ void BindBlockDesc(pybind11::module *m) {
205205
void BindVarDsec(pybind11::module *m) {
206206
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
207207
var_desc
208-
.def("name",
209-
[](pd::VarDesc &self) {
210-
pybind11::bytes name = self.Name();
211-
return name;
212-
},
213-
pybind11::return_value_policy::reference)
208+
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
214209
.def("set_name", &pd::VarDesc::SetName)
215210
.def("set_shape", &pd::VarDesc::SetShape)
216211
.def("set_shapes", &pd::VarDesc::SetShapes)

paddle/fluid/pybind/pybind.cc

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ limitations under the License. */
5454
#include "paddle/fluid/platform/gpu_info.h"
5555
#endif
5656

57+
#include "pybind11/stl.h"
58+
5759
// disable auto conversion to list in Python
5860
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);
5961

python/paddle/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
import paddle.reader
2525
import paddle.dataset
2626
import paddle.batch
27+
import paddle.compat
2728
batch = batch.batch

python/paddle/compat.py

+237
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import six
16+
import math
17+
18+
__all__ = [
19+
'long_type',
20+
'to_text',
21+
'to_bytes',
22+
'round',
23+
'floor_division',
24+
'get_exception_message',
25+
]
26+
27+
if six.PY2:
28+
int_type = int
29+
long_type = long
30+
else:
31+
int_type = int
32+
long_type = int
33+
34+
35+
# str and bytes related functions
36+
def to_text(obj, encoding='utf-8', inplace=False):
37+
"""
38+
All string in PaddlePaddle should be represented as a literal string.
39+
This function will convert object to a literal string without any encoding.
40+
Especially, if the object type is a list or set container, we will iterate
41+
all items in the object and convert them to literal string.
42+
43+
In Python3:
44+
Decode the bytes type object to str type with specific encoding
45+
46+
In Python2:
47+
Decode the str type object to unicode type with specific encoding
48+
49+
Args:
50+
obj(unicode|str|bytes|list|set) : The object to be decoded.
51+
encoding(str) : The encoding format to decode a string
52+
inplace(bool) : If we change the original object or we create a new one
53+
54+
Returns:
55+
Decoded result of obj
56+
"""
57+
if obj is None:
58+
return obj
59+
60+
if isinstance(obj, list):
61+
if inplace:
62+
for i in six.moves.xrange(len(obj)):
63+
obj[i] = _to_text(obj[i], encoding)
64+
return obj
65+
else:
66+
return [_to_text(item, encoding) for item in obj]
67+
elif isinstance(obj, set):
68+
if inplace:
69+
for item in obj:
70+
obj.remove(item)
71+
obj.add(_to_text(item, encoding))
72+
return obj
73+
else:
74+
return set([_to_text(item, encoding) for item in obj])
75+
else:
76+
return _to_text(obj, encoding)
77+
78+
79+
def _to_text(obj, encoding):
80+
"""
81+
In Python3:
82+
Decode the bytes type object to str type with specific encoding
83+
84+
In Python2:
85+
Decode the str type object to unicode type with specific encoding,
86+
or we just return the unicode string of object
87+
88+
Args:
89+
obj(unicode|str|bytes) : The object to be decoded.
90+
encoding(str) : The encoding format
91+
92+
Returns:
93+
decoded result of obj
94+
"""
95+
if obj is None:
96+
return obj
97+
98+
if isinstance(obj, six.binary_type):
99+
return obj.decode(encoding)
100+
elif isinstance(obj, six.text_type):
101+
return obj
102+
else:
103+
return six.u(obj)
104+
105+
106+
def to_bytes(obj, encoding='utf-8', inplace=False):
107+
"""
108+
All string in PaddlePaddle should be represented as a literal string.
109+
This function will convert object to a bytes with specific encoding.
110+
Especially, if the object type is a list or set container, we will iterate
111+
all items in the object and convert them to bytes.
112+
113+
In Python3:
114+
Encode the str type object to bytes type with specific encoding
115+
116+
In Python2:
117+
Encode the unicode type object to str type with specific encoding,
118+
or we just return the 8-bit string of object
119+
120+
Args:
121+
obj(unicode|str|bytes|list|set) : The object to be encoded.
122+
encoding(str) : The encoding format to encode a string
123+
inplace(bool) : If we change the original object or we create a new one
124+
125+
Returns:
126+
Decoded result of obj
127+
"""
128+
if obj is None:
129+
return obj
130+
131+
if isinstance(obj, list):
132+
if inplace:
133+
for i in six.moves.xrange(len(obj)):
134+
obj[i] = _to_bytes(obj[i], encoding)
135+
return obj
136+
else:
137+
return [_to_bytes(item, encoding) for item in obj]
138+
elif isinstance(obj, set):
139+
if inplace:
140+
for item in obj:
141+
obj.remove(item)
142+
obj.add(_to_bytes(item, encoding))
143+
return obj
144+
else:
145+
return set([_to_bytes(item, encoding) for item in obj])
146+
else:
147+
return _to_bytes(obj, encoding)
148+
149+
150+
def _to_bytes(obj, encoding):
151+
"""
152+
In Python3:
153+
Encode the str type object to bytes type with specific encoding
154+
155+
In Python2:
156+
Encode the unicode type object to str type with specific encoding,
157+
or we just return the 8-bit string of object
158+
159+
Args:
160+
obj(unicode|str|bytes) : The object to be encoded.
161+
encoding(str) : The encoding format
162+
163+
Returns:
164+
encoded result of obj
165+
"""
166+
if obj is None:
167+
return obj
168+
169+
assert encoding is not None
170+
if isinstance(obj, six.text_type):
171+
return obj.encode(encoding)
172+
elif isinstance(obj, six.binary_type):
173+
return obj
174+
else:
175+
return six.b(obj)
176+
177+
178+
# math related functions
179+
def round(x, d=0):
180+
"""
181+
Compatible round which act the same behaviour in Python3.
182+
183+
Args:
184+
x(float) : The number to round halfway.
185+
186+
Returns:
187+
round result of x
188+
"""
189+
if six.PY3:
190+
# The official walkaround of round in Python3 is incorrect
191+
# we implement accroding this answer: https://www.techforgeek.info/round_python.html
192+
if x > 0.0:
193+
p = 10**d
194+
return float(math.floor((x * p) + math.copysign(0.5, x))) / p
195+
elif x < 0.0:
196+
p = 10**d
197+
return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
198+
else:
199+
return math.copysign(0.0, x)
200+
else:
201+
import __builtin__
202+
return __builtin__.round(x, d)
203+
204+
205+
def floor_division(x, y):
206+
"""
207+
Compatible division which act the same behaviour in Python3 and Python2,
208+
whose result will be a int value of floor(x / y) in Python3 and value of
209+
(x / y) in Python2.
210+
211+
Args:
212+
x(int|float) : The number to divide.
213+
y(int|float) : The number to be divided
214+
215+
Returns:
216+
division result of x // y
217+
"""
218+
return x // y
219+
220+
221+
# exception related functions
222+
def get_exception_message(exc):
223+
"""
224+
Get the error message of a specific exception
225+
226+
Args:
227+
exec(Exception) : The exception to get error message.
228+
229+
Returns:
230+
the error message of exec
231+
"""
232+
assert exc is not None
233+
234+
if six.PY2:
235+
return exc.message
236+
else:
237+
return str(exc)

python/paddle/dataset/cifar.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import numpy
3333
import paddle.dataset.common
3434
import tarfile
35-
from six.moves import zip
35+
import six
3636
from six.moves import cPickle as pickle
3737

3838
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
@@ -46,10 +46,11 @@
4646

4747
def reader_creator(filename, sub_name, cycle=False):
4848
def read_batch(batch):
49-
data = batch['data']
50-
labels = batch.get('labels', batch.get('fine_labels', None))
49+
data = batch[six.b('data')]
50+
labels = batch.get(
51+
six.b('labels'), batch.get(six.b('fine_labels'), None))
5152
assert labels is not None
52-
for sample, label in zip(data, labels):
53+
for sample, label in six.moves.zip(data, labels):
5354
yield (sample / 255.0).astype(numpy.float32), int(label)
5455

5556
def reader():
@@ -59,7 +60,11 @@ def reader():
5960

6061
while True:
6162
for name in names:
62-
batch = pickle.load(f.extractfile(name))
63+
if six.PY2:
64+
batch = pickle.load(f.extractfile(name))
65+
else:
66+
batch = pickle.load(
67+
f.extractfile(name), encoding='bytes')
6368
for item in read_batch(batch):
6469
yield item
6570
if not cycle:

python/paddle/dataset/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ def download(url, module_name, md5sum, save_name=None):
8585
total_length = r.headers.get('content-length')
8686

8787
if total_length is None:
88-
with open(filename, 'w') as f:
88+
with open(filename, 'wb') as f:
8989
shutil.copyfileobj(r.raw, f)
9090
else:
91-
with open(filename, 'w') as f:
91+
with open(filename, 'wb') as f:
9292
dl = 0
9393
total_length = int(total_length)
9494
for data in r.iter_content(chunk_size=4096):

0 commit comments

Comments
 (0)