Skip to content

Commit e842d23

Browse files
authored
Add support for filtering results in PyRosettaCluster (#459)
1 parent 50a63df commit e842d23

File tree

11 files changed

+271
-30
lines changed

11 files changed

+271
-30
lines changed

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
produce,
2424
recreate_environment,
2525
reproduce,
26+
requires_packed_pose,
2627
reserve_scores,
2728
run,
2829
update_scores,
@@ -43,11 +44,12 @@
4344
"produce",
4445
"recreate_environment",
4546
"reproduce",
47+
"requires_packed_pose",
4648
"reserve_scores",
4749
"run",
4850
"update_scores",
4951
]
50-
__version__: str = "2.0.0"
52+
__version__: str = "2.1.0"
5153

5254
_print_conda_warnings()
5355

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/converter_tasks.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,17 @@ def _parse_dict(obj: Dict[Any, Any]) -> Dict[Any, Any]:
422422
for k in obj.keys():
423423
if k in ["client", "clients", "input_packed_pose"]:
424424
raise NotImplementedError(
425-
f"The parameter '{k}' must be passed directly to reproduce(), "
425+
f"The parameter '{k}' must be passed directly to `reproduce()`, "
426426
+ "not as a member of the 'instance_kwargs' dictionary."
427427
)
428+
elif k == "filter_results":
429+
raise ValueError(
430+
f"The parameter '{k}' cannot be set as a PyRosettaCluster attribute "
431+
+ "in `reproduce()` because the saved 'decoy_ids' attribute from the "
432+
+ "original simulation depends on the original decoy output order from "
433+
+ "each protocol, so results must be filtered identically. Please remove "
434+
+ "this keyword argument to run `reproduce()`."
435+
)
428436
return obj
429437

430438

@@ -433,6 +441,21 @@ def _default_none(obj: None) -> Dict[Any, Any]:
433441
return {}
434442

435443

444+
@singledispatch
445+
def is_empty(obj: Any) -> NoReturn:
446+
"""Test whether a `PackedPose` object is empty."""
447+
raise NotImplementedError(type(obj))
448+
449+
@is_empty.register(type(None))
450+
def _from_none(obj: None) -> bool:
451+
# Protocol results return a `None` object when a segmentation fault occurs with `ignore_errors=True`
452+
return False
453+
454+
@is_empty.register(PackedPose)
455+
def _from_packed(obj: PackedPose) -> bool:
456+
return obj.empty()
457+
458+
436459
def is_bytes(obj: Any) -> bool:
437460
return isinstance(obj, bytes)
438461

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/converters.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030
import pyrosetta
3131
import sys
3232
import types
33+
import warnings
3334

