|
22 | 22 | from .. import core
|
23 | 23 | from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
|
24 | 24 | from ..multiprocess_utils import _cleanup_mmap, CleanupFuncRegistrar, MP_STATUS_CHECK_INTERVAL
|
25 |
| -from ..framework import _non_static_mode |
| 25 | +from ..framework import _non_static_mode, _in_eager_without_dygraph_check |
26 | 26 | from .flat import _flatten_batch
|
27 | 27 |
|
28 | 28 | # NOTE: queue has a different name in python2 and python3
|
@@ -339,10 +339,16 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
|
339 | 339 | out_queue.put((idx, batch, None))
|
340 | 340 | batch, structure = _flatten_batch(batch)
|
341 | 341 | if use_shared_memory:
|
| 342 | + # NOTE: In eager mode, Tensor._share_memory has no |
| 343 | + # effect, fall back to _array_to_share_memory_tensor |
| 344 | + def tensor_share_memory(tensor): |
| 345 | + if _in_eager_without_dygraph_check(): |
| 346 | + return core._array_to_share_memory_tensor(tensor) |
| 347 | + return tensor._share_memory() |
342 | 348 | tensor_list = [
|
343 | 349 | core._array_to_share_memory_tensor(b)
|
344 |
| - if isinstance(b, np.ndarray) else b._share_memory() |
345 |
| - for b in batch |
| 350 | + if isinstance(b, np.ndarray) \ |
| 351 | + else tensor_share_memory(b) for b in batch |
346 | 352 | ]
|
347 | 353 | out_queue.put((idx, tensor_list, structure))
|
348 | 354 | core._remove_tensor_list_mmap_fds(tensor_list)
|
|
0 commit comments