Skip to content

Commit b4302bb

Browse files
authored
Merge pull request #6990 from typhoonzero/refine_pipe_reader
refine pipe_reader
2 parents 80dafdf + 9b67688 commit b4302bb

File tree

2 files changed

+62
-85
lines changed

2 files changed

+62
-85
lines changed

python/paddle/v2/reader/decorator.py

Lines changed: 51 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
__all__ = [
1616
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
17-
'ComposeNotAligned', 'firstn', 'xmap_readers', 'pipe_reader'
17+
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader'
1818
]
1919

2020
from threading import Thread
@@ -334,93 +334,72 @@ def _buf2lines(buf, line_break="\n"):
334334
return lines[:-1], lines[-1]
335335

336336

337-
def pipe_reader(left_cmd,
338-
parser,
339-
bufsize=8192,
340-
file_type="plain",
341-
cut_lines=True,
342-
line_break="\n"):
337+
class PipeReader:
343338
"""
344-
pipe_reader read data by stream from a command, take it's
345-
stdout into a pipe buffer and redirect it to the parser to
346-
parse, then yield data as your desired format.
339+
PipeReader read data by stream from a command, take it's
340+
stdout into a pipe buffer and redirect it to the parser to
341+
parse, then yield data as your desired format.
347342
348-
You can using standard linux command or call another program
349-
to read data, from HDFS, Ceph, URL, AWS S3 etc:
343+
You can using standard linux command or call another program
344+
to read data, from HDFS, Ceph, URL, AWS S3 etc:
350345
351-
cmd = "hadoop fs -cat /path/to/some/file"
352-
cmd = "cat sample_file.tar.gz"
353-
cmd = "curl http://someurl"
354-
cmd = "python print_s3_bucket.py"
346+
.. code-block:: python
347+
cmd = "hadoop fs -cat /path/to/some/file"
348+
cmd = "cat sample_file.tar.gz"
349+
cmd = "curl http://someurl"
350+
cmd = "python print_s3_bucket.py"
355351
356-
A sample parser:
352+
An example:
353+
354+
.. code-block:: python
357355
358-
def sample_parser(lines):
359-
# parse each line as one sample data,
360-
# return a list of samples as batches.
361-
ret = []
362-
for l in lines:
363-
ret.append(l.split(" ")[1:5])
364-
return ret
365-
366-
:param left_cmd: command to excute to get stdout from.
367-
:type left_cmd: string
368-
:param parser: parser function to parse lines of data.
369-
if cut_lines is True, parser will receive list
370-
of lines.
371-
if cut_lines is False, parser will receive a
372-
raw buffer each time.
373-
parser should return a list of parsed values.
374-
:type parser: callable
375-
:param bufsize: the buffer size used for the stdout pipe.
376-
:type bufsize: int
377-
:param file_type: can be plain/gzip, stream buffer data type.
378-
:type file_type: string
379-
:param cut_lines: whether to pass lines instead of raw buffer
380-
to the parser
381-
:type cut_lines: bool
382-
:param line_break: line break of the file, like \n or \r
383-
:type line_break: string
384-
385-
:return: the reader generator.
386-
:rtype: callable
356+
def example_reader():
357+
for f in myfiles:
358+
pr = PipeReader("cat %s"%f)
359+
for l in pr.get_line():
360+
sample = l.split(" ")
361+
yield sample
387362
"""
388-
if not isinstance(left_cmd, str):
389-
raise TypeError("left_cmd must be a string")
390-
if not callable(parser):
391-
raise TypeError("parser must be a callable object")
392-
393-
# TODO(typhoonzero): add a thread to read stderr
394-
395-
# Always init a decompress object is better than
396-
# create in the loop.
397-
dec = zlib.decompressobj(
398-
32 + zlib.MAX_WBITS) # offset 32 to skip the header
399363

400-
def reader():
401-
process = subprocess.Popen(
402-
left_cmd.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)
364+
def __init__(self, command, bufsize=8192, file_type="plain"):
365+
if not isinstance(command, str):
366+
raise TypeError("left_cmd must be a string")
367+
if file_type == "gzip":
368+
self.dec = zlib.decompressobj(
369+
32 + zlib.MAX_WBITS) # offset 32 to skip the header
370+
self.file_type = file_type
371+
self.bufsize = bufsize
372+
self.process = subprocess.Popen(
373+
command.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)
374+
375+
def get_line(self, cut_lines=True, line_break="\n"):
376+
"""
377+
:param cut_lines: cut buffer to lines
378+
:type cut_lines: bool
379+
:param line_break: line break of the file, like \n or \r
380+
:type line_break: string
381+
382+
:return: one line or a buffer of bytes
383+
:rtype: string
384+
"""
403385
remained = ""
404386
while True:
405-
buff = process.stdout.read(bufsize)
387+
buff = self.process.stdout.read(self.bufsize)
406388
if buff:
407-
if file_type == "gzip":
408-
decomp_buff = dec.decompress(buff)
409-
elif file_type == "plain":
389+
if self.file_type == "gzip":
390+
decomp_buff = self.dec.decompress(buff)
391+
elif self.file_type == "plain":
410392
decomp_buff = buff
411393
else:
412-
raise TypeError("file_type %s is not allowed" % file_type)
394+
raise TypeError("file_type %s is not allowed" %
395+
self.file_type)
413396

414397
if cut_lines:
415398
lines, remained = _buf2lines(''.join(
416399
[remained, decomp_buff]), line_break)
417-
parsed_list = parser(lines)
418-
for ret in parsed_list:
419-
yield ret
400+
for line in lines:
401+
yield line
420402
else:
421-
for ret in parser(decomp_buff):
422-
yield ret
403+
yield decomp_buff
423404
else:
424405
break
425-
426-
return reader

python/paddle/v2/reader/tests/decorator_test.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,11 @@ def mapper(x):
147147

148148
class TestPipeReader(unittest.TestCase):
149149
def test_pipe_reader(self):
150-
def simple_parser(lines):
151-
return lines
150+
def example_reader(myfiles):
151+
for f in myfiles:
152+
pr = paddle.v2.reader.PipeReader("cat %s" % f, bufsize=128)
153+
for l in pr.get_line():
154+
yield l
152155

153156
import tempfile
154157

@@ -159,17 +162,12 @@ def simple_parser(lines):
159162
for r in records:
160163
f.write('%s\n' % r)
161164

162-
cmd = "cat %s" % temp.name
163-
reader = paddle.v2.reader.pipe_reader(
164-
cmd, simple_parser, bufsize=128)
165-
for i in xrange(4):
166-
result = []
167-
for r in reader():
168-
result.append(r)
169-
170-
for idx, e in enumerate(records):
171-
print e, result[idx]
172-
self.assertEqual(e, result[idx])
165+
result = []
166+
for r in example_reader([temp.name]):
167+
result.append(r)
168+
169+
for idx, e in enumerate(records):
170+
self.assertEqual(e, result[idx])
173171
finally:
174172
# delete the temporary file
175173
temp.close()

0 commit comments

Comments
 (0)