diff --git a/hyperactor_mesh/src/logging.rs b/hyperactor_mesh/src/logging.rs index 496f58144..950e0af20 100644 --- a/hyperactor_mesh/src/logging.rs +++ b/hyperactor_mesh/src/logging.rs @@ -11,6 +11,7 @@ use std::fmt; use std::path::Path; use std::path::PathBuf; use std::pin::Pin; +use std::sync::Arc; use std::task::Context as TaskContext; use std::task::Poll; use std::time::Duration; @@ -22,12 +23,15 @@ use chrono::DateTime; use chrono::Local; use hyperactor::Actor; use hyperactor::ActorRef; +use hyperactor::Bind; use hyperactor::Context; use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; +use hyperactor::OncePortRef; use hyperactor::RefClient; +use hyperactor::Unbind; use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelRx; @@ -39,14 +43,12 @@ use hyperactor::channel::TxStatus; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::data::Serialized; -use hyperactor::message::Bind; -use hyperactor::message::Bindings; -use hyperactor::message::Unbind; use hyperactor_telemetry::env; use hyperactor_telemetry::log_file_path; use serde::Deserialize; use serde::Serialize; use tokio::io; +use tokio::sync::Mutex; use tokio::sync::watch::Receiver; use crate::bootstrap::BOOTSTRAP_LOG_CHANNEL; @@ -260,7 +262,11 @@ pub enum LogMessage { }, /// Flush the log - Flush {}, + Flush { + /// Indicate if the current flush is synced or non-synced. + /// If synced, a version number is available. Otherwise, none. + sync_version: Option, + }, } /// Messages that can be sent to the LogClient locally. @@ -279,6 +285,16 @@ pub enum LogClientMessage { /// The time window in seconds to aggregate logs. If None, aggregation is disabled. aggregate_window_sec: Option, }, + + /// Synchronously flush all the logs from all the procs. This is for client to call. + StartSyncFlush { + /// Expect these many procs to ack the flush message. + expected_procs: usize, + /// Return once we have received the acks from all the procs + reply: OncePortRef<()>, + /// Return to the caller the current flush version + version: OncePortRef, + }, } /// Trait for sending logs @@ -352,7 +368,7 @@ impl LogSender for LocalLogSender { // send will make sure message is delivered if TxStatus::Active == *self.status.borrow() { // Do not use tx.send, it will block the allocator as the child process state is unknown. - self.tx.post(LogMessage::Flush {}); + self.tx.post(LogMessage::Flush { sync_version: None }); } else { tracing::debug!( "log sender {} is not active, skip sending flush message", @@ -558,7 +574,9 @@ impl Named, Handler, HandleClient, - RefClient + RefClient, + Bind, + Unbind )] pub enum LogForwardMessage { /// Receive the log from the parent process and forward ti to the client. @@ -566,18 +584,9 @@ pub enum LogForwardMessage { /// If to stream the log back to the client. SetMode { stream_to_client: bool }, -} -impl Bind for LogForwardMessage { - fn bind(&mut self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } -} - -impl Unbind for LogForwardMessage { - fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> { - Ok(()) - } + /// Flush the log with a version number. + ForceSyncFlush { version: u64 }, } /// A log forwarder that receives the log from its parent process and forward it back to the client @@ -588,6 +597,8 @@ impl Unbind for LogForwardMessage { )] pub struct LogForwardActor { rx: ChannelRx, + flush_tx: Arc>>, + next_flush_deadline: SystemTime, logging_client_ref: ActorRef, stream_to_client: bool, } @@ -630,8 +641,15 @@ impl Actor for LogForwardActor { .1 } }; + + // Dial the same channel to send flush message to drain the log queue. + let flush_tx = Arc::new(Mutex::new(channel::dial::(log_channel)?)); + let now = RealClock.system_time_now(); + Ok(Self { rx, + flush_tx, + next_flush_deadline: now, logging_client_ref, stream_to_client: true, }) @@ -639,6 +657,13 @@ impl Actor for LogForwardActor { async fn init(&mut self, this: &Instance) -> Result<(), anyhow::Error> { this.self_message_with_delay(LogForwardMessage::Forward {}, Duration::from_secs(0))?; + + // Make sure we start the flush loop periodically so the log channel will not deadlock. + self.flush_tx + .lock() + .await + .send(LogMessage::Flush { sync_version: None }) + .await?; Ok(()) } } @@ -647,17 +672,48 @@ impl Actor for LogForwardActor { #[hyperactor::forward(LogForwardMessage)] impl LogForwardMessageHandler for LogForwardActor { async fn forward(&mut self, ctx: &Context) -> Result<(), anyhow::Error> { - if let Ok(LogMessage::Log { - hostname, - pid, - output_target, - payload, - }) = self.rx.recv().await - { - if self.stream_to_client { - self.logging_client_ref - .log(ctx, hostname, pid, output_target, payload) - .await?; + match self.rx.recv().await { + Ok(LogMessage::Flush { sync_version }) => { + let now = RealClock.system_time_now(); + match sync_version { + None => { + // Schedule another flush to keep the log channel from deadlocking. + let delay = Duration::from_secs(1); + if now >= self.next_flush_deadline { + self.next_flush_deadline = now + delay; + let flush_tx = self.flush_tx.clone(); + tokio::spawn(async move { + RealClock.sleep(delay).await; + if let Err(e) = flush_tx + .lock() + .await + .send(LogMessage::Flush { sync_version: None }) + .await + { + tracing::error!("failed to send flush message: {}", e); + } + }); + } + } + version => { + self.logging_client_ref.flush(ctx, version).await?; + } + } + } + Ok(LogMessage::Log { + hostname, + pid, + output_target, + payload, + }) => { + if self.stream_to_client { + self.logging_client_ref + .log(ctx, hostname, pid, output_target, payload) + .await?; + } + } + Err(e) => { + return Err(e.into()); } } @@ -675,6 +731,21 @@ impl LogForwardMessageHandler for LogForwardActor { self.stream_to_client = stream_to_client; Ok(()) } + + async fn force_sync_flush( + &mut self, + _cx: &Context, + version: u64, + ) -> Result<(), anyhow::Error> { + self.flush_tx + .lock() + .await + .send(LogMessage::Flush { + sync_version: Some(version), + }) + .await + .map_err(anyhow::Error::from) + } } /// Deserialize a serialized message and split it into UTF-8 lines @@ -707,6 +778,11 @@ pub struct LogClientActor { aggregators: HashMap, last_flush_time: SystemTime, next_flush_deadline: Option, + + // For flush sync barrier + current_flush_version: u64, + current_flush_port: Option>, + current_unflushed_procs: usize, } impl LogClientActor { @@ -736,6 +812,12 @@ impl LogClientActor { OutputTarget::Stderr => eprintln!("{}", message), } } + + fn flush_internal(&mut self) { + self.print_aggregators(); + self.last_flush_time = RealClock.system_time_now(); + self.next_flush_deadline = None; + } } #[async_trait] @@ -754,6 +836,9 @@ impl Actor for LogClientActor { aggregators, last_flush_time: RealClock.system_time_now(), next_flush_deadline: None, + current_flush_version: 0, + current_flush_port: None, + current_unflushed_procs: 0, }) } } @@ -805,20 +890,26 @@ impl LogMessageHandler for LogClientActor { let new_deadline = self.last_flush_time + Duration::from_secs(window); let now = RealClock.system_time_now(); if new_deadline <= now { - self.flush(cx).await?; + self.flush_internal(); } else { let delay = new_deadline.duration_since(now)?; match self.next_flush_deadline { None => { self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } Some(deadline) => { // Some early log lines have alrady triggered the flush. if new_deadline < deadline { // This can happen if the user has adjusted the aggregation window. self.next_flush_deadline = Some(new_deadline); - cx.self_message_with_delay(LogMessage::Flush {}, delay)?; + cx.self_message_with_delay( + LogMessage::Flush { sync_version: None }, + delay, + )?; } } } @@ -829,10 +920,45 @@ impl LogMessageHandler for LogClientActor { Ok(()) } - async fn flush(&mut self, _cx: &Context) -> Result<(), anyhow::Error> { - self.print_aggregators(); - self.last_flush_time = RealClock.system_time_now(); - self.next_flush_deadline = None; + async fn flush( + &mut self, + cx: &Context, + sync_version: Option, + ) -> Result<(), anyhow::Error> { + match sync_version { + None => { + self.flush_internal(); + } + Some(version) => { + if version != self.current_flush_version { + tracing::error!( + "found mismatched flush versions: got {}, expect {}; this can happen if some previous flush didn't finish fully", + version, + self.current_flush_version + ); + return Ok(()); + } + + if self.current_unflushed_procs == 0 || self.current_flush_port.is_none() { + // This is a serious issue; it's better to error out. + anyhow::bail!("found no ongoing flush request"); + } + self.current_unflushed_procs -= 1; + + tracing::debug!( + "ack sync flush: version {}; remaining procs: {}", + self.current_flush_version, + self.current_unflushed_procs + ); + + if self.current_unflushed_procs == 0 { + self.flush_internal(); + let reply = self.current_flush_port.take().unwrap(); + self.current_flush_port = None; + reply.send(cx, ()).map_err(anyhow::Error::from)?; + } + } + } Ok(()) } @@ -853,6 +979,34 @@ impl LogClientMessageHandler for LogClientActor { self.aggregate_window_sec = aggregate_window_sec; Ok(()) } + + async fn start_sync_flush( + &mut self, + cx: &Context, + expected_procs_flushed: usize, + reply: OncePortRef<()>, + version: OncePortRef, + ) -> Result<(), anyhow::Error> { + if self.current_unflushed_procs > 0 || self.current_flush_port.is_some() { + tracing::warn!( + "found unfinished ongoing flush: version {}; {} unflushed procs", + self.current_flush_version, + self.current_unflushed_procs, + ); + } + + self.current_flush_version += 1; + tracing::debug!( + "start sync flush with version {}", + self.current_flush_version + ); + self.current_flush_port = Some(reply.clone()); + self.current_unflushed_procs = expected_procs_flushed; + version + .send(cx, self.current_flush_version) + .map_err(anyhow::Error::from)?; + Ok(()) + } } #[cfg(test)] diff --git a/monarch_extension/src/logging.rs b/monarch_extension/src/logging.rs index 9ff8b208b..51820c019 100644 --- a/monarch_extension/src/logging.rs +++ b/monarch_extension/src/logging.rs @@ -38,6 +38,44 @@ pub struct LoggingMeshClient { client_actor: ActorHandle, } +impl LoggingMeshClient { + async fn flush_internal( + client_actor: ActorHandle, + forwarder_mesh: SharedCell>, + ) -> Result<(), anyhow::Error> { + let forwarder_inner_mesh = forwarder_mesh.borrow().map_err(anyhow::Error::msg)?; + let (reply_tx, reply_rx) = forwarder_inner_mesh + .proc_mesh() + .client() + .open_once_port::<()>(); + let (version_tx, version_rx) = forwarder_inner_mesh + .proc_mesh() + .client() + .open_once_port::(); + + // First initialize a sync flush. + client_actor.send(LogClientMessage::StartSyncFlush { + expected_procs: forwarder_inner_mesh.proc_mesh().shape().slice().len(), + reply: reply_tx.bind(), + version: version_tx.bind(), + })?; + + let version = version_rx.recv().await?; + + // Then ask all the flushers to ask the log forwarders to sync flush + forwarder_inner_mesh.cast( + forwarder_inner_mesh.proc_mesh().client(), + Selection::True, + LogForwardMessage::ForceSyncFlush { version }, + )?; + + // Finally the forwarder will send sync point back to the client, flush, and return. + reply_rx.recv().await?; + + Ok(()) + } +} + #[pymethods] impl LoggingMeshClient { #[staticmethod] @@ -97,6 +135,18 @@ impl LoggingMeshClient { Ok(()) } + + // A sync flush mechanism for the client make sure all the stdout/stderr are streamed back and flushed. + fn flush(&self) -> PyResult { + let forwarder_mesh = self.forwarder_mesh.clone(); + let client_actor = self.client_actor.clone(); + + PyPythonTask::new(async move { + Self::flush_internal(client_actor, forwarder_mesh) + .await + .map_err(|e| PyErr::new::(e.to_string())) + }) + } } impl Drop for LoggingMeshClient { diff --git a/python/monarch/_rust_bindings/monarch_extension/logging.pyi b/python/monarch/_rust_bindings/monarch_extension/logging.pyi index 5d6f11960..fa3d732af 100644 --- a/python/monarch/_rust_bindings/monarch_extension/logging.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/logging.pyi @@ -21,3 +21,4 @@ class LoggingMeshClient: def set_mode( self, stream_to_client: bool, aggregate_window_sec: int | None, level: int ) -> None: ... + def flush(self) -> PythonTask[None]: ... diff --git a/python/tests/python_actor_test_binary.py b/python/tests/python_actor_test_binary.py index 12a10b0f5..9cff72087 100644 --- a/python/tests/python_actor_test_binary.py +++ b/python/tests/python_actor_test_binary.py @@ -10,6 +10,7 @@ import logging import click +from monarch._src.actor.future import Future from monarch.actor import Actor, endpoint, proc_mesh @@ -40,8 +41,10 @@ async def _flush_logs() -> None: for _ in range(5): await am.print.call("has print streaming") - # TODO: will soon be removed by D80051803 - await asyncio.sleep(2) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() @main.command("flush-logs") diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 5d9ce8d98..95a05f23d 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -27,6 +27,7 @@ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask from monarch._src.actor.actor_mesh import ActorMesh, Channel, Port +from monarch._src.actor.future import Future from monarch.actor import ( Accumulator, @@ -559,8 +560,10 @@ async def test_actor_log_streaming() -> None: await am.print.call("has print streaming too") await am.log.call("has log streaming as level matched") - # Give it some time to reflect and aggregate - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -675,7 +678,11 @@ async def test_logging_option_defaults() -> None: for _ in range(5): await am.print.call("print streaming") await am.log.call("log streaming") - await asyncio.sleep(4) + + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -801,8 +808,10 @@ async def test_flush_on_disable_aggregation() -> None: for _ in range(5): await am.print.call("single log line") - # Wait a bit to ensure flush completes - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush() @@ -846,6 +855,32 @@ async def test_flush_on_disable_aggregation() -> None: pass +@pytest.mark.timeout(120) +async def test_multiple_ongoing_flushes_no_deadlock() -> None: + """ + The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked. + Because now a flush call is purely sync, it is very easy to get into a deadlock. + So we assert the last flush call will not get into such a state. + """ + pm = await proc_mesh(gpus=4) + am = await pm.spawn("printer", Printer) + + # Generate some logs that will be aggregated but not flushed immediately + for _ in range(10): + await am.print.call("aggregated log line") + + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + futures = [] + for _ in range(5): + # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature. + await asyncio.sleep(0.1) + futures.append(Future(coro=log_mesh.flush().spawn().task())) + + # The last flush should not block + futures[-1].get() + + @pytest.mark.timeout(60) async def test_adjust_aggregation_window() -> None: """Test that the flush deadline is updated when the aggregation window is adjusted. @@ -886,8 +921,10 @@ async def test_adjust_aggregation_window() -> None: for _ in range(3): await am.print.call("second batch of logs") - # Wait just enough time for the shorter window to trigger a flush - await asyncio.sleep(1) + # TODO: remove this completely once we hook the flush logic upon dropping device_mesh + log_mesh = pm._logging_mesh_client + assert log_mesh is not None + Future(coro=log_mesh.flush().spawn().task()).get() # Flush all outputs stdout_file.flush()