-
Notifications
You must be signed in to change notification settings - Fork 5.7k
use pyreader to read data in dygraph mode #17314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wopeizl
merged 18 commits into
PaddlePaddle:develop
from
wopeizl:use_pyreader_in_dygraph
Jun 5, 2019
Merged
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
d002f31
use pyreader to read data
wopeizl 239973a
correct to the right flower dataset
wopeizl 3ab3e0b
add return_list to PyReader to support return value represented as list
wopeizl 9526933
Merge remote-tracking branch 'upstream/develop' into use_pyreader_in_…
wopeizl f758d65
Merge remote-tracking branch 'upstream/develop' into use_pyreader_in_…
wopeizl 17135a7
test=develop
wopeizl 29dfa74
test=develop
wopeizl 59cd8bd
add extra test case for pyreader
wopeizl ac8a6ff
test=develop
wopeizl e29d936
sync to latest
wopeizl f5f3e4f
test=develop
wopeizl a01bc62
Merge branch 'use_pyreader_in_dygraph' of https://github.com/wopeizl/…
wopeizl 3ecb635
test=develop
wopeizl a24921a
test=develop
wopeizl 54f584e
test=develop
wopeizl c260376
test=develop
wopeizl 361506d
test=develop
wopeizl d7ca993
test=develop
wopeizl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,13 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from . import core | ||
from . import core, dygraph | ||
import six | ||
import numpy as np | ||
import threading | ||
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program | ||
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode | ||
from .executor import global_scope | ||
from .data_feeder import DataFeeder, BatchedTensorProvider | ||
from .data_feeder import DataFeeder, BatchedTensorProvider, ListTensorProvider | ||
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer | ||
from .unique_name import UniqueNameGenerator | ||
|
||
|
@@ -48,12 +49,13 @@ class PyReader(object): | |
|
||
Args: | ||
feed_list (list(Variable)|tuple(Variable)): feed variable list. | ||
The variables should be created by :code:`fluid.layers.data()`. | ||
The variables should be created by :code:`fluid.layers.data()`. | ||
it can be None under iterable mode. | ||
capacity (int): capacity of the queue maintained in PyReader object. | ||
use_double_buffer (bool): whether to use double_buffer_reader to | ||
speed up data feeding. | ||
iterable (bool): whether the created reader object is iterable. | ||
|
||
return_list (bool): whether the return value presented as list. | ||
Returns: | ||
reader (Reader): the created reader object. | ||
|
||
|
@@ -124,7 +126,7 @@ def reader(): | |
return reader | ||
|
||
image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32') | ||
reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True) | ||
reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True, return_list=False) | ||
|
||
user_defined_reader = reader_creator_random_image(784, 784) | ||
reader.decorate_sample_list_generator( | ||
|
@@ -138,26 +140,71 @@ def reader(): | |
for data in reader(): | ||
executor.run(feed=data) | ||
|
||
|
||
3. If return_list=True, the return values would be presented as list instead of dict`. | ||
|
||
.. code-block:: python | ||
|
||
import paddle | ||
import paddle.fluid as fluid | ||
import numpy as np | ||
|
||
EPOCH_NUM = 3 | ||
ITER_NUM = 5 | ||
BATCH_SIZE = 10 | ||
|
||
def reader_creator_random_image(height, width): | ||
def reader(): | ||
for i in range(ITER_NUM): | ||
yield np.random.uniform(low=0, high=255, size=[height, width]), | ||
return reader | ||
|
||
image = fluid.layers.data(name='image', shape=[784, 784], dtype='float32') | ||
reader = fluid.io.PyReader(feed_list=[image], capacity=4, iterable=True, return_list=True) | ||
|
||
user_defined_reader = reader_creator_random_image(784, 784) | ||
reader.decorate_sample_list_generator( | ||
paddle.batch(user_defined_reader, batch_size=BATCH_SIZE), | ||
fluid.core.CPUPlace()) | ||
# definition of network is omitted | ||
executor = fluid.Executor(fluid.core.CPUPlace()) | ||
executor.run(fluid.default_main_program()) | ||
|
||
for _ in range(EPOCH_NUM): | ||
for data in reader(): | ||
executor.run(feed={"image": data[0]}) | ||
""" | ||
|
||
unique_name_generator = UniqueNameGenerator() | ||
|
||
def __init__(self, | ||
feed_list, | ||
capacity, | ||
feed_list=None, | ||
capacity=1, | ||
use_double_buffer=True, | ||
iterable=False): | ||
iterable=False, | ||
return_list=False): | ||
wopeizl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._tensor_reader = None | ||
self._thread = None | ||
self._iterable = iterable | ||
self._feed_list = feed_list | ||
# force to use iterable mode under dygraph mode | ||
if in_dygraph_mode(): | ||
self._iterable = True | ||
wopeizl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._return_list = True | ||
else: | ||
self._iterable = iterable | ||
self._return_list = return_list | ||
if not self._feed_list: | ||
raise Exception("Feed list must be given under static mode.") | ||
wopeizl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._use_double_buffer = use_double_buffer | ||
self._capacity = capacity | ||
self._feed_list = feed_list | ||
if not self._iterable: | ||
self._init_non_iterable() | ||
|
||
def _init_iterable(self, places): | ||
self._var_names = [v.name for v in self._feed_list] | ||
if in_dygraph_mode(): | ||
self._var_names = [] | ||
else: | ||
self._var_names = [v.name for v in self._feed_list] | ||
self._places = _convert_places(places) | ||
self._queue = core.init_lod_tensor_blocking_queue(core.Variable(), | ||
self._capacity) | ||
|
@@ -240,6 +287,7 @@ class Iterator(object): | |
def __init__(self, reader): | ||
self._reader = reader._reader | ||
self._reset = reader._reset | ||
self._return_list = reader._return_list | ||
|
||
def __iter__(self): | ||
return self | ||
|
@@ -248,12 +296,29 @@ def __next__(self): | |
return self.next() | ||
|
||
def next(self): | ||
ret = self._reader.read_next() | ||
if ret: | ||
return ret | ||
if not in_dygraph_mode(): | ||
ret = None | ||
wopeizl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self._return_list: | ||
ret = self._reader.read_next_list() | ||
ret = ret[0] if ret is not None and len( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if multi-card training is enabled? |
||
ret) > 0 else None | ||
else: | ||
ret = self._reader.read_next() | ||
if ret: | ||
return ret | ||
else: | ||
self._reset() | ||
raise StopIteration | ||
else: | ||
self._reset() | ||
raise StopIteration | ||
ret = self._reader.read_next_list() | ||
if ret and ret[0]: | ||
return [ | ||
dygraph.base.to_variable(np.array(v)) | ||
for v in ret[0] | ||
] | ||
else: | ||
self._reset() | ||
raise StopIteration | ||
|
||
self._start() | ||
return Iterator(self) | ||
|
@@ -293,8 +358,9 @@ def generator(): | |
break | ||
|
||
''' | ||
assert not self._iterable, "start() cannot be called when PyReader is iterable" | ||
self._start() | ||
if not in_dygraph_mode(): | ||
assert not self._iterable, "start() cannot be called when PyReader is iterable" | ||
self._start() | ||
|
||
def reset(self): | ||
''' | ||
|
@@ -327,8 +393,9 @@ def generator(): | |
break | ||
|
||
''' | ||
assert not self._iterable, "reset() cannot be called when PyReader is iterable" | ||
self._reset() | ||
if not in_dygraph_mode(): | ||
assert not self._iterable, "reset() cannot be called when PyReader is iterable" | ||
self._reset() | ||
|
||
def _start(self): | ||
def __thread_main__(): | ||
|
@@ -488,14 +555,22 @@ def generator(): | |
''' | ||
assert self._tensor_reader is None, \ | ||
"Cannot reset the data source of PyReader" | ||
with program_guard(Program(), Program()): | ||
feeder = DataFeeder( | ||
feed_list=self._feed_list, place=core.CPUPlace()) | ||
paddle_reader = feeder.decorate_reader(reader, multi_devices=False) | ||
|
||
def __tensor_reader_impl__(): | ||
for slots in paddle_reader(): | ||
yield [slots[var.name] for var in self._feed_list] | ||
if not in_dygraph_mode(): | ||
with program_guard(Program(), Program()): | ||
feeder = DataFeeder( | ||
feed_list=self._feed_list, place=core.CPUPlace()) | ||
paddle_reader = feeder.decorate_reader( | ||
reader, multi_devices=False) | ||
|
||
def __tensor_reader_impl__(): | ||
for slots in paddle_reader(): | ||
yield [slots[var.name] for var in self._feed_list] | ||
else: | ||
provider = ListTensorProvider(reader, place=core.CPUPlace()) | ||
|
||
def __tensor_reader_impl__(): | ||
for slots in provider(): | ||
yield slots | ||
|
||
self.decorate_batch_generator(__tensor_reader_impl__, places) | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if number of input slots change?