Skip to content

Commit 5b1fa93

Browse files
highkerfacebook-github-bot
authored andcommitted
flush log upon ipython notebook cell exit (#816)
Summary: In ipython notebook, a cell can end fast. Yet the process can still run in the background. However, the background process will not flush logs to the existing cell anymore. The patch registers the flush function upon a cell exiting. Reviewed By: ahmadsharif1 Differential Revision: D79982702
1 parent a8139fb commit 5b1fa93

File tree

4 files changed

+255
-20
lines changed

4 files changed

+255
-20
lines changed

python/monarch/_src/actor/logging.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import gc
10+
import logging
11+
12+
from typing import Callable
13+
14+
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
15+
16+
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
17+
from monarch._src.actor.future import Future
18+
19+
IN_IPYTHON = False
20+
try:
21+
# Check if we are in ipython environment
22+
# pyre-ignore[21]
23+
from IPython import get_ipython
24+
25+
# pyre-ignore[21]
26+
from IPython.core.interactiveshell import ExecutionResult
27+
28+
IN_IPYTHON = get_ipython() is not None
29+
except ImportError:
30+
pass
31+
32+
33+
class LoggingManager:
34+
def __init__(self) -> None:
35+
self._logging_mesh_client: LoggingMeshClient | None = None
36+
self._ipython_flush_logs_handler: Callable[..., None] | None = None
37+
38+
async def init(self, proc_mesh: HyProcMesh) -> None:
39+
if self._logging_mesh_client is not None:
40+
return
41+
42+
self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh)
43+
self._logging_mesh_client.set_mode(
44+
stream_to_client=True,
45+
aggregate_window_sec=3,
46+
level=logging.INFO,
47+
)
48+
49+
if IN_IPYTHON:
50+
# For ipython environment, a cell can end fast with threads running in background.
51+
# Flush all the ongoing logs proactively to avoid missing logs.
52+
assert self._logging_mesh_client is not None
53+
logging_client: LoggingMeshClient = self._logging_mesh_client
54+
ipython = get_ipython()
55+
56+
# pyre-ignore[11]
57+
def flush_logs(_: ExecutionResult) -> None:
58+
try:
59+
Future(coro=logging_client.flush().spawn().task()).get(3)
60+
except TimeoutError:
61+
# We need to prevent failed proc meshes not coming back
62+
pass
63+
64+
# Force to recycle previous undropped proc_mesh.
65+
# Otherwise, we may end up with unregisterd dead callbacks.
66+
gc.collect()
67+
68+
# Store the handler reference so we can unregister it later
69+
self._ipython_flush_logs_handler = flush_logs
70+
ipython.events.register("post_run_cell", flush_logs)
71+
72+
async def logging_option(
73+
self,
74+
stream_to_client: bool = True,
75+
aggregate_window_sec: int | None = 3,
76+
level: int = logging.INFO,
77+
) -> None:
78+
if level < 0 or level > 255:
79+
raise ValueError("Invalid logging level: {}".format(level))
80+
81+
assert self._logging_mesh_client is not None
82+
self._logging_mesh_client.set_mode(
83+
stream_to_client=stream_to_client,
84+
aggregate_window_sec=aggregate_window_sec,
85+
level=level,
86+
)
87+
88+
def stop(self) -> None:
89+
if self._ipython_flush_logs_handler is not None:
90+
assert IN_IPYTHON
91+
ipython = get_ipython()
92+
assert ipython is not None
93+
ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler)
94+
self._ipython_flush_logs_handler = None

python/monarch/_src/actor/proc_mesh.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
)
3434
from weakref import WeakValueDictionary
3535

36-
from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
3736
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
3837
Alloc,
3938
AllocConstraints,
@@ -67,10 +66,12 @@
6766

6867
from monarch._src.actor.endpoint import endpoint
6968
from monarch._src.actor.future import DeprecatedNotAFuture, Future
69+
from monarch._src.actor.logging import LoggingManager
7070
from monarch._src.actor.shape import MeshTrait
7171
from monarch.tools.config import Workspace
7272
from monarch.tools.utils import conda as conda_utils
7373

