Skip to content

Cloudpickle configurable #35422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/trigger_files/beam_PostCommit_Python.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run.",
"modification": 14
"modification": 1
}

95 changes: 63 additions & 32 deletions sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import dataclasses
import dis
from enum import Enum
import functools
import io
import itertools
import logging
Expand Down Expand Up @@ -100,6 +101,20 @@

PYPY = platform.python_implementation() == "PyPy"


def uuid_generator(_):
return uuid.uuid4().hex


@dataclasses.dataclass
class CloudPickleConfig:
"""Configuration for cloudpickle behavior."""
id_generator: typing.Optional[callable] = uuid_generator
skip_reset_dynamic_type_state: bool = False


DEFAULT_CONFIG = CloudPickleConfig()

builtin_code_type = None
if PYPY:
# builtin-code objects only exist in pypy
Expand All @@ -108,11 +123,11 @@
_extract_code_globals_cache = weakref.WeakKeyDictionary()


def _get_or_create_tracker_id(class_def):
def _get_or_create_tracker_id(class_def, id_generator):
with _DYNAMIC_CLASS_TRACKER_LOCK:
class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
if class_tracker_id is None:
class_tracker_id = uuid.uuid4().hex
if class_tracker_id is None and id_generator is not None:
class_tracker_id = id_generator(class_def)
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
_DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def
return class_tracker_id
Expand Down Expand Up @@ -593,26 +608,26 @@ def _make_typevar(
return _lookup_class_or_track(class_tracker_id, tv)


def _decompose_typevar(obj):
def _decompose_typevar(obj, config):
return (
obj.__name__,
obj.__bound__,
obj.__constraints__,
obj.__covariant__,
obj.__contravariant__,
_get_or_create_tracker_id(obj),
_get_or_create_tracker_id(obj, config.id_generator),
)


def _typevar_reduce(obj):
def _typevar_reduce(obj, config):
# TypeVar instances require the module information hence why we
# are not using the _should_pickle_by_reference directly
module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__)

if module_and_name is None:
return (_make_typevar, _decompose_typevar(obj))
return (_make_typevar, _decompose_typevar(obj, config))
elif _is_registered_pickle_by_value(module_and_name[0]):
return (_make_typevar, _decompose_typevar(obj))
return (_make_typevar, _decompose_typevar(obj, config))

return (getattr, module_and_name)

Expand Down Expand Up @@ -656,7 +671,7 @@ def _make_dict_items(obj, is_ordered=False):
# -------------------------------------------------


def _class_getnewargs(obj):
def _class_getnewargs(obj, config):
type_kwargs = {}
if "__module__" in obj.__dict__:
type_kwargs["__module__"] = obj.__module__
Expand All @@ -670,20 +685,20 @@ def _class_getnewargs(obj):
obj.__name__,
_get_bases(obj),
type_kwargs,
_get_or_create_tracker_id(obj),
_get_or_create_tracker_id(obj, config.id_generator),
None,
)


def _enum_getnewargs(obj):
def _enum_getnewargs(obj, config):
members = {e.name: e.value for e in obj}
return (
obj.__bases__,
obj.__name__,
obj.__qualname__,
members,
obj.__module__,
_get_or_create_tracker_id(obj),
_get_or_create_tracker_id(obj, config.id_generator),
None,
)

Expand Down Expand Up @@ -1028,7 +1043,7 @@ def _weakset_reduce(obj):
return weakref.WeakSet, (list(obj), )


def _dynamic_class_reduce(obj):
def _dynamic_class_reduce(obj, config):
"""Save a class that can't be referenced as a module attribute.

This method is used to serialize classes that are defined inside
Expand All @@ -1038,24 +1053,28 @@ def _dynamic_class_reduce(obj):
if Enum is not None and issubclass(obj, Enum):
return (
_make_skeleton_enum,
_enum_getnewargs(obj),
_enum_getnewargs(obj, config),
_enum_getstate(obj),
None,
None,
_class_setstate,
functools.partial(
_class_setstate,
skip_reset_dynamic_type_state=config.skip_reset_dynamic_type_state),
)
else:
return (
_make_skeleton_class,
_class_getnewargs(obj),
_class_getnewargs(obj, config),
_class_getstate(obj),
None,
None,
_class_setstate,
functools.partial(
_class_setstate,
skip_reset_dynamic_type_state=config.skip_reset_dynamic_type_state),
)


def _class_reduce(obj):
def _class_reduce(obj, config):
"""Select the reducer depending on the dynamic nature of the class obj."""
if obj is type(None): # noqa
return type, (None, )
Expand All @@ -1066,7 +1085,7 @@ def _class_reduce(obj):
elif obj in _BUILTIN_TYPE_NAMES:
return _builtin_type, (_BUILTIN_TYPE_NAMES[obj], )
elif not _should_pickle_by_reference(obj):
return _dynamic_class_reduce(obj)
return _dynamic_class_reduce(obj, config)
return NotImplemented


Expand Down Expand Up @@ -1150,14 +1169,12 @@ def _function_setstate(obj, state):
setattr(obj, k, v)


def _class_setstate(obj, state):
# This breaks the ability to modify the state of a dynamic type in the main
# process wth the assumption that the type is updatable in the child process.
def _class_setstate(obj, state, skip_reset_dynamic_type_state):
# Lock while potentially modifying class state.
with _DYNAMIC_CLASS_TRACKER_LOCK:
if obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS:
if skip_reset_dynamic_type_state and obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS:
return obj
_DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS[obj] = True

state, slotstate = state
registry = None
for attrname, attr in state.items():
Expand Down Expand Up @@ -1229,7 +1246,6 @@ class Pickler(pickle.Pickler):
_dispatch_table[types.MethodType] = _method_reduce
_dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
_dispatch_table[weakref.WeakSet] = _weakset_reduce
_dispatch_table[typing.TypeVar] = _typevar_reduce
_dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
_dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
_dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
Expand Down Expand Up @@ -1309,7 +1325,8 @@ def dump(self, obj):
else:
raise

def __init__(self, file, protocol=None, buffer_callback=None):
def __init__(
self, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
if protocol is None:
protocol = DEFAULT_PROTOCOL
super().__init__(file, protocol=protocol, buffer_callback=buffer_callback)
Expand All @@ -1318,6 +1335,7 @@ def __init__(self, file, protocol=None, buffer_callback=None):
# their global namespace at unpickling time.
self.globals_ref = {}
self.proto = int(protocol)
self.config = config

if not PYPY:
# pickle.Pickler is the C implementation of the CPython pickler and
Expand Down Expand Up @@ -1384,7 +1402,9 @@ def reducer_override(self, obj):
is_anyclass = False

if is_anyclass:
return _class_reduce(obj)
return _class_reduce(obj, self.config)
elif isinstance(obj, typing.TypeVar): # Add this check
return _typevar_reduce(obj, self.config)
elif isinstance(obj, types.FunctionType):
return self._function_reduce(obj)
else:
Expand Down Expand Up @@ -1454,12 +1474,20 @@ def save_global(self, obj, name=None, pack=struct.pack):
if name is not None:
super().save_global(obj, name=name)
elif not _should_pickle_by_reference(obj, name=name):
self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj)
self._save_reduce_pickle5(
*_dynamic_class_reduce(obj, self.config), obj=obj)
else:
super().save_global(obj, name=name)

dispatch[type] = save_global

def save_typevar(self, obj, name=None):
"""Handle TypeVar objects with access to config."""
return self._save_reduce_pickle5(
*_typevar_reduce(obj, self.config), obj=obj)

dispatch[typing.TypeVar] = save_typevar

def save_function(self, obj, name=None):
"""Registered with the dispatch to handle all function types.

Expand Down Expand Up @@ -1505,7 +1533,7 @@ def save_pypy_builtin_func(self, obj):
# Shorthands similar to pickle.dump/pickle.dumps


def dump(obj, file, protocol=None, buffer_callback=None):
def dump(obj, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
"""Serialize obj as bytes streamed into file

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand All @@ -1518,10 +1546,12 @@ def dump(obj, file, protocol=None, buffer_callback=None):
implementation details that can change from one Python version to the
next).
"""
Pickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)
Pickler(
file, protocol=protocol, buffer_callback=buffer_callback,
config=config).dump(obj)


def dumps(obj, protocol=None, buffer_callback=None):
def dumps(obj, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
"""Serialize obj as a string of bytes allocated in memory

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand All @@ -1535,7 +1565,8 @@ def dumps(obj, protocol=None, buffer_callback=None):
next).
"""
with io.BytesIO() as file:
cp = Pickler(file, protocol=protocol, buffer_callback=buffer_callback)
cp = Pickler(
file, protocol=protocol, buffer_callback=buffer_callback, config=config)
cp.dump(obj)
return file.getvalue()

Expand Down
8 changes: 6 additions & 2 deletions sdks/python/apache_beam/internal/cloudpickle_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

from apache_beam.internal.cloudpickle import cloudpickle

DEFAULT_CONFIG = cloudpickle.CloudPickleConfig(
skip_reset_dynamic_type_state=True)

try:
from absl import flags
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -113,7 +116,8 @@ def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:
enable_best_effort_determinism=False,
config: cloudpickle.CloudPickleConfig = DEFAULT_CONFIG) -> bytes:
"""For internal use only; no backwards-compatibility guarantees."""
if enable_best_effort_determinism:
# TODO: Add support once https://github.com/cloudpipe/cloudpickle/pull/563
Expand All @@ -123,7 +127,7 @@ def dumps(
'This has only been implemented for dill.')
with _pickle_lock:
with io.BytesIO() as file:
pickler = cloudpickle.CloudPickler(file)
pickler = cloudpickle.CloudPickler(file, config=config)
try:
pickler.dispatch_table[type(flags.FLAGS)] = _pickle_absl_flags
except NameError:
Expand Down
Loading