Skip to content

Commit 58e1b3b

Browse files
authored
Merge pull request #446 from QiJune/format_py_code_2nd
format python code in python directory
2 parents ef5e483 + a1ba3f4 commit 58e1b3b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+3498
-2926
lines changed

python/paddle/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

python/paddle/trainer/PyDataProvider2.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
import functools
1919
import itertools
2020

21-
logging.basicConfig(
22-
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
23-
" %(message)s")
21+
logging.basicConfig(format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
22+
" %(message)s")
2423

2524

2625
class SequenceType(object):
@@ -132,8 +131,10 @@ def __init__(self, generator, input_order):
132131
def __call__(self, obj, filename):
133132
for item in self.generator(obj, filename):
134133
if isinstance(item, dict):
135-
yield [item.get(input_name, None) for input_name in
136-
self.input_order]
134+
yield [
135+
item.get(input_name, None)
136+
for input_name in self.input_order
137+
]
137138
else:
138139
yield item
139140

@@ -162,8 +163,8 @@ def __call__(self, obj, filename):
162163
yield items
163164
except AssertionError as e:
164165
self.logger.warning(
165-
"Item (%s) is not fit the input type with error %s"
166-
% (repr(item), repr(e)))
166+
"Item (%s) is not fit the input type with error %s" %
167+
(repr(item), repr(e)))
167168

168169
if self.check_fail_continue:
169170
continue
@@ -202,13 +203,17 @@ def loop_check(callback, item):
202203
callback(each)
203204

204205

205-
def provider(input_types=None, should_shuffle=None, pool_size=-1,
206+
def provider(input_types=None,
207+
should_shuffle=None,
208+
pool_size=-1,
206209
min_pool_size=-1,
207210
can_over_batch_size=True,
208211
calc_batch_size=None,
209212
cache=CacheType.NO_CACHE,
210-
check=False, check_fail_continue=False,
211-
init_hook=None, **kwargs):
213+
check=False,
214+
check_fail_continue=False,
215+
init_hook=None,
216+
**kwargs):
212217
"""
213218
Provider decorator. Use it to make a function into PyDataProvider2 object.
214219
In this function, user only need to get each sample for some train/test
@@ -318,9 +323,9 @@ def __init__(self, file_list, **kwargs):
318323
"Could not recognize should_shuffle (%s), "
319324
"just use default value of should_shuffle."
320325
" Please set should_shuffle to bool value or "
321-
"something in %s" % (
322-
repr(self.should_shuffle),
323-
repr(true_table + false_table)))
326+
"something in %s" %
327+
(repr(self.should_shuffle),
328+
repr(true_table + false_table)))
324329
self.should_shuffle = None
325330

326331
self.pool_size = pool_size
@@ -351,8 +356,7 @@ def __init__(self, file_list, **kwargs):
351356
self.generator = InputOrderWrapper(self.generator,
352357
self.input_order)
353358
if self.check:
354-
self.generator = CheckWrapper(self.generator,
355-
self.slots,
359+
self.generator = CheckWrapper(self.generator, self.slots,
356360
check_fail_continue,
357361
self.logger)
358362

@@ -368,4 +372,3 @@ def deserialize_args(args):
368372
:return:
369373
"""
370374
return cPickle.loads(args)
371-

python/paddle/trainer/PyDataProviderWrapper.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
"""
1615
This module provide a wrapper(decorator) to wrap a data process method into a
1716
PyDataProvider. Some examples are shown `here <data_provider/python_case.html>`_.
@@ -47,6 +46,7 @@
4746

4847
import io
4948

49+
5050
class SlotType(object): # Just a hint for user.
5151
pass
5252

@@ -83,6 +83,7 @@ class SparseNonValueSlot(SlotType):
8383
- **SubSeq**: [[[int, int, ...], [int, ....], ...] , \
8484
[[int, int, ...], [int, ....], ...] , ...]
8585
"""
86+
8687
def __init__(self, dim):
8788
"""
8889
:param dim: slot dimension
@@ -294,8 +295,9 @@ def reset(self):
294295
fn = "%s_%d" % (self.profile_filename, self.profile_count)
295296
sortby = "cumulative"
296297
with open(fn, "w") as f:
297-
pstats.Stats(self.profiler, stream=f).sort_stats(
298-
sortby).print_stats()
298+
pstats.Stats(
299+
self.profiler,
300+
stream=f).sort_stats(sortby).print_stats()
299301
self.logger.info("saving profile to file %s" % fn)
300302
self.profile_count += 1
301303
self.logger.info("resetting profile")
@@ -453,9 +455,10 @@ def writeDataStream(dat, data_callback):
453455
seq_stream.flush()
454456
subseq_stream.flush()
455457

456-
return "".join([self.int_packer.pack(current_batch_size),
457-
data_bytes.getvalue(),
458-
seq_bytes.getvalue(), subseq_bytes.getvalue()])
458+
return "".join([
459+
self.int_packer.pack(current_batch_size), data_bytes.getvalue(),
460+
seq_bytes.getvalue(), subseq_bytes.getvalue()
461+
])
459462

460463
finally:
461464
data_stream.close()
@@ -516,7 +519,7 @@ def __prepareData(self, batch_size, ret_list):
516519
self.data_pool[idx])
517520
idx -= 1
518521

519-
ret_list += self.data_pool[self.data_pool_idx: idx + 1]
522+
ret_list += self.data_pool[self.data_pool_idx:idx + 1]
520523

521524
# for speed reason, just shift left index, not delete data actually.
522525
self.data_pool_idx = idx + 1
@@ -537,8 +540,8 @@ def fillPool(self):
537540
if self.max_pool_size == 0:
538541
for i in xrange(min(self.file_count, len(self.generators))):
539542
self.data_pool += list(self.generators[i])
540-
self.generators = self.generators[
541-
min(self.file_count, len(self.generators)):]
543+
self.generators = self.generators[min(self.file_count,
544+
len(self.generators)):]
542545
self.max_pool_size = len(self.data_pool)
543546
else:
544547
while len(self.data_pool) < self.max_pool_size and len(
@@ -562,9 +565,15 @@ def default_init_hook(cls, *args, **kwargs):
562565
del cls, args, kwargs
563566

564567

565-
def provider(slots=None, use_seq=False, should_shuffle=True, pool_size=1,
566-
can_over_batch_size=True, calc_batch_size=lambda data: 1,
567-
debug=False, init_hook=default_init_hook, profile_filename=None):
568+
def provider(slots=None,
569+
use_seq=False,
570+
should_shuffle=True,
571+
pool_size=1,
572+
can_over_batch_size=True,
573+
calc_batch_size=lambda data: 1,
574+
debug=False,
575+
init_hook=default_init_hook,
576+
profile_filename=None):
568577
"""
569578
The decorator for PyDataProvider. User should use this to create Provider class.
570579
User should only concern how to read sample from file.
@@ -663,7 +672,7 @@ class Cls(GeneralPyDataProvider):
663672
def __init__(self, *file_list, **kwargs):
664673
logging.basicConfig(
665674
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
666-
" %(message)s")
675+
" %(message)s")
667676

668677
self.logger = logging.getLogger("")
669678
if debug:

python/paddle/trainer/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

0 commit comments

Comments
 (0)