Skip to content

Commit df75a9c

Browse files
authored
Add cooldown time for PyRosettaCluster unit tests (#512)
1 parent 58d657c commit df75a9c

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"run",
5050
"update_scores",
5151
]
52-
__version__: str = "2.1.0"
52+
__version__: str = "2.1.1"
5353

5454
_print_conda_warnings()
5555

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@
227227
dry_run: A `bool` object specifying whether or not to save '.pdb' files to
228228
disk. If `True`, then do not write '.pdb' or '.pdb.bz2' files to disk.
229229
Default: False
230+
cooldown_time: A `float` or `int` object specifying how many seconds to sleep after the
231+
simulation is complete to allow loggers to flush. For very slow network filesystems,
232+
2.0 or more seconds may be reasonable.
233+
Default: 0.5
230234
231235
Returns:
232236
A PyRosettaCluster instance.
@@ -604,6 +608,12 @@ class PyRosettaCluster(IO[G], LoggingSupport[G], SchedulerManager[G], TaskBase[G
604608
validator=attr.validators.instance_of(bool),
605609
converter=_parse_yield_results,
606610
)
611+
cooldown_time = attr.ib(
612+
type=float,
613+
default=0.5,
614+
validator=[_validate_float, attr.validators.instance_of((float, int))],
615+
converter=attr.converters.default_if_none(default=0.5),
616+
)
607617
protocols_key = attr.ib(
608618
type=str,
609619
default="PyRosettaCluster_protocols_container",

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/logging_support.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import logging
2727
import os
28+
import time
2829
import warnings
2930

3031
from contextlib import suppress
@@ -108,6 +109,7 @@ def _close_logger(self) -> None:
108109
for handler in logger.handlers[:]:
109110
logger.removeHandler(handler)
110111
with suppress(Exception):
112+
handler.flush()
111113
handler.close()
112114
logging.shutdown()
113115

@@ -153,8 +155,9 @@ def _close_socket_listener(self, clients: Dict[int, Client]) -> None:
153155
self._close_socket_logger_plugins(clients)
154156
self.socket_listener.stop()
155157
handler = self.socket_listener.handler
156-
handler.flush()
157158
with suppress(Exception):
159+
handler.flush()
160+
self._cooldown()
158161
handler.close()
159162

160163
def _close_socket_logger_plugins(self, clients: Dict[int, Client]) -> None:
@@ -175,6 +178,9 @@ def _close_socket_logger_plugins(self, clients: Dict[int, Client]) -> None:
175178
f"Logger was not closed cleanly on dask worker ({worker_address}) - {result}"
176179
)
177180

181+
def _cooldown(self) -> None:
182+
time.sleep(self.cooldown_time)
183+
178184

179185
def purge_socket_logger_plugin_address(
180186
socket_listener_address: Tuple[str, int],

source/src/python/PyRosetta/src/test/T900_distributed.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@
5353
print("Printing pip environment failed with return code: {0}.".format(ex.returncode))
5454

5555

56-
def e(cmd, sleep=1):
57-
"""Run command getting return code and output with subsequent sleep step to provide extra time to flush loggers."""
56+
def e(cmd):
57+
"""Run command getting return code and output."""
5858
print("Executing:\n{0}".format(cmd))
59-
status, output = subprocess.getstatusoutput("{0} && sleep {1}".format(cmd, sleep))
59+
status, output = subprocess.getstatusoutput(cmd)
6060
print("Output:\n{0}".format(output))
6161
if status != 0:
6262
print(
@@ -106,9 +106,8 @@ def e(cmd, sleep=1):
106106
tests = distributed_cluster_test_cases + distributed_test_suites
107107

108108
for test in tests:
109-
sleep = 5 if test in distributed_cluster_test_cases else 1
110109
t0 = time.time()
111-
e("{python} -m unittest {test}".format(python=sys.executable, test=test), sleep=sleep)
110+
e("{python} -m unittest {test}".format(python=sys.executable, test=test))
112111
t1 = time.time()
113112
dt = t1 - t0
114113
print("Finished running test in {0} seconds: {1}\n".format(round(dt, 6), test))

0 commit comments

Comments
 (0)