3435
from functools import singledispatch
3536
from pyrosetta.distributed.cluster.converter_tasks import (
3637
environment_cmd,
3738
get_yml,
3839
is_bytes,
3940
is_dict,
41+
is_empty,
4042
is_packed,
4143
parse_input_packed_pose as _parse_input_packed_pose,
4244
to_int,
@@ -63,6 +65,38 @@
6365
S = TypeVar("S", bound=Serialization)
6466

6567

68+
def _parse_filter_results(obj: Any) -> Union[bool, NoReturn]:
69+
"""Parse the input `filter_results` attribute of PyRosettaCluster."""
70+
_issue_future_warning = True
71+
72+
@singledispatch
73+
def converter(obj: Any) -> NoReturn:
74+
raise ValueError("'filter_results' must be of type `bool` or `NoneType`!")
75+
76+
@converter.register(type(None))
77+
def _parse_none(obj: None) -> bool:
78+
if _issue_future_warning:
79+
warnings.warn(
80+
(
81+
"As of PyRosettaCluster version 2.1.0, the 'filter_results' instance attribute "
82+
"is enabled by default, which automatically filters empty `PackedPose` objects between "
83+
"user-provided PyRosetta protocols to help reduce compute overhead. Please explicitly set "
84+
"either `filter_results=True` (the currently enabled, new setting) or `filter_results=False` "
85+
"(to revert to legacy behavior before version 2.1.0) to silence this notice. This notice "
86+
"will disappear in a future version of PyRosettaCluster."
87+
),
88+
FutureWarning,
89+
stacklevel=5,
90+
)
91+
return True
92+
93+
@converter.register(bool)
94+
def _parse_bool(obj: bool) -> bool:
95+
return obj
96+
97+
return converter(obj)
98+
99+
66100
def _parse_decoy_ids(objs: Any) -> List[int]:
67101
"""
68102
Normalize user-provided PyRosetta 'decoy_ids' to a `list` object containing `int` objects.

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@
206206
It's recommended to set this option to at least 1 second, but longer times may
207207
be used as a safety throttle in cases of overwhelmed dask scheduler processes.
208208
Default: 3.0
209+
filter_results: A `bool` object specifying whether or not to filter out empty
210+
`PackedPose` objects between user-provided PyRosetta protocols. When a protocol
211+
returns or yields `NoneType`, PyRosettaCluster converts it to an empty `PackedPose`
212+
object that gets passed to the next protocol. If `True`, then filter out any empty
213+
`PackedPose` objects where there are no residues in the conformation as given by
214+
`Pose.empty()`, otherwise if `False` then continue to pass empty `PackedPose` objects
215+
to the next protocol. This is used for filtering out decoys mid-trajectory through
216+
user-provided PyRosetta protocols if protocols return or yield any `None`, empty
217+
`Pose`, or empty `PackedPose` objects.
218+
Default: True
209219
save_all: A `bool` object specifying whether or not to save all of the returned
210220
or yielded `Pose` and `PackedPose` objects from all user-provided
211221
PyRosetta protocols. This option may be used for checkpointing trajectories.
@@ -257,7 +267,9 @@
257267
from datetime import datetime
258268
from pyrosetta.distributed.cluster.base import TaskBase, _get_residue_type_set
259269
from pyrosetta.distributed.cluster.converters import (
270+
is_empty as _is_empty,
260271
_parse_decoy_ids,
272+
_parse_filter_results,
261273
_parse_environment,
262274
_parse_input_packed_pose,
263275
_parse_logging_address,
@@ -567,6 +579,12 @@ class PyRosettaCluster(IO[G], LoggingSupport[G], SchedulerManager[G], TaskBase[G
567579
validator=[_validate_float, attr.validators.instance_of((float, int))],
568580
converter=attr.converters.default_if_none(default=3.0),
569581
)
582+
filter_results = attr.ib(
583+
type=bool,
584+
default=None,
585+
validator=attr.validators.instance_of(bool),
586+
converter=_parse_filter_results,
587+
)
570588
save_all = attr.ib(
571589
type=bool,
572590
default=False,
@@ -822,6 +840,8 @@ def _run(
822840
compressed_packed_pose,
823841
self.serializer.deepcopy_kwargs(kwargs),
824842
)
843+
if self.filter_results and _is_empty(self.serializer.decompress_packed_pose(compressed_packed_pose)):
844+
continue
825845
compressed_kwargs, pyrosetta_init_kwargs, protocol, clients_index, resource = self._setup_kwargs(
826846
kwargs, clients_indices, resources
827847
)

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/exceptions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def wrapper(
133133
+ WorkerError._get_msg(protocol_name, ignore_errors)
134134
)
135135
if ignore_errors:
136+
# Return a `NoneType` object to be converted to an empty `PackedPose` object
137+
# when a non-system-exiting Python exception is raised
136138
result = None
137139
else:
138140
raise

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/toolkit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
produce,
2222
recreate_environment,
2323
reproduce,
24+
requires_packed_pose,
2425
reserve_scores,
2526
run,
2627
update_scores,

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/tools.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from functools import wraps
3535
from pyrosetta.distributed.cluster.converters import _parse_protocols, _parse_yield_results
3636
from pyrosetta.distributed.cluster.converter_tasks import (
37+
is_empty,
3738
get_protocols_list_of_str,
3839
get_yml,
3940
parse_client,
@@ -427,6 +428,53 @@ def wrapper(packed_pose, **kwargs):
427428
return cast(P, wrapper)
428429

429430

431+
def requires_packed_pose(func: P) -> Union[PackedPose, None, P]:
432+
"""
433+
Use this as a Python decorator of any user-provided PyRosetta protocol.
434+
If a user-provided PyRosetta protocol requires that the first argument
435+
parameter be a non-empty `PackedPose` object, then return any received empty
436+
`PackedPose` objects or `NoneType` objects and skip the decorated protocol,
437+
otherwise run the decorated protocol.
438+
439+
If using `PyRosettaCluster(filter_results=False)` and the preceding protocol
440+
returns or yields either `None`, an empty `Pose` object, or an empty `PackedPose`
441+
object, then an empty `PackedPose` object is distributed to the next user-provided
442+
PyRosetta protocol, in which case the next protocol and/or any downstream
443+
protocols are skipped if they are decorated with this decorator. If using
444+
`PyRosettaCluster(ignore_errors=True)` and an error is raised in the preceding
445+
protocol, then a `NoneType` object is distributed to the next user-provided
446+
PyRosetta protocol, in which case the next protocol and/or any downstream
447+
protocols are skipped if they are decorated with this decorator.
448+
449+
For example:
450+
451+
@requires_packed_pose
452+
def my_pyrosetta_protocol(packed_pose, **kwargs):
453+
assert packed_pose.pose.size() > 0
454+
return packed_pose
455+
456+
Args:
457+
A user-provided PyRosetta function.
458+
459+
Returns:
460+
The input `packed_pose` argument parameter if it is an empty `PackedPose` object
461+
or a `NoneType` object, otherwise the results from the decorated protocol.
462+
"""
463+
@wraps(func)
464+
def wrapper(packed_pose, **kwargs):
465+
_msg = "User-provided PyRosetta protocol '{0}' received and is duly returning {1} object."
466+
if is_empty(packed_pose):
467+
logging.info(_msg.format(func.__name__, "an empty `PackedPose`"))
468+
return packed_pose
469+
elif packed_pose is None:
470+
logging.info(_msg.format(func.__name__, "a `NoneType`"))
471+
return packed_pose
472+
else:
473+
return func(packed_pose, **kwargs)
474+
475+
return cast(P, wrapper)
476+
477+
430478
def reproduce(
431479
input_file: Optional[str] = None,
432480
scorefile: Optional[str] = None,

source/src/python/PyRosetta/src/pyrosetta/distributed/packed_pose/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def clone(self):
7474
result.scores = pickle.loads(pickle.dumps(self.scores))
7575
return result
7676

77+
def empty(self):
78+
return self.pose.empty()
79+
7780

7881
def pack_result(func):
7982
@functools.wraps(func)

source/src/python/PyRosetta/src/pyrosetta/tests/distributed/cluster/test_logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def my_pyrosetta_protocol_2(packed_pose, **kwargs):
147147
save_all=False,
148148
system_info=None,
149149
pyrosetta_build=None,
150+
filter_results=None,
150151
max_delay_time=3.0,
151152
)
152153
cluster.distribute(my_pyrosetta_protocol_1, my_pyrosetta_protocol_2)

0 commit comments

Comments
 (0)