74+
7475
HAS_TENSOR_ENGINE = False
7576
try:
7677
# Torch is needed for tensor engine
@@ -191,7 +192,7 @@ def __init__(
191192
# of whether this is a slice of a real proc_meshg
192193
self._slice = False
193194
self._code_sync_client: Optional[CodeSyncMeshClient] = None
194-
self._logging_mesh_client: Optional[LoggingMeshClient] = None
195+
self._logging_manager: LoggingManager = LoggingManager()
195196
self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh
196197
self._stopped = False
197198
self._controller_controller: Optional["_ControllerController"] = None
@@ -311,14 +312,7 @@ async def task(
311312
) -> HyProcMesh:
312313
hy_proc_mesh = await hy_proc_mesh_task
313314

314-
pm._logging_mesh_client = await LoggingMeshClient.spawn(
315-
proc_mesh=hy_proc_mesh
316-
)
317-
pm._logging_mesh_client.set_mode(
318-
stream_to_client=True,
319-
aggregate_window_sec=3,
320-
level=logging.INFO,
321-
)
315+
await pm._logging_manager.init(hy_proc_mesh)
322316

323317
if setup_actor is not None:
324318
await setup_actor.setup.call()
@@ -482,12 +476,9 @@ async def logging_option(
482476
Returns:
483477
None
484478
"""
485-
if level < 0 or level > 255:
486-
raise ValueError("Invalid logging level: {}".format(level))
487479
await self.initialized
488480

489-
assert self._logging_mesh_client is not None
490-
self._logging_mesh_client.set_mode(
481+
await self._logging_manager.logging_option(
491482
stream_to_client=stream_to_client,
492483
aggregate_window_sec=aggregate_window_sec,
493484
level=level,
@@ -499,6 +490,8 @@ async def __aenter__(self) -> "ProcMesh":
499490
return self
500491

501492
def stop(self) -> Future[None]:
493+
self._logging_manager.stop()
494+
502495
async def _stop_nonblocking() -> None:
503496
await (await self._proc_mesh).stop_nonblocking()
504497
self._stopped = True
@@ -516,6 +509,8 @@ async def __aexit__(
516509
# Finalizer to check if the proc mesh was closed properly.
517510
def __del__(self) -> None:
518511
if not self._stopped:
512+
self._logging_manager.stop()
513+
519514
warnings.warn(
520515
f"unstopped ProcMesh {self!r}",
521516
ResourceWarning,

python/tests/python_actor_test_binary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def _flush_logs() -> None:
4242
await am.print.call("has print streaming")
4343

4444
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
45-
log_mesh = pm._logging_mesh_client
45+
log_mesh = pm._logging_manager._logging_mesh_client
4646
assert log_mesh is not None
4747
Future(coro=log_mesh.flush().spawn().task()).get()
4848

python/tests/test_python_actors.py

Lines changed: 151 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88
import asyncio
9+
import gc
910
import importlib.resources
1011
import logging
1112
import operator
@@ -586,7 +587,7 @@ async def test_actor_log_streaming() -> None:
586587
await am.log.call("has log streaming as level matched")
587588

588589
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
589-
log_mesh = pm._logging_mesh_client
590+
log_mesh = pm._logging_manager._logging_mesh_client
590591
assert log_mesh is not None
591592
Future(coro=log_mesh.flush().spawn().task()).get()
592593

@@ -705,7 +706,7 @@ async def test_logging_option_defaults() -> None:
705706
await am.log.call("log streaming")
706707

707708
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
708-
log_mesh = pm._logging_mesh_client
709+
log_mesh = pm._logging_manager._logging_mesh_client
709710
assert log_mesh is not None
710711
Future(coro=log_mesh.flush().spawn().task()).get()
711712

@@ -760,6 +761,151 @@ async def test_logging_option_defaults() -> None:
760761
pass
761762

762763

764+
# oss_skip: pytest keeps complaining about mocking get_ipython module
765+
@pytest.mark.oss_skip
766+
@pytest.mark.timeout(180)
767+
async def test_flush_logs_ipython() -> None:
768+
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
769+
# Save original file descriptors
770+
original_stdout_fd = os.dup(1) # stdout
771+
772+
try:
773+
# Create temporary files to capture output
774+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
775+
stdout_path = stdout_file.name
776+
777+
# Redirect file descriptors to our temp files
778+
os.dup2(stdout_file.fileno(), 1)
779+
780+
# Also redirect Python's sys.stdout
781+
original_sys_stdout = sys.stdout
782+
sys.stdout = stdout_file
783+
784+
try:
785+
# Mock IPython environment
786+
class MockExecutionResult:
787+
pass
788+
789+
class MockEvents:
790+
def __init__(self):
791+
self.callbacks = {}
792+
self.registers = 0
793+
self.unregisters = 0
794+
795+
def register(self, event_name, callback):
796+
if event_name not in self.callbacks:
797+
self.callbacks[event_name] = []
798+
self.callbacks[event_name].append(callback)
799+
self.registers += 1
800+
801+
def unregister(self, event_name, callback):
802+
if event_name not in self.callbacks:
803+
raise ValueError(f"Event {event_name} not registered")
804+
assert callback in self.callbacks[event_name]
805+
self.callbacks[event_name].remove(callback)
806+
self.unregisters += 1
807+
808+
def trigger(self, event_name, *args, **kwargs):
809+
if event_name in self.callbacks:
810+
for callback in self.callbacks[event_name]:
811+
callback(*args, **kwargs)
812+
813+
class MockIPython:
814+
def __init__(self):
815+
self.events = MockEvents()
816+
817+
mock_ipython = MockIPython()
818+
819+
with unittest.mock.patch(
820+
"monarch._src.actor.logging.get_ipython",
821+
lambda: mock_ipython,
822+
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
823+
# Make sure we can register and unregister callbacks
824+
for i in range(3):
825+
pm1 = await proc_mesh(gpus=2)
826+
pm2 = await proc_mesh(gpus=2)
827+
am1 = await pm1.spawn("printer", Printer)
828+
am2 = await pm2.spawn("printer", Printer)
829+
830+
# Set aggregation window to ensure logs are buffered
831+
await pm1.logging_option(
832+
stream_to_client=True, aggregate_window_sec=600
833+
)
834+
await pm2.logging_option(
835+
stream_to_client=True, aggregate_window_sec=600
836+
)
837+
assert mock_ipython.events.unregisters == 2 * i
838+
# TODO: remove `1 +` from attaching controller_controller
839+
assert mock_ipython.events.registers == 1 + 2 * (i + 1)
840+
await asyncio.sleep(1)
841+
842+
# Generate some logs that will be aggregated
843+
for _ in range(5):
844+
await am1.print.call("ipython1 test log")
845+
await am2.print.call("ipython2 test log")
846+
847+
# Trigger the post_run_cell event which should flush logs
848+
mock_ipython.events.trigger(
849+
"post_run_cell", MockExecutionResult()
850+
)
851+
852+
# Flush all outputs
853+
stdout_file.flush()
854+
os.fsync(stdout_file.fileno())
855+
856+
gc.collect()
857+
858+
# TODO: this should be 6 without attaching controller_controller
859+
assert mock_ipython.events.registers == 7
860+
# There are many objects still taking refs
861+
assert mock_ipython.events.unregisters == 4
862+
# TODO: same, this should be 2
863+
assert len(mock_ipython.events.callbacks["post_run_cell"]) == 3
864+
finally:
865+
# Restore Python's sys.stdout
866+
sys.stdout = original_sys_stdout
867+
868+
# Restore original file descriptors
869+
os.dup2(original_stdout_fd, 1)
870+
871+
# Read the captured output
872+
with open(stdout_path, "r") as f:
873+
stdout_content = f.read()
874+
875+
# TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils
876+
877+
# Clean up temp files
878+
os.unlink(stdout_path)
879+
880+
# Verify that logs were flushed when the post_run_cell event was triggered
881+
# We should see the aggregated logs in the output
882+
assert (
883+
len(
884+
re.findall(
885+
r"\[10 similar log lines\].*ipython1 test log", stdout_content
886+
)
887+
)
888+
== 3
889+
), stdout_content
890+
891+
assert (
892+
len(
893+
re.findall(
894+
r"\[10 similar log lines\].*ipython2 test log", stdout_content
895+
)
896+
)
897+
== 3
898+
), stdout_content
899+
900+
finally:
901+
# Ensure file descriptors are restored even if something goes wrong
902+
try:
903+
os.dup2(original_stdout_fd, 1)
904+
os.close(original_stdout_fd)
905+
except OSError:
906+
pass
907+
908+
763909
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
764910
@pytest.mark.oss_skip
765911
async def test_flush_logs_fast_exit() -> None:
@@ -834,7 +980,7 @@ async def test_flush_on_disable_aggregation() -> None:
834980
await am.print.call("single log line")
835981

836982
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
837-
log_mesh = pm._logging_mesh_client
983+
log_mesh = pm._logging_manager._logging_mesh_client
838984
assert log_mesh is not None
839985
Future(coro=log_mesh.flush().spawn().task()).get()
840986

@@ -894,7 +1040,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
8941040
for _ in range(10):
8951041
await am.print.call("aggregated log line")
8961042

897-
log_mesh = pm._logging_mesh_client
1043+
log_mesh = pm._logging_manager._logging_mesh_client
8981044
assert log_mesh is not None
8991045
futures = []
9001046
for _ in range(5):
@@ -947,7 +1093,7 @@ async def test_adjust_aggregation_window() -> None:
9471093
await am.print.call("second batch of logs")
9481094

9491095
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
950-
log_mesh = pm._logging_mesh_client
1096+
log_mesh = pm._logging_manager._logging_mesh_client
9511097
assert log_mesh is not None
9521098
Future(coro=log_mesh.flush().spawn().task()).get()
9531099

0 commit comments

Comments
 (0)