Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions python/monarch/_src/actor/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import gc
import logging

from typing import Callable

from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient

from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
from monarch._src.actor.future import Future

IN_IPYTHON = False
try:
# Check if we are in ipython environment
# pyre-ignore[21]
from IPython import get_ipython

# pyre-ignore[21]
from IPython.core.interactiveshell import ExecutionResult

IN_IPYTHON = get_ipython() is not None
except ImportError:
pass


class LoggingManager:
def __init__(self) -> None:
self._logging_mesh_client: LoggingMeshClient | None = None
self._ipython_flush_logs_handler: Callable[..., None] | None = None

async def init(self, proc_mesh: HyProcMesh) -> None:
if self._logging_mesh_client is not None:
return

self._logging_mesh_client = await LoggingMeshClient.spawn(proc_mesh=proc_mesh)
self._logging_mesh_client.set_mode(
stream_to_client=True,
aggregate_window_sec=3,
level=logging.INFO,
)

if IN_IPYTHON:
# For ipython environment, a cell can end fast with threads running in background.
# Flush all the ongoing logs proactively to avoid missing logs.
assert self._logging_mesh_client is not None
logging_client: LoggingMeshClient = self._logging_mesh_client
ipython = get_ipython()

# pyre-ignore[11]
def flush_logs(_: ExecutionResult) -> None:
try:
Future(coro=logging_client.flush().spawn().task()).get(3)
except TimeoutError:
# We need to prevent failed proc meshes not coming back
pass

# Force to recycle previous undropped proc_mesh.
# Otherwise, we may end up with unregisterd dead callbacks.
gc.collect()

# Store the handler reference so we can unregister it later
self._ipython_flush_logs_handler = flush_logs
ipython.events.register("post_run_cell", flush_logs)

async def logging_option(
self,
stream_to_client: bool = True,
aggregate_window_sec: int | None = 3,
level: int = logging.INFO,
) -> None:
if level < 0 or level > 255:
raise ValueError("Invalid logging level: {}".format(level))

assert self._logging_mesh_client is not None
self._logging_mesh_client.set_mode(
stream_to_client=stream_to_client,
aggregate_window_sec=aggregate_window_sec,
level=level,
)

def stop(self) -> None:
if self._ipython_flush_logs_handler is not None:
assert IN_IPYTHON
ipython = get_ipython()
assert ipython is not None
ipython.events.unregister("post_run_cell", self._ipython_flush_logs_handler)
self._ipython_flush_logs_handler = None
23 changes: 9 additions & 14 deletions python/monarch/_src/actor/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
)
from weakref import WeakValueDictionary

from monarch._rust_bindings.monarch_extension.logging import LoggingMeshClient
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
Alloc,
AllocConstraints,
Expand Down Expand Up @@ -67,10 +66,12 @@

from monarch._src.actor.endpoint import endpoint
from monarch._src.actor.future import DeprecatedNotAFuture, Future
from monarch._src.actor.logging import LoggingManager
from monarch._src.actor.shape import MeshTrait
from monarch.tools.config import Workspace
from monarch.tools.utils import conda as conda_utils


