Skip to content

Commit 078f4d9

Browse files
klimajrclune
andauthored
Supporting secure unpickling in PyRosetta (#523)
Currently, `PackedPose` objects are serialized/deserialized using the `pickle` module (introduced in ~2019), and the `Pose.cache` dictionary (introduced in #430) supports caching arbitrary datatypes in the `Pose` object using the `pickle` module. Additionally, #462 enables saving compressed `PackedPose` objects to disk (i.e., as `*.b64_pose` and `*.pkl_pose` files) for sharing PyRosetta `Pose` objects with the scientific community. However, use of the `pickle` module is not secure (see warning [here](https://docs.python.org/3/library/pickle.html) as outlined in #519). Herein this PR, a secure `pickle.loads` method is developed and slotted into the `PackedPose` and `Pose.cache` infrastructure to permanently disallow certain risky packages, modules, and namespaces from being unpickled/loaded (e.g., `exec`, `eval`, `os.system`, `subprocess.run`, etc., and will be updated over time as needed), thus significantly improving the security of handling `PackedPose` and `Pose` objects in memory if received from a second party (i.e., over a socket, queue, interprocess communication, etc.) or when reading a file received from a second party (i.e., using `pyrosetta.distributed.io.pose_from_file` with a `*.b64_pose` and `*.pkl_pose` file). By default, only `pyrosetta` and `numpy` packages, and certain `builtins` modules (like `dict`, `complex`, `tuple`, etc.), are considered secure and permitted to be unpickled/loaded. Other packages that the user may want to serialize/deserialize may be assigned as secure per-process by the user in-code (see methods below). It is worth noting that PyTorch developers have implemented a similar strategy with the [torch.serialization.add_safe_globals()](https://docs.pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) method. Another aim of this PR is to implement an optional Hash-based Message Authentication Code (HMAC) key in the `Pose.cache` dictionary for data integrity verification. While not a security feature, this new API allows the user to set a HMAC key to be prepended to every score value in the `Pose.cache` dictionary that effectively says "this was saved by PyRosetta", so that it intentionally raises an error when the HMAC key is missing or differs upon retrieval, indicating that the data appears to have been tampered with or modified. By default, the HMAC key is disabled (being set to `None`) in order to reduce memory overhead of the `Pose.cache` dictionary; e.g., if 32 bytes are prepended to each score value, with 1,000 score values that's 32,000 bytes or 32 KB of overhead, and with a million score values that's 32 MB of overhead. The following are newly added functions: - `pyrosetta.secure_unpickle.add_secure_package`: Add a package to the unpickle allowed list - `pyrosetta.secure_unpickle.remove_secure_package`: Remove a package from the unpickle allowed list - `pyrosetta.secure_unpickle.clear_secure_packages`: Remove all packages from the unpickle allowed list - `pyrosetta.secure_unpickle.get_disallowed_packages`: Return all permanently disallowed packages/modules/prefixes - `pyrosetta.secure_unpickle.get_secure_packages`: Return all packages in the unpickle allowed list - `pyrosetta.secure_unpickle.set_secure_packages`: Set all packages in the unpickle allowed list - `pyrosetta.secure_unpickle.set_unpickle_hmac_key`: Set the HMAC key for the `Pose.cache` dictionary - `pyrosetta.secure_unpickle.get_unpickle_hmac_key`: Return the HMAC key for the `Pose.cache` dictionary --------- Co-authored-by: Rachel Clune <rachel.clune@omsf.io>
1 parent dd237bf commit 078f4d9

File tree

6 files changed

+833
-30
lines changed

6 files changed

+833
-30
lines changed

source/src/python/PyRosetta/src/pyrosetta/bindings/scores/serialization.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,48 +9,36 @@
99
__author__ = "Jason C. Klima"
1010

1111

12-
import base64
1312
import collections
14-
import pickle
13+
14+
from pyrosetta.secure_unpickle import SecureSerializerBase
1515

1616

1717
class PoseScoreSerializerBase(object):
1818
"""Base class for `PoseScoreSerializer` methods."""
19-
2019
@staticmethod
2120
def to_pickle(value):
22-
try:
23-
return pickle.dumps(value)
24-
except (TypeError, OverflowError, MemoryError, pickle.PicklingError) as ex:
25-
raise TypeError(
26-
"Only pickle-serializable object types are allowed to be set "
27-
+ "as score values. Received: %r. %s" % (type(value), ex)
28-
)
21+
return SecureSerializerBase.to_pickle(value)
2922

3023
@staticmethod
3124
def from_pickle(value):
32-
try:
33-
return pickle.loads(value)
34-
except (TypeError, OverflowError, MemoryError, EOFError, pickle.UnpicklingError) as ex:
35-
raise TypeError(
36-
"Could not deserialize score value of type %r. %s" % (type(value), ex)
37-
)
25+
return SecureSerializerBase.secure_loads(value)
3826

3927
@staticmethod
4028
def to_base64(value):
41-
return base64.b64encode(value).decode()
29+
return SecureSerializerBase.to_base64(value)
4230

4331
@staticmethod
4432
def from_base64(value):
45-
return base64.b64decode(value, validate=True)
33+
return SecureSerializerBase.from_base64(value)
4634

4735
@staticmethod
48-
def to_base64_pickle(value):
49-
return PoseScoreSerializerBase.to_base64(PoseScoreSerializerBase.to_pickle(value))
36+
def to_base64_pickle(obj):
37+
return SecureSerializerBase.secure_to_base64_pickle(obj)
5038

5139
@staticmethod
5240
def from_base64_pickle(value):
53-
return PoseScoreSerializerBase.from_pickle(PoseScoreSerializerBase.from_base64(value))
41+
return SecureSerializerBase.secure_from_base64_pickle(value)
5442

5543
@staticmethod
5644
def bool_from_str(value):

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"run",
5050
"update_scores",
5151
]
52-
__version__: str = "2.1.1"
52+
__version__: str = "2.1.2"
5353

5454
_print_conda_warnings()
5555

source/src/python/PyRosetta/src/pyrosetta/distributed/cluster/serialization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from functools import singledispatch, wraps
4040
from pyrosetta.distributed.packed_pose.core import PackedPose
41+
from pyrosetta.secure_unpickle import SecureSerializerBase
4142
from typing import (
4243
Any,
4344
Dict,
@@ -189,7 +190,7 @@ def compress_packed_pose(self, packed_pose: Any) -> Union[NoReturn, None, bytes]
189190
compressed_packed_pose = None
190191
elif isinstance(packed_pose, PackedPose):
191192
packed_pose = update_scores(packed_pose)
192-
compressed_packed_pose = self.encoder(packed_pose.pickled_pose)
193+
compressed_packed_pose = self.encoder(io.to_pickle(packed_pose))
193194
else:
194195
raise TypeError(
195196
"The 'packed_pose' argument parameter must be of type `NoneType` or `PackedPose`."
@@ -200,8 +201,8 @@ def compress_packed_pose(self, packed_pose: Any) -> Union[NoReturn, None, bytes]
200201
@requires_compression
201202
def decompress_packed_pose(self, compressed_packed_pose: Any) -> Union[NoReturn, None, PackedPose]:
202203
"""
203-
Decompress a `bytes` object with the custom serialization and `cloudpickle` modules. If the 'compressed_packed_pose'
204-
argument parameter is `None`, then just return `None`.
204+
Decompress a `bytes` object with the custom serialization module and secure implementation of the `pickle` module.
205+
If the 'compressed_packed_pose' argument parameter is `None`, then just return `None`.
205206
206207
Args:
207208
compressed_packed_pose: the input `bytes` object to decompress. If `None`, then just return `None`.
@@ -215,7 +216,7 @@ def decompress_packed_pose(self, compressed_packed_pose: Any) -> Union[NoReturn,
215216
if compressed_packed_pose is None:
216217
packed_pose = None
217218
elif isinstance(compressed_packed_pose, bytes):
218-
pose = cloudpickle.loads(self.decoder(compressed_packed_pose))
219+
pose = SecureSerializerBase.secure_loads(self.decoder(compressed_packed_pose))
219220
packed_pose = io.to_packed(pose)
220221
else:
221222
raise TypeError(

source/src/python/PyRosetta/src/pyrosetta/distributed/packed_pose/core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import pyrosetta.rosetta.core.pose as pose
1414
import pyrosetta.distributed
1515

16+
from pyrosetta.secure_unpickle import SecureSerializerBase
17+
18+
1619
__all__ = ["pack_result", "pose_result", "to_packed", "to_pose", "to_dict", "to_base64", "to_pickle", "PackedPose"]
1720

1821

@@ -38,7 +41,7 @@ class PackedPose:
3841
def __init__(self, pose_or_pack):
3942
"""Create a packed pose from pose, pack, or pickled bytes."""
4043
if isinstance(pose_or_pack, pose.Pose):
41-
self.pickled_pose = pickle.dumps(pose_or_pack)
44+
self.pickled_pose = SecureSerializerBase.to_pickle(pose_or_pack)
4245
self.scores = dict(pose_or_pack.cache)
4346

4447
elif isinstance(pose_or_pack, PackedPose):
@@ -55,7 +58,7 @@ def __init__(self, pose_or_pack):
5558
@property
5659
@pyrosetta.distributed.requires_init
5760
def pose(self):
58-
return pickle.loads(self.pickled_pose)
61+
return SecureSerializerBase.secure_loads(self.pickled_pose)
5962

6063
def update_scores(self, *score_dicts, **score_kwargs):
6164
new_scores = {}
@@ -71,7 +74,9 @@ def update_scores(self, *score_dicts, **score_kwargs):
7174

7275
def clone(self):
7376
result = PackedPose(self.pose)
74-
result.scores = pickle.loads(pickle.dumps(self.scores))
77+
result.scores = SecureSerializerBase.secure_loads(
78+
SecureSerializerBase.to_pickle(self.scores)
79+
)
7580
return result
7681

7782
def empty(self):

0 commit comments

Comments
 (0)