Skip to content

Commit e62540e

Browse files
committed
Fix ut
1 parent 2376990 commit e62540e

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

mars/services/storage/tests/test_transfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
265265

266266
send_task = asyncio.create_task(
267267
sender_actor.send_batch_data(
268-
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY
268+
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
269269
)
270270
)
271271

@@ -283,7 +283,7 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
283283

284284
send_task = asyncio.create_task(
285285
sender_actor.send_batch_data(
286-
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY
286+
"mock", ["data_key1"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
287287
)
288288
)
289289
await send_task
@@ -294,12 +294,12 @@ async def test_cancel_transfer(create_actors, mock_sender, mock_receiver):
294294
if mock_sender is MockSenderManagerActor:
295295
send_task1 = asyncio.create_task(
296296
sender_actor.send_batch_data(
297-
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY
297+
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
298298
)
299299
)
300300
send_task2 = asyncio.create_task(
301301
sender_actor.send_batch_data(
302-
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY
302+
"mock", ["data_key2"], worker_address_2, StorageLevel.MEMORY, is_small_objects=False
303303
)
304304
)
305305
await asyncio.sleep(0.5)

mars/services/storage/transfer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def get_receiver_ref(address: str, band_name: str):
9393

9494
async def _send_data(
9595
self,
96-
receiver_ref: Union[mo.ActorRef],
96+
receiver_ref: Union[mo.ActorRef, "ReceiverManagerActor"],
9797
session_id: str,
9898
data_keys: List[str],
9999
data_infos: List[DataInfo],
@@ -218,6 +218,7 @@ async def send_batch_data(
218218
level: StorageLevel,
219219
band_name: str = "numa-0",
220220
block_size: int = None,
221+
is_small_objects=None,
221222
error: str = "raise",
222223
):
223224
logger.debug(
@@ -253,7 +254,17 @@ async def send_batch_data(
253254
if level is None:
254255
level = infos[0].level
255256
total_size = sum(data_sizes)
256-
if total_size > block_size:
257+
if is_small_objects is None:
258+
is_small_objects = total_size <= block_size
259+
if is_small_objects:
260+
logger.debug(
261+
"Choose send_small_objects method for sending data of %s bytes",
262+
total_size,
263+
)
264+
await self._send_small_objects(
265+
session_id, data_keys, infos, address, band_name, level
266+
)
267+
else:
257268
logger.debug("Choose block method for sending data of %s bytes", total_size)
258269
await self._send(
259270
session_id,
@@ -265,14 +276,6 @@ async def send_batch_data(
265276
band_name,
266277
level,
267278
)
268-
else:
269-
logger.debug(
270-
"Choose send_small_objects method for sending data of %s bytes",
271-
total_size,
272-
)
273-
await self._send_small_objects(
274-
session_id, data_keys, infos, address, band_name, level
275-
)
276279
unpin_tasks = []
277280
for data_key in data_keys:
278281
unpin_tasks.append(

0 commit comments

Comments
 (0)