Skip to content

Commit 99d086e

Browse files
committed
Add batch method for closing writers
1 parent 324d598 commit 99d086e

File tree

5 files changed

+122
-70
lines changed

5 files changed

+122
-70
lines changed

mars/services/storage/api/oscar.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323
StorageManagerActor,
2424
DataManagerActor,
2525
DataInfo,
26-
WrappedStorageFileObject,
2726
)
28-
from ..handler import StorageHandlerActor
27+
from ..handler import StorageHandlerActor, WrappedStorageFileObject
2928
from .core import AbstractStorageAPI
3029

3130
_is_windows = sys.platform.lower().startswith("win")

mars/services/storage/core.py

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
from typing import Dict, List, Optional, Union, Tuple
2020

2121
from ... import oscar as mo
22-
from ...lib.aio import AioFileObject
2322
from ...oscar.backends.allocate_strategy import IdleLabel, NoIdleSlot
2423
from ...resource import cuda_card_stats
2524
from ...storage import StorageLevel, get_storage_backend
26-
from ...storage.base import ObjectInfo, StorageBackend
27-
from ...storage.core import StorageFileObject
25+
from ...storage.base import ObjectInfo
2826
from ...utils import dataslots
2927
from .errors import DataNotExist, StorageFull
3028

@@ -44,50 +42,6 @@ def build_data_info(storage_info: ObjectInfo, level, size, band_name=None):
4442
return DataInfo(storage_info.object_id, level, size, store_size, band_name)
4543

4644

47-
class WrappedStorageFileObject(AioFileObject):
48-
"""
49-
Wrap to hold ref after write close
50-
"""
51-
52-
def __init__(
53-
self,
54-
file: StorageFileObject,
55-
level: StorageLevel,
56-
size: int,
57-
session_id: str,
58-
data_key: str,
59-
data_manager: mo.ActorRefType["DataManagerActor"],
60-
storage_handler: StorageBackend,
61-
):
62-
self._object_id = file.object_id
63-
super().__init__(file)
64-
self._size = size
65-
self._level = level
66-
self._session_id = session_id
67-
self._data_key = data_key
68-
self._data_manager = data_manager
69-
self._storage_handler = storage_handler
70-
71-
def __getattr__(self, item):
72-
return getattr(self._file, item)
73-
74-
async def clean_up(self):
75-
self._file.close()
76-
77-
async def close(self):
78-
self._file.close()
79-
if self._object_id is None:
80-
# for some backends like vineyard,
81-
# object id is generated after write close
82-
self._object_id = self._file.object_id
83-
if "w" in self._file.mode:
84-
object_info = await self._storage_handler.object_info(self._object_id)
85-
data_info = build_data_info(object_info, self._level, self._size)
86-
await self._data_manager.put_data_info(
87-
self._session_id, self._data_key, data_info, object_info
88-
)
89-
90-
9145
class StorageQuotaActor(mo.Actor):
9246
def __init__(
9347
self,

mars/services/storage/handler.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Dict, List, Union
1919

2020
from ... import oscar as mo
21+
from ...lib.aio import AioFileObject
2122
from ...storage import StorageLevel, get_storage_backend
2223
from ...storage.core import StorageFileObject
2324
from ...typing import BandType
@@ -29,7 +30,6 @@
2930
DataManagerActor,
3031
DataInfo,
3132
build_data_info,
32-
WrappedStorageFileObject,
3333
)
3434
from .errors import DataNotExist, NoDataToSpill
3535

@@ -39,6 +39,54 @@
3939
logger = logging.getLogger(__name__)
4040

4141

