Skip to content

format python code in python directory #446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

35 changes: 19 additions & 16 deletions python/paddle/trainer/PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import functools
import itertools

logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")
logging.basicConfig(format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")


class SequenceType(object):
Expand Down Expand Up @@ -132,8 +131,10 @@ def __init__(self, generator, input_order):
def __call__(self, obj, filename):
for item in self.generator(obj, filename):
if isinstance(item, dict):
yield [item.get(input_name, None) for input_name in
self.input_order]
yield [
item.get(input_name, None)
for input_name in self.input_order
]
else:
yield item

Expand Down Expand Up @@ -162,8 +163,8 @@ def __call__(self, obj, filename):
yield items
except AssertionError as e:
self.logger.warning(
"Item (%s) is not fit the input type with error %s"
% (repr(item), repr(e)))
"Item (%s) is not fit the input type with error %s" %
(repr(item), repr(e)))

if self.check_fail_continue:
continue
Expand Down Expand Up @@ -202,13 +203,17 @@ def loop_check(callback, item):
callback(each)


def provider(input_types=None, should_shuffle=None, pool_size=-1,
def provider(input_types=None,
should_shuffle=None,
pool_size=-1,
min_pool_size=-1,
can_over_batch_size=True,
calc_batch_size=None,
cache=CacheType.NO_CACHE,
check=False, check_fail_continue=False,
init_hook=None, **kwargs):
check=False,
check_fail_continue=False,
init_hook=None,
**kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
In this function, user only need to get each sample for some train/test
Expand Down Expand Up @@ -318,9 +323,9 @@ def __init__(self, file_list, **kwargs):
"Could not recognize should_shuffle (%s), "
"just use default value of should_shuffle."
" Please set should_shuffle to bool value or "
"something in %s" % (
repr(self.should_shuffle),
repr(true_table + false_table)))
"something in %s" %
(repr(self.should_shuffle),
repr(true_table + false_table)))
self.should_shuffle = None

self.pool_size = pool_size
Expand Down Expand Up @@ -351,8 +356,7 @@ def __init__(self, file_list, **kwargs):
self.generator = InputOrderWrapper(self.generator,
self.input_order)
if self.check:
self.generator = CheckWrapper(self.generator,
self.slots,
self.generator = CheckWrapper(self.generator, self.slots,
check_fail_continue,
self.logger)

Expand All @@ -368,4 +372,3 @@ def deserialize_args(args):
:return:
"""
return cPickle.loads(args)

35 changes: 22 additions & 13 deletions python/paddle/trainer/PyDataProviderWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module provide a wrapper(decorator) to wrap a data process method into a
PyDataProvider. Some examples are shown `here <data_provider/python_case.html>`_.
Expand Down Expand Up @@ -47,6 +46,7 @@

import io


class SlotType(object): # Just a hint for user.
pass

Expand Down Expand Up @@ -83,6 +83,7 @@ class SparseNonValueSlot(SlotType):
- **SubSeq**: [[[int, int, ...], [int, ....], ...] , \
[[int, int, ...], [int, ....], ...] , ...]
"""

def __init__(self, dim):
"""
:param dim: slot dimension
Expand Down Expand Up @@ -294,8 +295,9 @@ def reset(self):
fn = "%s_%d" % (self.profile_filename, self.profile_count)
sortby = "cumulative"
with open(fn, "w") as f:
pstats.Stats(self.profiler, stream=f).sort_stats(
sortby).print_stats()
pstats.Stats(
self.profiler,
stream=f).sort_stats(sortby).print_stats()
self.logger.info("saving profile to file %s" % fn)
self.profile_count += 1
self.logger.info("resetting profile")
Expand Down Expand Up @@ -453,9 +455,10 @@ def writeDataStream(dat, data_callback):
seq_stream.flush()
subseq_stream.flush()

return "".join([self.int_packer.pack(current_batch_size),
data_bytes.getvalue(),
seq_bytes.getvalue(), subseq_bytes.getvalue()])
return "".join([
self.int_packer.pack(current_batch_size), data_bytes.getvalue(),
seq_bytes.getvalue(), subseq_bytes.getvalue()
])

finally:
data_stream.close()
Expand Down Expand Up @@ -516,7 +519,7 @@ def __prepareData(self, batch_size, ret_list):
self.data_pool[idx])
idx -= 1

ret_list += self.data_pool[self.data_pool_idx: idx + 1]
ret_list += self.data_pool[self.data_pool_idx:idx + 1]

# for speed reason, just shift left index, not delete data actually.
self.data_pool_idx = idx + 1
Expand All @@ -537,8 +540,8 @@ def fillPool(self):
if self.max_pool_size == 0:
for i in xrange(min(self.file_count, len(self.generators))):
self.data_pool += list(self.generators[i])
self.generators = self.generators[
min(self.file_count, len(self.generators)):]
self.generators = self.generators[min(self.file_count,
len(self.generators)):]
self.max_pool_size = len(self.data_pool)
else:
while len(self.data_pool) < self.max_pool_size and len(
Expand All @@ -562,9 +565,15 @@ def default_init_hook(cls, *args, **kwargs):
del cls, args, kwargs


def provider(slots=None, use_seq=False, should_shuffle=True, pool_size=1,
can_over_batch_size=True, calc_batch_size=lambda data: 1,
debug=False, init_hook=default_init_hook, profile_filename=None):
def provider(slots=None,
use_seq=False,
should_shuffle=True,
pool_size=1,
can_over_batch_size=True,
calc_batch_size=lambda data: 1,
debug=False,
init_hook=default_init_hook,
profile_filename=None):
"""
The decorator for PyDataProvider. User should use this to create Provider class.
User should only concern how to read sample from file.
Expand Down Expand Up @@ -663,7 +672,7 @@ class Cls(GeneralPyDataProvider):
def __init__(self, *file_list, **kwargs):
logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")
" %(message)s")

self.logger = logging.getLogger("")
if debug:
Expand Down
1 change: 0 additions & 1 deletion python/paddle/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading