Skip to content

Commit be77aee

Browse files
authored
fix Tensor share memory in eager mode. test=develop (#42445)
1 parent d6442df commit be77aee

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

python/paddle/fluid/dataloader/worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .. import core
2323
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
2424
from ..multiprocess_utils import _cleanup_mmap, CleanupFuncRegistrar, MP_STATUS_CHECK_INTERVAL
25-
from ..framework import _non_static_mode
25+
from ..framework import _non_static_mode, _in_eager_without_dygraph_check
2626
from .flat import _flatten_batch
2727

2828
# NOTE: queue has a different name in python2 and python3
@@ -339,10 +339,16 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
339339
out_queue.put((idx, batch, None))
340340
batch, structure = _flatten_batch(batch)
341341
if use_shared_memory:
342+
# NOTE: In eager mode, Tensor._share_memory has no
343+
# effect, fall back to _array_to_share_memory_tensor
344+
def tensor_share_memory(tensor):
345+
if _in_eager_without_dygraph_check():
346+
return core._array_to_share_memory_tensor(tensor)
347+
return tensor._share_memory()
342348
tensor_list = [
343349
core._array_to_share_memory_tensor(b)
344-
if isinstance(b, np.ndarray) else b._share_memory()
345-
for b in batch
350+
if isinstance(b, np.ndarray) \
351+
else tensor_share_memory(b) for b in batch
346352
]
347353
out_queue.put((idx, tensor_list, structure))
348354
core._remove_tensor_list_mmap_fds(tensor_list)

0 commit comments

Comments
 (0)