|
18 | 18 | from typing import Any, Dict, List, Union
|
19 | 19 |
|
20 | 20 | from ... import oscar as mo
|
| 21 | +from ...lib.aio import AioFileObject |
21 | 22 | from ...storage import StorageLevel, get_storage_backend
|
22 | 23 | from ...storage.core import StorageFileObject
|
23 | 24 | from ...typing import BandType
|
|
29 | 30 | DataManagerActor,
|
30 | 31 | DataInfo,
|
31 | 32 | build_data_info,
|
32 |
| - WrappedStorageFileObject, |
33 | 33 | )
|
34 | 34 | from .errors import DataNotExist, NoDataToSpill
|
35 | 35 |
|
|
39 | 39 | logger = logging.getLogger(__name__)
|
40 | 40 |
|
41 | 41 |
|
| 42 | +class WrappedStorageFileObject(AioFileObject): |
| 43 | + """ |
| 44 | + Wrap to hold ref after write close |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + file: StorageFileObject, |
| 50 | + level: StorageLevel, |
| 51 | + size: int, |
| 52 | + session_id: str, |
| 53 | + data_key: str, |
| 54 | + storage_handler: mo.ActorRefType["StorageHandlerActor"], |
| 55 | + ): |
| 56 | + self._object_id = file.object_id |
| 57 | + super().__init__(file) |
| 58 | + self._size = size |
| 59 | + self._level = level |
| 60 | + self._session_id = session_id |
| 61 | + self._data_key = data_key |
| 62 | + self._storage_handler = storage_handler |
| 63 | + |
| 64 | + def __getattr__(self, item): |
| 65 | + return getattr(self._file, item) |
| 66 | + |
| 67 | + @property |
| 68 | + def file(self): |
| 69 | + return self._file |
| 70 | + |
| 71 | + @property |
| 72 | + def object_id(self): |
| 73 | + return self._object_id |
| 74 | + |
| 75 | + @property |
| 76 | + def level(self): |
| 77 | + return self._level |
| 78 | + |
| 79 | + @property |
| 80 | + def size(self): |
| 81 | + return self._size |
| 82 | + |
| 83 | + async def clean_up(self): |
| 84 | + self._file.close() |
| 85 | + |
| 86 | + async def close(self): |
| 87 | + await self._storage_handler.close_writer(self) |
| 88 | + |
| 89 | + |
42 | 90 | class StorageHandlerActor(mo.Actor):
|
43 | 91 | """
|
44 | 92 | Storage handler actor, provide methods like `get`, `put`, etc.
|
@@ -360,8 +408,7 @@ async def open_writer(
|
360 | 408 | size,
|
361 | 409 | session_id,
|
362 | 410 | data_key,
|
363 |
| - self._data_manager_ref, |
364 |
| - self._clients[level], |
| 411 | + self, |
365 | 412 | )
|
366 | 413 |
|
367 | 414 | @open_writer.batch
|
@@ -392,12 +439,48 @@ async def batch_open_writers(self, args_list, kwargs_list):
|
392 | 439 | size,
|
393 | 440 | session_id,
|
394 | 441 | data_key,
|
395 |
| - self._data_manager_ref, |
396 |
| - self._clients[level], |
| 442 | + self, |
397 | 443 | )
|
398 | 444 | )
|
399 | 445 | return wrapped_writers
|
400 | 446 |
|
| 447 | + @mo.extensible |
| 448 | + async def close_writer(self, writer: WrappedStorageFileObject): |
| 449 | + writer.file.close() |
| 450 | + if writer.object_id is None: |
| 451 | + # for some backends like vineyard, |
| 452 | + # object id is generated after write close |
| 453 | + writer._object_id = writer.file.object_id |
| 454 | + if "w" in writer.file.mode: |
| 455 | + client = self._clients[writer.level] |
| 456 | + object_info = await client.object_info(writer.object_id) |
| 457 | + data_info = build_data_info(object_info, writer.level, writer.size) |
| 458 | + await self._data_manager_ref.put_data_info( |
| 459 | + writer._session_id, writer._data_key, data_info, object_info |
| 460 | + ) |
| 461 | + |
| 462 | + @close_writer.batch |
| 463 | + async def batch_close_writers(self, args_list, kwargs_list): |
| 464 | + put_info_tasks = [] |
| 465 | + for args, kwargs in zip(args_list, kwargs_list): |
| 466 | + (writer,) = self.close_writer.bind(*args, **kwargs) |
| 467 | + writer.file.close() |
| 468 | + if writer.object_id is None: |
| 469 | + # for some backends like vineyard, |
| 470 | + # object id is generated after write close |
| 471 | + writer._object_id = writer.file.object_id |
| 472 | + if "w" in writer.file.mode: |
| 473 | + client = self._clients[writer.level] |
| 474 | + object_info = await client.object_info(writer.object_id) |
| 475 | + data_info = build_data_info(object_info, writer.level, writer.size) |
| 476 | + put_info_tasks.append( |
| 477 | + self._data_manager_ref.put_data_info.delay( |
| 478 | + writer._session_id, writer._data_key, data_info, object_info |
| 479 | + ) |
| 480 | + ) |
| 481 | + if put_info_tasks: |
| 482 | + await self._data_manager_ref.put_data_info.batch(*put_info_tasks) |
| 483 | + |
401 | 484 | async def _get_meta_api(self, session_id: str):
|
402 | 485 | if self._supervisor_address is None:
|
403 | 486 | cluster_api = await ClusterAPI.create(self.address)
|
|
0 commit comments