Skip to content

Commit a8139fb

Browse files
samluryefacebook-github-bot
authored andcommitted
Semi-private python API for overriding handle_undeliverable_message inside PythonActor (#964)
Summary: Pull Request resolved: #964 This diff makes undeliverable message handling overridable for python actors, using the newly introduced `Actor._handle_undeliverable_message` method. Previously, the rust implementation of `PythonActor` simply used the default `Actor::handle_undeliverable_message` implementation. Now, `PythonActor` overrides `handle_undeliverable_message` to call into the corresponding method on the underlying python class. ghstack-source-id: 305073646 Reviewed By: mariusae Differential Revision: D79841379 fbshipit-source-id: 2d285b52df401ed01813ef57e39e158f0104bea1
1 parent 773266a commit a8139fb

File tree

6 files changed

+211
-21
lines changed

6 files changed

+211
-21
lines changed

monarch_hyperactor/src/actor.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use hyperactor::Handler;
2121
use hyperactor::Instance;
2222
use hyperactor::Named;
2323
use hyperactor::OncePortHandle;
24+
use hyperactor::mailbox::MessageEnvelope;
25+
use hyperactor::mailbox::Undeliverable;
2426
use hyperactor::message::Bind;
2527
use hyperactor::message::Bindings;
2628
use hyperactor::message::Unbind;
@@ -50,6 +52,7 @@ use crate::local_state_broker::BrokerId;
5052
use crate::local_state_broker::LocalStateBrokerMessage;
5153
use crate::mailbox::EitherPortRef;
5254
use crate::mailbox::PyMailbox;
55+
use crate::mailbox::PythonUndeliverableMessageEnvelope;
5356
use crate::proc::InstanceWrapper;
5457
use crate::proc::PyActorId;
5558
use crate::proc::PyProc;
@@ -498,6 +501,41 @@ impl Actor for PythonActor {
498501
);
499502
Ok(())
500503
}
504+
505+
async fn handle_undeliverable_message(
506+
&mut self,
507+
cx: &Instance<Self>,
508+
envelope: Undeliverable<MessageEnvelope>,
509+
) -> Result<(), anyhow::Error> {
510+
assert_eq!(envelope.0.sender(), cx.self_id());
511+
512+
let (envelope, handled) = Python::with_gil(|py| {
513+
let py_envelope = PythonUndeliverableMessageEnvelope {
514+
inner: Some(envelope),
515+
}
516+
.into_bound_py_any(py)?;
517+
let handled = self
518+
.actor
519+
.call_method(py, "_handle_undeliverable_message", (&py_envelope,), None)
520+
.map_err(|err| anyhow::Error::from(SerializablePyErr::from(py, &err)))?
521+
.extract::<bool>(py)?;
522+
Ok::<_, anyhow::Error>((
523+
py_envelope
524+
.downcast::<PythonUndeliverableMessageEnvelope>()
525+
.map_err(PyErr::from)?
526+
.try_borrow_mut()
527+
.map_err(PyErr::from)?
528+
.take()?,
529+
handled,
530+
))
531+
})?;
532+
533+
if !handled {
534+
<Self as Actor>::handle_undeliverable_message(self, cx, envelope).await
535+
} else {
536+
Ok(())
537+
}
538+
}
501539
}
502540

503541
/// Create a new TaskLocals with its own asyncio event loop in a dedicated thread.

monarch_hyperactor/src/mailbox.rs

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,53 @@ impl PythonPortReceiver {
421421
module = "monarch._rust_bindings.monarch_hyperactor.mailbox"
422422
)]
423423
pub(crate) struct PythonUndeliverableMessageEnvelope {
424-
#[allow(dead_code)] // At this time, field `inner` isn't read.
425-
pub(crate) inner: Undeliverable<MessageEnvelope>,
424+
pub(crate) inner: Option<Undeliverable<MessageEnvelope>>,
425+
}
426+
427+
impl PythonUndeliverableMessageEnvelope {
428+
fn inner(&self) -> PyResult<&Undeliverable<MessageEnvelope>> {
429+
self.inner.as_ref().ok_or_else(|| {
430+
PyErr::new::<PyRuntimeError, _>(
431+
"PythonUndeliverableMessageEnvelope was already consumed",
432+
)
433+
})
434+
}
435+
436+
pub(crate) fn take(&mut self) -> anyhow::Result<Undeliverable<MessageEnvelope>> {
437+
self.inner.take().ok_or_else(|| {
438+
anyhow::anyhow!("PythonUndeliverableMessageEnvelope was already consumed")
439+
})
440+
}
441+
}
442+
443+
#[pymethods]
444+
impl PythonUndeliverableMessageEnvelope {
445+
fn __repr__(&self) -> PyResult<String> {
446+
Ok(format!(
447+
"UndeliverableMessageEnvelope(sender={}, dest={}, error={})",
448+
self.inner()?.0.sender(),
449+
self.inner()?.0.dest(),
450+
self.error_msg()?
451+
))
452+
}
453+
454+
fn sender(&self) -> PyResult<PyActorId> {
455+
Ok(PyActorId {
456+
inner: self.inner()?.0.sender().clone(),
457+
})
458+
}
459+
460+
fn dest(&self) -> PyResult<PyPortId> {
461+
Ok(self.inner()?.0.dest().clone().into())
462+
}
463+
464+
fn error_msg(&self) -> PyResult<String> {
465+
Ok(self
466+
.inner()?
467+
.0
468+
.error()
469+
.map_or("None".to_string(), |e| e.to_string()))
470+
}
426471
}
427472