42+
class WrappedStorageFileObject(AioFileObject):
43+
"""
44+
Wrap to hold ref after write close
45+
"""
46+
47+
def __init__(
48+
self,
49+
file: StorageFileObject,
50+
level: StorageLevel,
51+
size: int,
52+
session_id: str,
53+
data_key: str,
54+
storage_handler: mo.ActorRefType["StorageHandlerActor"],
55+
):
56+
self._object_id = file.object_id
57+
super().__init__(file)
58+
self._size = size
59+
self._level = level
60+
self._session_id = session_id
61+
self._data_key = data_key
62+
self._storage_handler = storage_handler
63+
64+
def __getattr__(self, item):
65+
return getattr(self._file, item)
66+
67+
@property
68+
def file(self):
69+
return self._file
70+
71+
@property
72+
def object_id(self):
73+
return self._object_id
74+
75+
@property
76+
def level(self):
77+
return self._level
78+
79+
@property
80+
def size(self):
81+
return self._size
82+
83+
async def clean_up(self):
84+
self._file.close()
85+
86+
async def close(self):
87+
await self._storage_handler.close_writer(self)
88+
89+
4290
class StorageHandlerActor(mo.Actor):
4391
"""
4492
Storage handler actor, provide methods like `get`, `put`, etc.
@@ -360,8 +408,7 @@ async def open_writer(
360408
size,
361409
session_id,
362410
data_key,
363-
self._data_manager_ref,
364-
self._clients[level],
411+
self,
365412
)
366413

367414
@open_writer.batch
@@ -392,12 +439,48 @@ async def batch_open_writers(self, args_list, kwargs_list):
392439
size,
393440
session_id,
394441
data_key,
395-
self._data_manager_ref,
396-
self._clients[level],
442+
self,
397443
)
398444
)
399445
return wrapped_writers
400446

447+
@mo.extensible
448+
async def close_writer(self, writer: WrappedStorageFileObject):
449+
writer.file.close()
450+
if writer.object_id is None:
451+
# for some backends like vineyard,
452+
# object id is generated after write close
453+
writer._object_id = writer.file.object_id
454+
if "w" in writer.file.mode:
455+
client = self._clients[writer.level]
456+
object_info = await client.object_info(writer.object_id)
457+
data_info = build_data_info(object_info, writer.level, writer.size)
458+
await self._data_manager_ref.put_data_info(
459+
writer._session_id, writer._data_key, data_info, object_info
460+
)
461+
462+
@close_writer.batch
463+
async def batch_close_writers(self, args_list, kwargs_list):
464+
put_info_tasks = []
465+
for args, kwargs in zip(args_list, kwargs_list):
466+
(writer,) = self.close_writer.bind(*args, **kwargs)
467+
writer.file.close()
468+
if writer.object_id is None:
469+
# for some backends like vineyard,
470+
# object id is generated after write close
471+
writer._object_id = writer.file.object_id
472+
if "w" in writer.file.mode:
473+
client = self._clients[writer.level]
474+
object_info = await client.object_info(writer.object_id)
475+
data_info = build_data_info(object_info, writer.level, writer.size)
476+
put_info_tasks.append(
477+
self._data_manager_ref.put_data_info.delay(
478+
writer._session_id, writer._data_key, data_info, object_info
479+
)
480+
)
481+
if put_info_tasks:
482+
await self._data_manager_ref.put_data_info.batch(*put_info_tasks)
483+
401484
async def _get_meta_api(self, session_id: str):
402485
if self._supervisor_address is None:
403486
cluster_api = await ClusterAPI.create(self.address)

mars/services/storage/tests/test_transfer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,11 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
266266

267267
send_task = asyncio.create_task(
268268
sender_actor.send_batch_data(
269-
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
269+
"mock",
270+
["data_key1"],
271+
worker_address_2,
272+
StorageLevel.MEMORY,
273+
is_small_objects=False,
270274
)
271275
)
272276

@@ -284,7 +288,11 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
284288

285289
send_task = asyncio.create_task(
286290
sender_actor.send_batch_data(
287-
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
291+
"mock",
292+
["data_key1"],
293+
worker_address_2,
294+
StorageLevel.MEMORY,
295+
is_small_objects=False,
288296
)
289297
)
290298
await send_task
@@ -295,12 +303,20 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
295303
if mock_sender is MockSenderManagerActor:
296304
send_task1 = asyncio.create_task(
297305
sender_actor.send_batch_data(
298-
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
306+
"mock",
307+
["data_key2"],
308+
worker_address_2,
309+
StorageLevel.MEMORY,
310+
is_small_objects=False,
299311
)
300312
)
301313
send_task2 = asyncio.create_task(
302314
sender_actor.send_batch_data(
303-
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
315+
"mock",
316+
["data_key2"],
317+
worker_address_2,
318+
StorageLevel.MEMORY,
319+
is_small_objects=False,
304320
)
305321
)
306322
await asyncio.sleep(0.5)

mars/services/storage/transfer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from ...lib.aio import alru_cache
2222
from ...storage import StorageLevel
2323
from ...utils import dataslots
24-
from .core import DataManagerActor, WrappedStorageFileObject, DataInfo
25-
from .handler import StorageHandlerActor
24+
from .core import DataManagerActor, DataInfo
25+
from .handler import StorageHandlerActor, WrappedStorageFileObject
2626

2727
DEFAULT_TRANSFER_BLOCK_SIZE = 4 * 1024**2
2828

@@ -96,9 +96,7 @@ async def send(self, buffer, eof_mark, key):
9696
open_reader_tasks = []
9797
storage_client = await self._storage_handler.get_client(level)
9898
for info in data_infos:
99-
open_reader_tasks.append(
100-
storage_client.open_reader(info.object_id)
101-
)
99+
open_reader_tasks.append(storage_client.open_reader(info.object_id))
102100
readers = await asyncio.gather(*open_reader_tasks)
103101

104102
for data_key, reader in zip(data_keys, readers):
@@ -129,7 +127,9 @@ async def _send(
129127
band_name: str,
130128
level: StorageLevel,
131129
):
132-
receiver_ref: mo.ActorRefType[ReceiverManagerActor] = await self.get_receiver_ref(address, band_name)
130+
receiver_ref: mo.ActorRefType[
131+
ReceiverManagerActor
132+
] = await self.get_receiver_ref(address, band_name)
133133
is_transferring_list = await receiver_ref.open_writers(
134134
session_id, data_keys, data_sizes, level
135135
)
@@ -163,11 +163,11 @@ async def _send_small_objects(
163163
):
164164
# simple get all objects and send them all to receiver
165165
storage_client = await self._storage_handler.get_client(level)
166-
get_tasks = [
167-
storage_client.get(info.object_id) for info in data_infos
168-
]
166+
get_tasks = [storage_client.get(info.object_id) for info in data_infos]
169167
data_list = list(await asyncio.gather(*get_tasks))
170-
receiver_ref: mo.ActorRefType[ReceiverManagerActor] = await self.get_receiver_ref(address, band_name)
168+
receiver_ref: mo.ActorRefType[
169+
ReceiverManagerActor
170+
] = await self.get_receiver_ref(address, band_name)
171171
await receiver_ref.put_small_objects(session_id, data_keys, data_list, level)
172172

173173
async def send_batch_data(
@@ -358,9 +358,9 @@ async def do_write(
358358
if data:
359359
await writer.write(data)
360360
if is_eof:
361-
close_tasks.append(writer.close())
361+
close_tasks.append(self._storage_handler.close_writer.delay(writer))
362362
finished_keys.append(data_key)
363-
await asyncio.gather(*close_tasks)
363+
await self._storage_handler.close_writer.batch(*close_tasks)
364364
async with self._lock:
365365
for data_key in finished_keys:
366366
event = self._writing_infos[(session_id, data_key)].event

0 commit comments

Comments
 (0)