HAS_TENSOR_ENGINE = False
try:
# Torch is needed for tensor engine
Expand Down Expand Up @@ -191,7 +192,7 @@ def __init__(
# of whether this is a slice of a real proc_meshg
self._slice = False
self._code_sync_client: Optional[CodeSyncMeshClient] = None
self._logging_mesh_client: Optional[LoggingMeshClient] = None
self._logging_manager: LoggingManager = LoggingManager()
self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh
self._stopped = False
self._controller_controller: Optional["_ControllerController"] = None
Expand Down Expand Up @@ -311,14 +312,7 @@ async def task(
) -> HyProcMesh:
hy_proc_mesh = await hy_proc_mesh_task

pm._logging_mesh_client = await LoggingMeshClient.spawn(
proc_mesh=hy_proc_mesh
)
pm._logging_mesh_client.set_mode(
stream_to_client=True,
aggregate_window_sec=3,
level=logging.INFO,
)
await pm._logging_manager.init(hy_proc_mesh)

if setup_actor is not None:
await setup_actor.setup.call()
Expand Down Expand Up @@ -482,12 +476,9 @@ async def logging_option(
Returns:
None
"""
if level < 0 or level > 255:
raise ValueError("Invalid logging level: {}".format(level))
await self.initialized

assert self._logging_mesh_client is not None
self._logging_mesh_client.set_mode(
await self._logging_manager.logging_option(
stream_to_client=stream_to_client,
aggregate_window_sec=aggregate_window_sec,
level=level,
Expand All @@ -499,6 +490,8 @@ async def __aenter__(self) -> "ProcMesh":
return self

def stop(self) -> Future[None]:
self._logging_manager.stop()

async def _stop_nonblocking() -> None:
await (await self._proc_mesh).stop_nonblocking()
self._stopped = True
Expand All @@ -516,6 +509,8 @@ async def __aexit__(
# Finalizer to check if the proc mesh was closed properly.
def __del__(self) -> None:
if not self._stopped:
self._logging_manager.stop()

warnings.warn(
f"unstopped ProcMesh {self!r}",
ResourceWarning,
Expand Down
2 changes: 1 addition & 1 deletion python/tests/python_actor_test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def _flush_logs() -> None:
await am.print.call("has print streaming")

# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
log_mesh = pm._logging_mesh_client
log_mesh = pm._logging_manager._logging_mesh_client
assert log_mesh is not None
Future(coro=log_mesh.flush().spawn().task()).get()

Expand Down
156 changes: 151 additions & 5 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-unsafe
import asyncio
import gc
import importlib.resources
import logging
import operator
Expand Down Expand Up @@ -586,7 +587,7 @@ async def test_actor_log_streaming() -> None:
await am.log.call("has log streaming as level matched")

# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
log_mesh = pm._logging_mesh_client
log_mesh = pm._logging_manager._logging_mesh_client
assert log_mesh is not None
Future(coro=log_mesh.flush().spawn().task()).get()

Expand Down Expand Up @@ -705,7 +706,7 @@ async def test_logging_option_defaults() -> None:
await am.log.call("log streaming")

# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
log_mesh = pm._logging_mesh_client
log_mesh = pm._logging_manager._logging_mesh_client
assert log_mesh is not None
Future(coro=log_mesh.flush().spawn().task()).get()

Expand Down Expand Up @@ -760,6 +761,151 @@ async def test_logging_option_defaults() -> None:
pass


# oss_skip: pytest keeps complaining about mocking get_ipython module
@pytest.mark.oss_skip
@pytest.mark.timeout(180)
async def test_flush_logs_ipython() -> None:
"""Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
# Save original file descriptors
original_stdout_fd = os.dup(1) # stdout

try:
# Create temporary files to capture output
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file:
stdout_path = stdout_file.name

# Redirect file descriptors to our temp files
os.dup2(stdout_file.fileno(), 1)

# Also redirect Python's sys.stdout
original_sys_stdout = sys.stdout
sys.stdout = stdout_file

try:
# Mock IPython environment
class MockExecutionResult:
pass

class MockEvents:
def __init__(self):
self.callbacks = {}
self.registers = 0
self.unregisters = 0

def register(self, event_name, callback):
if event_name not in self.callbacks:
self.callbacks[event_name] = []
self.callbacks[event_name].append(callback)
self.registers += 1

def unregister(self, event_name, callback):
if event_name not in self.callbacks:
raise ValueError(f"Event {event_name} not registered")
assert callback in self.callbacks[event_name]
self.callbacks[event_name].remove(callback)
self.unregisters += 1

def trigger(self, event_name, *args, **kwargs):
if event_name in self.callbacks:
for callback in self.callbacks[event_name]:
callback(*args, **kwargs)

class MockIPython:
def __init__(self):
self.events = MockEvents()

mock_ipython = MockIPython()

with unittest.mock.patch(
"monarch._src.actor.logging.get_ipython",
lambda: mock_ipython,
), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True):
# Make sure we can register and unregister callbacks
for i in range(3):
pm1 = await proc_mesh(gpus=2)
pm2 = await proc_mesh(gpus=2)
am1 = await pm1.spawn("printer", Printer)
am2 = await pm2.spawn("printer", Printer)

# Set aggregation window to ensure logs are buffered
await pm1.logging_option(
stream_to_client=True, aggregate_window_sec=600
)
await pm2.logging_option(
stream_to_client=True, aggregate_window_sec=600
)
assert mock_ipython.events.unregisters == 2 * i
# TODO: remove `1 +` from attaching controller_controller
assert mock_ipython.events.registers == 1 + 2 * (i + 1)
await asyncio.sleep(1)

# Generate some logs that will be aggregated
for _ in range(5):
await am1.print.call("ipython1 test log")
await am2.print.call("ipython2 test log")

# Trigger the post_run_cell event which should flush logs
mock_ipython.events.trigger(
"post_run_cell", MockExecutionResult()
)

# Flush all outputs
stdout_file.flush()
os.fsync(stdout_file.fileno())

gc.collect()

# TODO: this should be 6 without attaching controller_controller
assert mock_ipython.events.registers == 7
# There are many objects still taking refs
assert mock_ipython.events.unregisters == 4
# TODO: same, this should be 2
assert len(mock_ipython.events.callbacks["post_run_cell"]) == 3
finally:
# Restore Python's sys.stdout
sys.stdout = original_sys_stdout

# Restore original file descriptors
os.dup2(original_stdout_fd, 1)

# Read the captured output
with open(stdout_path, "r") as f:
stdout_content = f.read()

# TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils

# Clean up temp files
os.unlink(stdout_path)

# Verify that logs were flushed when the post_run_cell event was triggered
# We should see the aggregated logs in the output
assert (
len(
re.findall(
r"\[10 similar log lines\].*ipython1 test log", stdout_content
)
)
== 3
), stdout_content

assert (
len(
re.findall(
r"\[10 similar log lines\].*ipython2 test log", stdout_content
)
)
== 3
), stdout_content

finally:
# Ensure file descriptors are restored even if something goes wrong
try:
os.dup2(original_stdout_fd, 1)
os.close(original_stdout_fd)
except OSError:
pass


# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
@pytest.mark.oss_skip
async def test_flush_logs_fast_exit() -> None:
Expand Down Expand Up @@ -834,7 +980,7 @@ async def test_flush_on_disable_aggregation() -> None:
await am.print.call("single log line")

# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
log_mesh = pm._logging_mesh_client
log_mesh = pm._logging_manager._logging_mesh_client
assert log_mesh is not None
Future(coro=log_mesh.flush().spawn().task()).get()

Expand Down Expand Up @@ -894,7 +1040,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
for _ in range(10):
await am.print.call("aggregated log line")

log_mesh = pm._logging_mesh_client
log_mesh = pm._logging_manager._logging_mesh_client
assert log_mesh is not None
futures = []
for _ in range(5):
Expand Down Expand Up @@ -947,7 +1093,7 @@ async def test_adjust_aggregation_window() -> None:
await am.print.call("second batch of logs")

# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
log_mesh = pm._logging_mesh_client
log_mesh = pm._logging_manager._logging_mesh_client
assert log_mesh is not None
Future(coro=log_mesh.flush().spawn().task()).get()

Expand Down