428473
#[derive(Debug)]
@@ -445,7 +490,9 @@ impl PythonUndeliverablePortReceiver {
445490
.recv()
446491
.await
447492
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))?;
448-
Ok(PythonUndeliverableMessageEnvelope { inner: message })
493+
Ok(PythonUndeliverableMessageEnvelope {
494+
inner: Some(message),
495+
})
449496
})
450497
}
451498

@@ -457,7 +504,9 @@ impl PythonUndeliverablePortReceiver {
457504
let message = signal_safe_block_on(py, async move { receiver.lock().await.recv().await })?
458505
.map_err(|err| PyErr::new::<PyEOFError, _>(format!("Port closed: {}", err)))?;
459506

460-
Ok(PythonUndeliverableMessageEnvelope { inner: message })
507+
Ok(PythonUndeliverableMessageEnvelope {
508+
inner: Some(message),
509+
})
461510
}
462511
}
463512

@@ -796,5 +845,6 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
796845
hyperactor_mod.add_class::<PythonOncePortReceiver>()?;
797846
hyperactor_mod.add_class::<Instance>()?;
798847
hyperactor_mod.add_class::<Context>()?;
848+
hyperactor_mod.add_class::<PythonUndeliverableMessageEnvelope>()?;
799849
Ok(())
800850
}

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,15 +210,6 @@ class PythonMessage:
210210
@property
211211
def kind(self) -> PythonMessageKind: ...
212212

213-
class UndeliverableMessageEnvelope:
214-
"""
215-
An envelope representing a message that could not be delivered.
216-
217-
This object is opaque; its contents are not accessible from Python.
218-
"""
219-
220-
...
221-
222213
@final
223214
class PythonActorHandle:
224215
"""

python/monarch/_rust_bindings/monarch_hyperactor/mailbox.pyi

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
from typing import final, Protocol
1010

11-
from monarch._rust_bindings.monarch_hyperactor.actor import (
12-
PythonMessage,
13-
UndeliverableMessageEnvelope,
14-
)
11+
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
1512

1613
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
1714
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
@@ -20,9 +17,9 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
2017

2118
@final
2219
class PortId:
23-
def __init__(self, actor_id: ActorId, index: int) -> None:
20+
def __init__(self, *, actor_id: ActorId, port: int) -> None:
2421
"""
25-
Create a new port id given an actor id and an index.
22+
Create a new port id given an actor id and a port index.
2623
"""
2724
...
2825
def __repr__(self) -> str: ...
@@ -68,6 +65,12 @@ class PortRef:
6865
A reference to a remote port over which PythonMessages can be sent.
6966
"""
7067

