|
6 | 6 |
|
7 | 7 | # pyre-unsafe
|
8 | 8 | import asyncio
|
| 9 | +import gc |
9 | 10 | import importlib.resources
|
10 | 11 | import logging
|
11 | 12 | import operator
|
@@ -586,7 +587,7 @@ async def test_actor_log_streaming() -> None:
|
586 | 587 | await am.log.call("has log streaming as level matched")
|
587 | 588 |
|
588 | 589 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
589 |
| - log_mesh = pm._logging_mesh_client |
| 590 | + log_mesh = pm._logging_manager._logging_mesh_client |
590 | 591 | assert log_mesh is not None
|
591 | 592 | Future(coro=log_mesh.flush().spawn().task()).get()
|
592 | 593 |
|
@@ -705,7 +706,7 @@ async def test_logging_option_defaults() -> None:
|
705 | 706 | await am.log.call("log streaming")
|
706 | 707 |
|
707 | 708 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
708 |
| - log_mesh = pm._logging_mesh_client |
| 709 | + log_mesh = pm._logging_manager._logging_mesh_client |
709 | 710 | assert log_mesh is not None
|
710 | 711 | Future(coro=log_mesh.flush().spawn().task()).get()
|
711 | 712 |
|
@@ -760,6 +761,151 @@ async def test_logging_option_defaults() -> None:
|
760 | 761 | pass
|
761 | 762 |
|
762 | 763 |
|
| 764 | +# oss_skip: pytest keeps complaining about mocking get_ipython module |
| 765 | +@pytest.mark.oss_skip |
| 766 | +@pytest.mark.timeout(180) |
| 767 | +async def test_flush_logs_ipython() -> None: |
| 768 | + """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" |
| 769 | + # Save original file descriptors |
| 770 | + original_stdout_fd = os.dup(1) # stdout |
| 771 | + |
| 772 | + try: |
| 773 | + # Create temporary files to capture output |
| 774 | + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: |
| 775 | + stdout_path = stdout_file.name |
| 776 | + |
| 777 | + # Redirect file descriptors to our temp files |
| 778 | + os.dup2(stdout_file.fileno(), 1) |
| 779 | + |
| 780 | + # Also redirect Python's sys.stdout |
| 781 | + original_sys_stdout = sys.stdout |
| 782 | + sys.stdout = stdout_file |
| 783 | + |
| 784 | + try: |
| 785 | + # Mock IPython environment |
| 786 | + class MockExecutionResult: |
| 787 | + pass |
| 788 | + |
| 789 | + class MockEvents: |
| 790 | + def __init__(self): |
| 791 | + self.callbacks = {} |
| 792 | + self.registers = 0 |
| 793 | + self.unregisters = 0 |
| 794 | + |
| 795 | + def register(self, event_name, callback): |
| 796 | + if event_name not in self.callbacks: |
| 797 | + self.callbacks[event_name] = [] |
| 798 | + self.callbacks[event_name].append(callback) |
| 799 | + self.registers += 1 |
| 800 | + |
| 801 | + def unregister(self, event_name, callback): |
| 802 | + if event_name not in self.callbacks: |
| 803 | + raise ValueError(f"Event {event_name} not registered") |
| 804 | + assert callback in self.callbacks[event_name] |
| 805 | + self.callbacks[event_name].remove(callback) |
| 806 | + self.unregisters += 1 |
| 807 | + |
| 808 | + def trigger(self, event_name, *args, **kwargs): |
| 809 | + if event_name in self.callbacks: |
| 810 | + for callback in self.callbacks[event_name]: |
| 811 | + callback(*args, **kwargs) |
| 812 | + |
| 813 | + class MockIPython: |
| 814 | + def __init__(self): |
| 815 | + self.events = MockEvents() |
| 816 | + |
| 817 | + mock_ipython = MockIPython() |
| 818 | + |
| 819 | + with unittest.mock.patch( |
| 820 | + "monarch._src.actor.logging.get_ipython", |
| 821 | + lambda: mock_ipython, |
| 822 | + ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): |
| 823 | + # Make sure we can register and unregister callbacks |
| 824 | + for i in range(3): |
| 825 | + pm1 = await proc_mesh(gpus=2) |
| 826 | + pm2 = await proc_mesh(gpus=2) |
| 827 | + am1 = await pm1.spawn("printer", Printer) |
| 828 | + am2 = await pm2.spawn("printer", Printer) |
| 829 | + |
| 830 | + # Set aggregation window to ensure logs are buffered |
| 831 | + await pm1.logging_option( |
| 832 | + stream_to_client=True, aggregate_window_sec=600 |
| 833 | + ) |
| 834 | + await pm2.logging_option( |
| 835 | + stream_to_client=True, aggregate_window_sec=600 |
| 836 | + ) |
| 837 | + assert mock_ipython.events.unregisters == 2 * i |
| 838 | + # TODO: remove `1 +` from attaching controller_controller |
| 839 | + assert mock_ipython.events.registers == 1 + 2 * (i + 1) |
| 840 | + await asyncio.sleep(1) |
| 841 | + |
| 842 | + # Generate some logs that will be aggregated |
| 843 | + for _ in range(5): |
| 844 | + await am1.print.call("ipython1 test log") |
| 845 | + await am2.print.call("ipython2 test log") |
| 846 | + |
| 847 | + # Trigger the post_run_cell event which should flush logs |
| 848 | + mock_ipython.events.trigger( |
| 849 | + "post_run_cell", MockExecutionResult() |
| 850 | + ) |
| 851 | + |
| 852 | + # Flush all outputs |
| 853 | + stdout_file.flush() |
| 854 | + os.fsync(stdout_file.fileno()) |
| 855 | + |
| 856 | + gc.collect() |
| 857 | + |
| 858 | + # TODO: this should be 6 without attaching controller_controller |
| 859 | + assert mock_ipython.events.registers == 7 |
| 860 | + # There are many objects still taking refs |
| 861 | + assert mock_ipython.events.unregisters == 4 |
| 862 | + # TODO: same, this should be 2 |
| 863 | + assert len(mock_ipython.events.callbacks["post_run_cell"]) == 3 |
| 864 | + finally: |
| 865 | + # Restore Python's sys.stdout |
| 866 | + sys.stdout = original_sys_stdout |
| 867 | + |
| 868 | + # Restore original file descriptors |
| 869 | + os.dup2(original_stdout_fd, 1) |
| 870 | + |
| 871 | + # Read the captured output |
| 872 | + with open(stdout_path, "r") as f: |
| 873 | + stdout_content = f.read() |
| 874 | + |
| 875 | + # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils |
| 876 | + |
| 877 | + # Clean up temp files |
| 878 | + os.unlink(stdout_path) |
| 879 | + |
| 880 | + # Verify that logs were flushed when the post_run_cell event was triggered |
| 881 | + # We should see the aggregated logs in the output |
| 882 | + assert ( |
| 883 | + len( |
| 884 | + re.findall( |
| 885 | + r"\[10 similar log lines\].*ipython1 test log", stdout_content |
| 886 | + ) |
| 887 | + ) |
| 888 | + == 3 |
| 889 | + ), stdout_content |
| 890 | + |
| 891 | + assert ( |
| 892 | + len( |
| 893 | + re.findall( |
| 894 | + r"\[10 similar log lines\].*ipython2 test log", stdout_content |
| 895 | + ) |
| 896 | + ) |
| 897 | + == 3 |
| 898 | + ), stdout_content |
| 899 | + |
| 900 | + finally: |
| 901 | + # Ensure file descriptors are restored even if something goes wrong |
| 902 | + try: |
| 903 | + os.dup2(original_stdout_fd, 1) |
| 904 | + os.close(original_stdout_fd) |
| 905 | + except OSError: |
| 906 | + pass |
| 907 | + |
| 908 | + |
763 | 909 | # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
|
764 | 910 | @pytest.mark.oss_skip
|
765 | 911 | async def test_flush_logs_fast_exit() -> None:
|
@@ -834,7 +980,7 @@ async def test_flush_on_disable_aggregation() -> None:
|
834 | 980 | await am.print.call("single log line")
|
835 | 981 |
|
836 | 982 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
837 |
| - log_mesh = pm._logging_mesh_client |
| 983 | + log_mesh = pm._logging_manager._logging_mesh_client |
838 | 984 | assert log_mesh is not None
|
839 | 985 | Future(coro=log_mesh.flush().spawn().task()).get()
|
840 | 986 |
|
@@ -894,7 +1040,7 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
|
894 | 1040 | for _ in range(10):
|
895 | 1041 | await am.print.call("aggregated log line")
|
896 | 1042 |
|
897 |
| - log_mesh = pm._logging_mesh_client |
| 1043 | + log_mesh = pm._logging_manager._logging_mesh_client |
898 | 1044 | assert log_mesh is not None
|
899 | 1045 | futures = []
|
900 | 1046 | for _ in range(5):
|
@@ -947,7 +1093,7 @@ async def test_adjust_aggregation_window() -> None:
|
947 | 1093 | await am.print.call("second batch of logs")
|
948 | 1094 |
|
949 | 1095 | # TODO: remove this completely once we hook the flush logic upon dropping device_mesh
|
950 |
| - log_mesh = pm._logging_mesh_client |
| 1096 | + log_mesh = pm._logging_manager._logging_mesh_client |
951 | 1097 | assert log_mesh is not None
|
952 | 1098 | Future(coro=log_mesh.flush().spawn().task()).get()
|
953 | 1099 |
|
|
0 commit comments