From 95ea5abf4cacfe6c58f85488e5ca04ace156332c Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 13 Feb 2017 20:07:55 +0800 Subject: [PATCH] Data reader for api --- demo/mnist/reader.py | 51 +++++++++++++++++++++++ python/paddle/v2/data.py | 88 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 demo/mnist/reader.py create mode 100644 python/paddle/v2/data.py diff --git a/demo/mnist/reader.py b/demo/mnist/reader.py new file mode 100644 index 00000000000000..ae2035aad27bc5 --- /dev/null +++ b/demo/mnist/reader.py @@ -0,0 +1,51 @@ +import os, sys +import struct +import numpy as np +import paddle.v2 as paddle + + +def load_data(filename, dir='./data/raw_data/'): + image = '-images-idx3-ubyte' + label = '-labels-idx1-ubyte' + if filename is 'train': + image_file = os.path.join(dir, filename + image) + label_file = os.path.join(dir, filename + label) + else: + image_file = os.path.join(dir, 't10k' + image) + label_file = os.path.join(dir, 't10k' + label) + + with open(image_file, "rb") as f: + num_magic, n, num_row, num_col = struct.unpack(">IIII", f.read(16)) + images = np.fromfile(f, 'ubyte', count=n * num_row * num_col).\ + reshape(n, num_row * num_col).astype('float32') + images = images / 255.0 * 2.0 - 1.0 + + with open(label_file, "rb") as fn: + num_magic, num_label = struct.unpack(">II", fn.read(8)) + labels = np.fromfile(fn, 'ubyte', count=num_label).astype('int32') + + return images, labels + + +def data(images, labels): + for i in xrange(len(labels)): + yield {"pixel": images[i, :], 'label': labels[i]} + + +def main(): + train_images, train_label = load_data('train') + train_gen = data(train_images, train_label) + train_data = paddle.data.CacheAllDataPool(train_gen, 128, + ['pixel', 'label']) + + test_images, test_label = load_data('test') + test_gen = data(test_images[0:128], test_label[0:128]) + test_data = paddle.data.CacheAllDataPool(test_gen, 128, ['pixel', 'label'], + False) + + for data_batch in test_data: + print data_batch + + +if __name__ == "__main__": + main() diff --git a/python/paddle/v2/data.py b/python/paddle/v2/data.py new file mode 100644 index 00000000000000..a1b71e57b1f7c8 --- /dev/null +++ b/python/paddle/v2/data.py @@ -0,0 +1,88 @@ +import collections +import random + +__all__ = [ + 'IDataPool', + 'CacheAllDataPool', +] + + +class IDataPool(object): + """ + Interface of DataPool, but note that Python is using Duck-Typing, it is not + necessary to inherit this interface. + + NOTE: For Paddle developer, NEVER CHECK isinstance(obj, IDataPool). + + Basically contains two method, + + * next(): User should return the next batch of data in pool. raise + StopIteration if there is no more data in pool. + + * reset(): Reset the data pool to initial status. + + The basic usage of this api is as same as normal Python iterator, like + + .. code-block:: python + + pool = DataPool() + + for batch in pool: + process_batch(batch) + + + NOTE: The Data Pool API is not thread-safe. + """ + + def __iter__(self): + self.reset() + return self + + def next(self): + raise NotImplementedError() + + def __next__(self): + return self.next() + + def reset(self): + raise NotImplementedError() + + +def input_order_mapper(iterable, input_order): + assert isinstance(input_order, collections.Sequence) + for each_input_name in input_order: + assert isinstance(each_input_name, basestring) + + tmp = [None] * len(input_order) + for each_item in iterable: + for i in xrange(len(input_order)): + tmp[i] = each_item[input_order[i]] + yield tmp + + +class CacheAllDataPool(IDataPool): + """ + load all samples in memory. + """ + + def __init__(self, iterable, batch_size, input_order, shuffle=True): + self.__pool__ = list( + input_order_mapper( + iterable=iterable, input_order=input_order)) + self.__batch_size__ = batch_size + self.__shuffle__ = shuffle + self.__idx__ = 0 + + def reset(self): + self.__idx__ = 0 + if self.__shuffle__: + random.shuffle(self.__pool__) + + def next(self): + if self.__idx__ >= len(self.__pool__): + raise StopIteration() + + begin = self.__idx__ + end = min(self.__idx__ + self.__batch_size__, len(self.__pool__)) + self.__idx__ = end + return self.__pool__[begin:end]