68+
def __init__(self, port_id: PortId) -> None:
69+
"""
70+
Create a new port ref given a port id.
71+
"""
72+
...
73+
7174
def send(self, mailbox: Mailbox, message: PythonMessage) -> None:
7275
"""Send a single message to the port's receiver."""
7376
...
@@ -218,3 +221,27 @@ class Reducer(Protocol):
218221
219222
This method's Rust counterpart is `CommReducer::reduce`.
220223
"""
224+
225+
class UndeliverableMessageEnvelope:
226+
"""
227+
An envelope representing a message that could not be delivered.
228+
"""
229+
230+
def __repr__(self) -> str: ...
231+
def sender(self) -> ActorId:
232+
"""
233+
The actor id of the sender.
234+
"""
235+
...
236+
237+
def dest(self) -> PortId:
238+
"""
239+
The port id of the destination.
240+
"""
241+
...
242+
243+
def error_msg(self) -> str:
244+
"""
245+
The error message describing why the message could not be delivered.
246+
"""
247+
...

python/monarch/_src/actor/actor_mesh.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
OncePortRef,
5656
PortReceiver as HyPortReceiver,
5757
PortRef,
58+
UndeliverableMessageEnvelope,
5859
)
5960

6061
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
@@ -895,6 +896,17 @@ def _post_mortem_debug(self, exc_tb) -> None:
895896
pdb_wrapper.post_mortem(exc_tb)
896897
self._maybe_exit_debugger(do_continue=False)
897898

899+
def _handle_undeliverable_message(
900+
self, message: UndeliverableMessageEnvelope
901+
) -> bool:
902+
handle_undeliverable = getattr(
903+
self.instance, "_handle_undeliverable_message", None
904+
)
905+
if handle_undeliverable is not None:
906+
return handle_undeliverable(message)
907+
else:
908+
return False
909+
898910

899911
def _is_mailbox(x: object) -> bool:
900912
if hasattr(x, "__monarch_ref__"):
@@ -943,6 +955,12 @@ def initialized(self):
943955
"actor implementations are not meshes, but we can't convince the typechecker of it..."
944956
)
945957

958+
def _handle_undeliverable_message(
959+
self, message: UndeliverableMessageEnvelope
960+
) -> bool:
961+
# Return False to indicate that the undeliverable message was not handled.
962+
return False
963+
946964

947965
class ActorMesh(MeshTrait, Generic[T], DeprecatedNotAFuture):
948966
def __init__(

python/tests/test_python_actors.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,24 @@
1919
import unittest
2020
import unittest.mock
2121
from types import ModuleType
22-
from typing import cast
22+
from typing import cast, Tuple
2323

2424
import pytest
2525

2626
import torch
27+
from monarch._rust_bindings.monarch_hyperactor.actor import (
28+
PythonMessage,
29+
PythonMessageKind,
30+
)
31+
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
32+
PortId,
33+
PortRef,
34+
UndeliverableMessageEnvelope,
35+
)
36+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
2737
from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask
2838

29-
from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port
39+
from monarch._src.actor.actor_mesh import ActorMesh, Channel, context, Port
3040
from monarch._src.actor.future import Future
3141

3242
from monarch.actor import (
@@ -1127,3 +1137,59 @@ def test_mesh_len():
11271137
proc_mesh = local_proc_mesh(gpus=12).get()
11281138
s = proc_mesh.spawn("sync_actor", SyncActor).get()
11291139
assert 12 == len(s)
1140+
1141+
1142+
class UndeliverableMessageReceiver(Actor):
1143+
def __init__(self):
1144+
self._messages = asyncio.Queue()
1145+
1146+
@endpoint
1147+
async def receive_undeliverable(
1148+
self, sender: ActorId, dest: PortId, error_msg: str
1149+
) -> None:
1150+
await self._messages.put((sender, dest, error_msg))
1151+
1152+
@endpoint
1153+
async def get_messages(self) -> Tuple[ActorId, PortId, str]:
1154+
return await self._messages.get()
1155+
1156+
1157+
class UndeliverableMessageSender(Actor):
1158+
def __init__(self, receiver: UndeliverableMessageReceiver):
1159+
self._receiver = receiver
1160+
1161+
@endpoint
1162+
def send_undeliverable(self) -> None:
1163+
mailbox = context().actor_instance._mailbox
1164+
port_id = PortId(
1165+
actor_id=ActorId(
1166+
world_name=mailbox.actor_id.world_name, rank=0, actor_name="bogus"
1167+
),
1168+
port=1234,
1169+
)
1170+
port_ref = PortRef(port_id)
1171+
port_ref.send(
1172+
mailbox,
1173+
PythonMessage(PythonMessageKind.Result(None), b"123"),
1174+
)
1175+
1176+
def _handle_undeliverable_message(
1177+
self, message: UndeliverableMessageEnvelope
1178+
) -> bool:
1179+
self._receiver.receive_undeliverable.call_one(
1180+
message.sender(), message.dest(), message.error_msg()
1181+
).get()
1182+
return True
1183+
1184+
1185+
@pytest.mark.timeout(60)
1186+
async def test_undeliverable_message() -> None:
1187+
pm = proc_mesh(gpus=1)
1188+
receiver = pm.spawn("undeliverable_receiver", UndeliverableMessageReceiver)
1189+
sender = pm.spawn("undeliverable_sender", UndeliverableMessageSender, receiver)
1190+
sender.send_undeliverable.call().get()
1191+
sender, dest, error_msg = receiver.get_messages.call_one().get()
1192+
assert sender.actor_name == "undeliverable_sender"
1193+
assert dest.actor_id.actor_name == "bogus"
1194+
assert error_msg is not None
1195+
pm.stop().get()

0 commit comments

Comments
 (0)