Skip to content

format python code in python directory #424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

44 changes: 26 additions & 18 deletions python/paddle/trainer/PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import functools
import itertools

logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")
logging.basicConfig(format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")


class SequenceType(object):
Expand All @@ -38,7 +37,7 @@ class DataType(object):


class CacheType(object):
NO_CACHE = 0 # No cache at all
NO_CACHE = 0 # No cache at all

# First pass, read data from python. And store them in memory. Read from
# memory during rest passes.
Expand Down Expand Up @@ -113,6 +112,7 @@ def integer_sequence(dim):


class SingleSlotWrapper(object):

def __init__(self, generator):
self.generator = generator

Expand All @@ -125,20 +125,24 @@ def __call__(self, obj, filename):


class InputOrderWrapper(object):

def __init__(self, generator, input_order):
self.generator = generator
self.input_order = input_order

def __call__(self, obj, filename):
for item in self.generator(obj, filename):
if isinstance(item, dict):
yield [item.get(input_name, None) for input_name in
self.input_order]
yield [
item.get(input_name, None)
for input_name in self.input_order
]
else:
yield item


class CheckWrapper(object):

def __init__(self, generator, input_types, check_fail_continue, logger):
self.generator = generator
self.input_types = input_types
Expand All @@ -162,8 +166,8 @@ def __call__(self, obj, filename):
yield items
except AssertionError as e:
self.logger.warning(
"Item (%s) is not fit the input type with error %s"
% (repr(item), repr(e)))
"Item (%s) is not fit the input type with error %s" %
(repr(item), repr(e)))

if self.check_fail_continue:
continue
Expand Down Expand Up @@ -202,13 +206,17 @@ def loop_check(callback, item):
callback(each)


def provider(input_types=None, should_shuffle=None, pool_size=-1,
def provider(input_types=None,
should_shuffle=None,
pool_size=-1,
min_pool_size=-1,
can_over_batch_size=True,
calc_batch_size=None,
cache=CacheType.NO_CACHE,
check=False, check_fail_continue=False,
init_hook=None, **kwargs):
check=False,
check_fail_continue=False,
init_hook=None,
**kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
In this function, user only need to get each sample for some train/test
Expand Down Expand Up @@ -289,7 +297,9 @@ def process(settings, file_name):
"""

def __wrapper__(generator):

class DataProvider(object):

def __init__(self, file_list, **kwargs):
self.logger = logging.getLogger("")
self.logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -318,9 +328,9 @@ def __init__(self, file_list, **kwargs):
"Could not recognize should_shuffle (%s), "
"just use default value of should_shuffle."
" Please set should_shuffle to bool value or "
"something in %s" % (
repr(self.should_shuffle),
repr(true_table + false_table)))
"something in %s" %
(repr(self.should_shuffle),
repr(true_table + false_table)))
self.should_shuffle = None

self.pool_size = pool_size
Expand All @@ -340,7 +350,7 @@ def __init__(self, file_list, **kwargs):
assert self.generator is not None

use_dynamic_order = False
if isinstance(self.slots, dict): # reorder input_types
if isinstance(self.slots, dict): # reorder input_types
self.slots = [self.slots[ipt] for ipt in self.input_order]
use_dynamic_order = True

Expand All @@ -351,8 +361,7 @@ def __init__(self, file_list, **kwargs):
self.generator = InputOrderWrapper(self.generator,
self.input_order)
if self.check:
self.generator = CheckWrapper(self.generator,
self.slots,
self.generator = CheckWrapper(self.generator, self.slots,
check_fail_continue,
self.logger)

Expand All @@ -368,4 +377,3 @@ def deserialize_args(args):
:return:
"""
return cPickle.loads(args)

61 changes: 36 additions & 25 deletions python/paddle/trainer/PyDataProviderWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module provide a wrapper(decorator) to wrap a data process method into a
PyDataProvider. Some examples are shown `here <data_provider/python_case.html>`_.
Expand All @@ -33,7 +32,7 @@
'provider', 'init_hook_wrapper'
]

try: # Just for profile mode, will try to import cProfile first.
try: # Just for profile mode, will try to import cProfile first.
# Most python will contains cProfile, cProfile/profile are basically same.
# ref: https://docs.python.org/2/library/profile.html#introduction-to-the-profilers
import cProfile as profile
Expand All @@ -47,7 +46,8 @@

import io

class SlotType(object): # Just a hint for user.

class SlotType(object): # Just a hint for user.
pass


Expand Down Expand Up @@ -83,6 +83,7 @@ class SparseNonValueSlot(SlotType):
- **SubSeq**: [[[int, int, ...], [int, ....], ...] , \
[[int, int, ...], [int, ....], ...] , ...]
"""

def __init__(self, dim):
"""
:param dim: slot dimension
Expand Down Expand Up @@ -234,11 +235,12 @@ def __call__(self, ele):


class GeneralPyDataProvider:

def __init__(self, *file_list, **kwargs):
"""
:param file_list: input file_list
"""
del kwargs # unused
del kwargs # unused
gc.disable()
assert isinstance(self.logger, logging.Logger)
self.use_seq_flag = hasattr(self, "use_seq_flag") and self.use_seq_flag
Expand Down Expand Up @@ -294,8 +296,9 @@ def reset(self):
fn = "%s_%d" % (self.profile_filename, self.profile_count)
sortby = "cumulative"
with open(fn, "w") as f:
pstats.Stats(self.profiler, stream=f).sort_stats(
sortby).print_stats()
pstats.Stats(
self.profiler,
stream=f).sort_stats(sortby).print_stats()
self.logger.info("saving profile to file %s" % fn)
self.profile_count += 1
self.logger.info("resetting profile")
Expand Down Expand Up @@ -384,7 +387,7 @@ def convertDataImpl(idx, data_callback):
slot_sample_num = len(ret_list)
if self.use_seq_flag:
slot_sample_num = 0
if self.has_subseq[idx]: # has sub-sequence
if self.has_subseq[idx]: # has sub-sequence
slot_subseq_num = 0
for dat in ret_list:
dat = dat[idx]
Expand All @@ -403,7 +406,7 @@ def convertDataImpl(idx, data_callback):
dat = dat[idx]
if self.use_seq_flag:
seq_stream.write(self.int_packer.pack(indices))
if self.has_subseq[idx]: # has sub-sequence
if self.has_subseq[idx]: # has sub-sequence
for sub_dat in dat:
writeDataStream(sub_dat, data_callback)
subseq_stream.write(self.int_packer.pack(indices))
Expand All @@ -416,13 +419,13 @@ def convertDataImpl(idx, data_callback):

def writeDataStream(dat, data_callback):
if self.use_seq_flag > 0:
if data_callback is None: # Special for index slot
if data_callback is None: # Special for index slot
data_stream.write(array.array("i", dat).tostring())
else:
for ele in dat:
data_callback(ele)
else:
if data_callback is None: # Special for index slot
if data_callback is None: # Special for index slot
data_stream.write(self.int_packer.pack(dat))
else:
data_callback(dat)
Expand Down Expand Up @@ -453,9 +456,10 @@ def writeDataStream(dat, data_callback):
seq_stream.flush()
subseq_stream.flush()

return "".join([self.int_packer.pack(current_batch_size),
data_bytes.getvalue(),
seq_bytes.getvalue(), subseq_bytes.getvalue()])
return "".join([
self.int_packer.pack(current_batch_size), data_bytes.getvalue(),
seq_bytes.getvalue(), subseq_bytes.getvalue()
])

finally:
data_stream.close()
Expand All @@ -475,12 +479,12 @@ def hasSubseq(self, ret_list):
dat = ret_list[0][i][0]
if isinstance(slot, IndexSlot) or isinstance(slot, StringSlot):
if isinstance(dat, list) or isinstance(dat, numpy.ndarray):
self.has_subseq.append(1) # has_subseq = True
self.has_subseq.append(1) # has_subseq = True
continue
elif isinstance(dat[0], list) or isinstance(dat[0], numpy.ndarray):
self.has_subseq.append(1) # has_subseq = True
self.has_subseq.append(1) # has_subseq = True
continue
self.has_subseq.append(0) # has_subseq = False
self.has_subseq.append(0) # has_subseq = False

def checkOrder(self):
first_noSubseq_slot = self.slots_num
Expand Down Expand Up @@ -511,12 +515,12 @@ def __prepareData(self, batch_size, ret_list):
if current_batch_size >= batch_size:
could_exit = True
break
if current_batch_size > batch_size and not self.can_over_batch_size: # if cannot over batch size
if current_batch_size > batch_size and not self.can_over_batch_size: # if cannot over batch size
current_batch_size -= self.calculateDataBatchSize(
self.data_pool[idx])
idx -= 1

ret_list += self.data_pool[self.data_pool_idx: idx + 1]
ret_list += self.data_pool[self.data_pool_idx:idx + 1]

# for speed reason, just shift left index, not delete data actually.
self.data_pool_idx = idx + 1
Expand All @@ -525,7 +529,7 @@ def __prepareData(self, batch_size, ret_list):
self.data_pool = []
else:
break
if self.use_seq_flag and not self.has_checked: # compute self.has_subseq and checkOrder only at first time
if self.use_seq_flag and not self.has_checked: # compute self.has_subseq and checkOrder only at first time
self.hasSubseq(ret_list)
self.checkOrder()
return current_batch_size
Expand All @@ -537,8 +541,8 @@ def fillPool(self):
if self.max_pool_size == 0:
for i in xrange(min(self.file_count, len(self.generators))):
self.data_pool += list(self.generators[i])
self.generators = self.generators[
min(self.file_count, len(self.generators)):]
self.generators = self.generators[min(self.file_count,
len(self.generators)):]
self.max_pool_size = len(self.data_pool)
else:
while len(self.data_pool) < self.max_pool_size and len(
Expand All @@ -562,9 +566,15 @@ def default_init_hook(cls, *args, **kwargs):
del cls, args, kwargs


def provider(slots=None, use_seq=False, should_shuffle=True, pool_size=1,
can_over_batch_size=True, calc_batch_size=lambda data: 1,
debug=False, init_hook=default_init_hook, profile_filename=None):
def provider(slots=None,
use_seq=False,
should_shuffle=True,
pool_size=1,
can_over_batch_size=True,
calc_batch_size=lambda data: 1,
debug=False,
init_hook=default_init_hook,
profile_filename=None):
"""
The decorator for PyDataProvider. User should use this to create Provider class.
User should only concern how to read sample from file.
Expand Down Expand Up @@ -657,13 +667,14 @@ def process(obj, file_name):
"""

def _wrapper(handler):

class Cls(GeneralPyDataProvider):
""" Real PyDataProvider Class. """

def __init__(self, *file_list, **kwargs):
logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")
" %(message)s")

self.logger = logging.getLogger("")
if debug:
Expand Down
1 change: 0 additions & 1 deletion python/paddle/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading