Skip to content

Commit 4d58e6d

Browse files
claudevdmClaude
andauthored
Cloudpickle configurable (#35422)
* Add configurable cloudpickler. * Trigger postcommits. --------- Co-authored-by: Claude <cvandermerwe@google.com>
1 parent f56d0fa commit 4d58e6d

File tree

3 files changed

+70
-35
lines changed

3 files changed

+70
-35
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
22
"comment": "Modify this file in a trivial way to cause this test suite to run.",
3-
"modification": 14
3+
"modification": 1
44
}
55

sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import dataclasses
6262
import dis
6363
from enum import Enum
64+
import functools
6465
import io
6566
import itertools
6667
import logging
@@ -100,6 +101,20 @@
100101

101102
PYPY = platform.python_implementation() == "PyPy"
102103

104+
105+
def uuid_generator(_):
106+
return uuid.uuid4().hex
107+
108+
109+
@dataclasses.dataclass
110+
class CloudPickleConfig:
111+
"""Configuration for cloudpickle behavior."""
112+
id_generator: typing.Optional[callable] = uuid_generator
113+
skip_reset_dynamic_type_state: bool = False
114+
115+
116+
DEFAULT_CONFIG = CloudPickleConfig()
117+
103118
builtin_code_type = None
104119
if PYPY:
105120
# builtin-code objects only exist in pypy
@@ -108,11 +123,11 @@
108123
_extract_code_globals_cache = weakref.WeakKeyDictionary()
109124

110125

111-
def _get_or_create_tracker_id(class_def):
126+
def _get_or_create_tracker_id(class_def, id_generator):
112127
with _DYNAMIC_CLASS_TRACKER_LOCK:
113128
class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
114-
if class_tracker_id is None:
115-
class_tracker_id = uuid.uuid4().hex
129+
if class_tracker_id is None and id_generator is not None:
130+
class_tracker_id = id_generator(class_def)
116131
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
117132
_DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def
118133
return class_tracker_id
@@ -593,26 +608,26 @@ def _make_typevar(
593608
return _lookup_class_or_track(class_tracker_id, tv)
594609

595610

596-
def _decompose_typevar(obj):
611+
def _decompose_typevar(obj, config):
597612
return (
598613
obj.__name__,
599614
obj.__bound__,
600615
obj.__constraints__,
601616
obj.__covariant__,
602617
obj.__contravariant__,
603-
_get_or_create_tracker_id(obj),
618+
_get_or_create_tracker_id(obj, config.id_generator),
604619
)
605620

606621

607-
def _typevar_reduce(obj):
622+
def _typevar_reduce(obj, config):
608623
# TypeVar instances require the module information hence why we
609624
# are not using the _should_pickle_by_reference directly
610625
module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__)
611626

612627
if module_and_name is None:
613-
return (_make_typevar, _decompose_typevar(obj))
628+
return (_make_typevar, _decompose_typevar(obj, config))
614629
elif _is_registered_pickle_by_value(module_and_name[0]):
615-
return (_make_typevar, _decompose_typevar(obj))
630+
return (_make_typevar, _decompose_typevar(obj, config))
616631

617632
return (getattr, module_and_name)
618633

@@ -656,7 +671,7 @@ def _make_dict_items(obj, is_ordered=False):
656671
# -------------------------------------------------
657672

658673

659-
def _class_getnewargs(obj):
674+
def _class_getnewargs(obj, config):
660675
type_kwargs = {}
661676
if "__module__" in obj.__dict__:
662677
type_kwargs["__module__"] = obj.__module__
@@ -670,20 +685,20 @@ def _class_getnewargs(obj):
670685
obj.__name__,
671686
_get_bases(obj),
672687
type_kwargs,
673-
_get_or_create_tracker_id(obj),
688+
_get_or_create_tracker_id(obj, config.id_generator),
674689
None,
675690
)
676691

677692

678-
def _enum_getnewargs(obj):
693+
def _enum_getnewargs(obj, config):
679694
members = {e.name: e.value for e in obj}
680695
return (
681696
obj.__bases__,
682697
obj.__name__,
683698
obj.__qualname__,
684699
members,
685700
obj.__module__,
686-
_get_or_create_tracker_id(obj),
701+
_get_or_create_tracker_id(obj, config.id_generator),
687702
None,
688703
)
689704

@@ -1028,7 +1043,7 @@ def _weakset_reduce(obj):
10281043
return weakref.WeakSet, (list(obj), )
10291044

10301045

1031-
def _dynamic_class_reduce(obj):
1046+
def _dynamic_class_reduce(obj, config):
10321047
"""Save a class that can't be referenced as a module attribute.
10331048
10341049
This method is used to serialize classes that are defined inside
@@ -1038,24 +1053,28 @@ def _dynamic_class_reduce(obj):
10381053
if Enum is not None and issubclass(obj, Enum):
10391054
return (
10401055
_make_skeleton_enum,
1041-
_enum_getnewargs(obj),
1056+
_enum_getnewargs(obj, config),
10421057
_enum_getstate(obj),
10431058
None,
10441059
None,
1045-
_class_setstate,
1060+
functools.partial(
1061+
_class_setstate,
1062+
skip_reset_dynamic_type_state=config.skip_reset_dynamic_type_state),
10461063
)
10471064
else:
10481065
return (
10491066
_make_skeleton_class,
1050-
_class_getnewargs(obj),
1067+
_class_getnewargs(obj, config),
10511068
_class_getstate(obj),
10521069
None,
10531070
None,
1054-
_class_setstate,
1071+
functools.partial(
1072+
_class_setstate,
1073+
skip_reset_dynamic_type_state=config.skip_reset_dynamic_type_state),
10551074
)
10561075

10571076

1058-
def _class_reduce(obj):
1077+
def _class_reduce(obj, config):
10591078
"""Select the reducer depending on the dynamic nature of the class obj."""
10601079
if obj is type(None): # noqa
10611080
return type, (None, )
@@ -1066,7 +1085,7 @@ def _class_reduce(obj):
10661085
elif obj in _BUILTIN_TYPE_NAMES:
10671086
return _builtin_type, (_BUILTIN_TYPE_NAMES[obj], )
10681087
elif not _should_pickle_by_reference(obj):
1069-
return _dynamic_class_reduce(obj)
1088+
return _dynamic_class_reduce(obj, config)
10701089
return NotImplemented
10711090

10721091

@@ -1150,14 +1169,12 @@ def _function_setstate(obj, state):
11501169
setattr(obj, k, v)
11511170

11521171

1153-
def _class_setstate(obj, state):
1154-
# This breaks the ability to modify the state of a dynamic type in the main
1155-
# process wth the assumption that the type is updatable in the child process.
1172+
def _class_setstate(obj, state, skip_reset_dynamic_type_state):
1173+
# Lock while potentially modifying class state.
11561174
with _DYNAMIC_CLASS_TRACKER_LOCK:
1157-
if obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS:
1175+
if skip_reset_dynamic_type_state and obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS:
11581176
return obj
11591177
_DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS[obj] = True
1160-
11611178
state, slotstate = state
11621179
registry = None
11631180
for attrname, attr in state.items():
@@ -1229,7 +1246,6 @@ class Pickler(pickle.Pickler):
12291246
_dispatch_table[types.MethodType] = _method_reduce
12301247
_dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
12311248
_dispatch_table[weakref.WeakSet] = _weakset_reduce
1232-
_dispatch_table[typing.TypeVar] = _typevar_reduce
12331249
_dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
12341250
_dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
12351251
_dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
@@ -1309,7 +1325,8 @@ def dump(self, obj):
13091325
else:
13101326
raise
13111327

1312-
def __init__(self, file, protocol=None, buffer_callback=None):
1328+
def __init__(
1329+
self, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
13131330
if protocol is None:
13141331
protocol = DEFAULT_PROTOCOL
13151332
super().__init__(file, protocol=protocol, buffer_callback=buffer_callback)
@@ -1318,6 +1335,7 @@ def __init__(self, file, protocol=None, buffer_callback=None):
13181335
# their global namespace at unpickling time.
13191336
self.globals_ref = {}
13201337
self.proto = int(protocol)
1338+
self.config = config
13211339

13221340
if not PYPY:
13231341
# pickle.Pickler is the C implementation of the CPython pickler and
@@ -1384,7 +1402,9 @@ def reducer_override(self, obj):
13841402
is_anyclass = False
13851403

13861404
if is_anyclass:
1387-
return _class_reduce(obj)
1405+
return _class_reduce(obj, self.config)
1406+
elif isinstance(obj, typing.TypeVar): # Add this check
1407+
return _typevar_reduce(obj, self.config)
13881408
elif isinstance(obj, types.FunctionType):
13891409
return self._function_reduce(obj)
13901410
else:
@@ -1454,12 +1474,20 @@ def save_global(self, obj, name=None, pack=struct.pack):
14541474
if name is not None:
14551475
super().save_global(obj, name=name)
14561476
elif not _should_pickle_by_reference(obj, name=name):
1457-
self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj)
1477+
self._save_reduce_pickle5(
1478+
*_dynamic_class_reduce(obj, self.config), obj=obj)
14581479
else:
14591480
super().save_global(obj, name=name)
14601481

14611482
dispatch[type] = save_global
14621483

1484+
def save_typevar(self, obj, name=None):
1485+
"""Handle TypeVar objects with access to config."""
1486+
return self._save_reduce_pickle5(
1487+
*_typevar_reduce(obj, self.config), obj=obj)
1488+
1489+
dispatch[typing.TypeVar] = save_typevar
1490+
14631491
def save_function(self, obj, name=None):
14641492
"""Registered with the dispatch to handle all function types.
14651493
@@ -1505,7 +1533,7 @@ def save_pypy_builtin_func(self, obj):
15051533
# Shorthands similar to pickle.dump/pickle.dumps
15061534

15071535

1508-
def dump(obj, file, protocol=None, buffer_callback=None):
1536+
def dump(obj, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
15091537
"""Serialize obj as bytes streamed into file
15101538
15111539
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -1518,10 +1546,12 @@ def dump(obj, file, protocol=None, buffer_callback=None):
15181546
implementation details that can change from one Python version to the
15191547
next).
15201548
"""
1521-
Pickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)
1549+
Pickler(
1550+
file, protocol=protocol, buffer_callback=buffer_callback,
1551+
config=config).dump(obj)
15221552

15231553

1524-
def dumps(obj, protocol=None, buffer_callback=None):
1554+
def dumps(obj, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
15251555
"""Serialize obj as a string of bytes allocated in memory
15261556
15271557
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -1535,7 +1565,8 @@ def dumps(obj, protocol=None, buffer_callback=None):
15351565
next).
15361566
"""
15371567
with io.BytesIO() as file:
1538-
cp = Pickler(file, protocol=protocol, buffer_callback=buffer_callback)
1568+
cp = Pickler(
1569+
file, protocol=protocol, buffer_callback=buffer_callback, config=config)
15391570
cp.dump(obj)
15401571
return file.getvalue()
15411572

sdks/python/apache_beam/internal/cloudpickle_pickler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737

3838
from apache_beam.internal.cloudpickle import cloudpickle
3939

40+
DEFAULT_CONFIG = cloudpickle.CloudPickleConfig(
41+
skip_reset_dynamic_type_state=True)
42+
4043
try:
4144
from absl import flags
4245
except (ImportError, ModuleNotFoundError):
@@ -113,7 +116,8 @@ def dumps(
113116
o,
114117
enable_trace=True,
115118
use_zlib=False,
116-
enable_best_effort_determinism=False) -> bytes:
119+
enable_best_effort_determinism=False,
120+
config: cloudpickle.CloudPickleConfig = DEFAULT_CONFIG) -> bytes:
117121
"""For internal use only; no backwards-compatibility guarantees."""
118122
if enable_best_effort_determinism:
119123
# TODO: Add support once https://github.com/cloudpipe/cloudpickle/pull/563
@@ -123,7 +127,7 @@ def dumps(
123127
'This has only been implemented for dill.')
124128
with _pickle_lock:
125129
with io.BytesIO() as file:
126-
pickler = cloudpickle.CloudPickler(file)
130+
pickler = cloudpickle.CloudPickler(file, config=config)
127131
try:
128132
pickler.dispatch_table[type(flags.FLAGS)] = _pickle_absl_flags
129133
except NameError:

0 commit comments

Comments
 (0)