61
61
import dataclasses
62
62
import dis
63
63
from enum import Enum
64
+ import functools
64
65
import io
65
66
import itertools
66
67
import logging
100
101
101
102
PYPY = platform .python_implementation () == "PyPy"
102
103
104
+
105
+ def uuid_generator (_ ):
106
+ return uuid .uuid4 ().hex
107
+
108
+
109
+ @dataclasses .dataclass
110
+ class CloudPickleConfig :
111
+ """Configuration for cloudpickle behavior."""
112
+ id_generator : typing .Optional [callable ] = uuid_generator
113
+ skip_reset_dynamic_type_state : bool = False
114
+
115
+
116
+ DEFAULT_CONFIG = CloudPickleConfig ()
117
+
103
118
builtin_code_type = None
104
119
if PYPY :
105
120
# builtin-code objects only exist in pypy
108
123
_extract_code_globals_cache = weakref .WeakKeyDictionary ()
109
124
110
125
111
- def _get_or_create_tracker_id (class_def ):
126
+ def _get_or_create_tracker_id (class_def , id_generator ):
112
127
with _DYNAMIC_CLASS_TRACKER_LOCK :
113
128
class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS .get (class_def )
114
- if class_tracker_id is None :
115
- class_tracker_id = uuid . uuid4 (). hex
129
+ if class_tracker_id is None and id_generator is not None :
130
+ class_tracker_id = id_generator ( class_def )
116
131
_DYNAMIC_CLASS_TRACKER_BY_CLASS [class_def ] = class_tracker_id
117
132
_DYNAMIC_CLASS_TRACKER_BY_ID [class_tracker_id ] = class_def
118
133
return class_tracker_id
@@ -593,26 +608,26 @@ def _make_typevar(
593
608
return _lookup_class_or_track (class_tracker_id , tv )
594
609
595
610
596
- def _decompose_typevar (obj ):
611
+ def _decompose_typevar (obj , config ):
597
612
return (
598
613
obj .__name__ ,
599
614
obj .__bound__ ,
600
615
obj .__constraints__ ,
601
616
obj .__covariant__ ,
602
617
obj .__contravariant__ ,
603
- _get_or_create_tracker_id (obj ),
618
+ _get_or_create_tracker_id (obj , config . id_generator ),
604
619
)
605
620
606
621
607
- def _typevar_reduce (obj ):
622
+ def _typevar_reduce (obj , config ):
608
623
# TypeVar instances require the module information hence why we
609
624
# are not using the _should_pickle_by_reference directly
610
625
module_and_name = _lookup_module_and_qualname (obj , name = obj .__name__ )
611
626
612
627
if module_and_name is None :
613
- return (_make_typevar , _decompose_typevar (obj ))
628
+ return (_make_typevar , _decompose_typevar (obj , config ))
614
629
elif _is_registered_pickle_by_value (module_and_name [0 ]):
615
- return (_make_typevar , _decompose_typevar (obj ))
630
+ return (_make_typevar , _decompose_typevar (obj , config ))
616
631
617
632
return (getattr , module_and_name )
618
633
@@ -656,7 +671,7 @@ def _make_dict_items(obj, is_ordered=False):
656
671
# -------------------------------------------------
657
672
658
673
659
- def _class_getnewargs (obj ):
674
+ def _class_getnewargs (obj , config ):
660
675
type_kwargs = {}
661
676
if "__module__" in obj .__dict__ :
662
677
type_kwargs ["__module__" ] = obj .__module__
@@ -670,20 +685,20 @@ def _class_getnewargs(obj):
670
685
obj .__name__ ,
671
686
_get_bases (obj ),
672
687
type_kwargs ,
673
- _get_or_create_tracker_id (obj ),
688
+ _get_or_create_tracker_id (obj , config . id_generator ),
674
689
None ,
675
690
)
676
691
677
692
678
- def _enum_getnewargs (obj ):
693
+ def _enum_getnewargs (obj , config ):
679
694
members = {e .name : e .value for e in obj }
680
695
return (
681
696
obj .__bases__ ,
682
697
obj .__name__ ,
683
698
obj .__qualname__ ,
684
699
members ,
685
700
obj .__module__ ,
686
- _get_or_create_tracker_id (obj ),
701
+ _get_or_create_tracker_id (obj , config . id_generator ),
687
702
None ,
688
703
)
689
704
@@ -1028,7 +1043,7 @@ def _weakset_reduce(obj):
1028
1043
return weakref .WeakSet , (list (obj ), )
1029
1044
1030
1045
1031
- def _dynamic_class_reduce (obj ):
1046
+ def _dynamic_class_reduce (obj , config ):
1032
1047
"""Save a class that can't be referenced as a module attribute.
1033
1048
1034
1049
This method is used to serialize classes that are defined inside
@@ -1038,24 +1053,28 @@ def _dynamic_class_reduce(obj):
1038
1053
if Enum is not None and issubclass (obj , Enum ):
1039
1054
return (
1040
1055
_make_skeleton_enum ,
1041
- _enum_getnewargs (obj ),
1056
+ _enum_getnewargs (obj , config ),
1042
1057
_enum_getstate (obj ),
1043
1058
None ,
1044
1059
None ,
1045
- _class_setstate ,
1060
+ functools .partial (
1061
+ _class_setstate ,
1062
+ skip_reset_dynamic_type_state = config .skip_reset_dynamic_type_state ),
1046
1063
)
1047
1064
else :
1048
1065
return (
1049
1066
_make_skeleton_class ,
1050
- _class_getnewargs (obj ),
1067
+ _class_getnewargs (obj , config ),
1051
1068
_class_getstate (obj ),
1052
1069
None ,
1053
1070
None ,
1054
- _class_setstate ,
1071
+ functools .partial (
1072
+ _class_setstate ,
1073
+ skip_reset_dynamic_type_state = config .skip_reset_dynamic_type_state ),
1055
1074
)
1056
1075
1057
1076
1058
- def _class_reduce (obj ):
1077
+ def _class_reduce (obj , config ):
1059
1078
"""Select the reducer depending on the dynamic nature of the class obj."""
1060
1079
if obj is type (None ): # noqa
1061
1080
return type , (None , )
@@ -1066,7 +1085,7 @@ def _class_reduce(obj):
1066
1085
elif obj in _BUILTIN_TYPE_NAMES :
1067
1086
return _builtin_type , (_BUILTIN_TYPE_NAMES [obj ], )
1068
1087
elif not _should_pickle_by_reference (obj ):
1069
- return _dynamic_class_reduce (obj )
1088
+ return _dynamic_class_reduce (obj , config )
1070
1089
return NotImplemented
1071
1090
1072
1091
@@ -1150,14 +1169,12 @@ def _function_setstate(obj, state):
1150
1169
setattr (obj , k , v )
1151
1170
1152
1171
1153
- def _class_setstate (obj , state ):
1154
- # This breaks the ability to modify the state of a dynamic type in the main
1155
- # process wth the assumption that the type is updatable in the child process.
1172
+ def _class_setstate (obj , state , skip_reset_dynamic_type_state ):
1173
+ # Lock while potentially modifying class state.
1156
1174
with _DYNAMIC_CLASS_TRACKER_LOCK :
1157
- if obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS :
1175
+ if skip_reset_dynamic_type_state and obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS :
1158
1176
return obj
1159
1177
_DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS [obj ] = True
1160
-
1161
1178
state , slotstate = state
1162
1179
registry = None
1163
1180
for attrname , attr in state .items ():
@@ -1229,7 +1246,6 @@ class Pickler(pickle.Pickler):
1229
1246
_dispatch_table [types .MethodType ] = _method_reduce
1230
1247
_dispatch_table [types .MappingProxyType ] = _mappingproxy_reduce
1231
1248
_dispatch_table [weakref .WeakSet ] = _weakset_reduce
1232
- _dispatch_table [typing .TypeVar ] = _typevar_reduce
1233
1249
_dispatch_table [_collections_abc .dict_keys ] = _dict_keys_reduce
1234
1250
_dispatch_table [_collections_abc .dict_values ] = _dict_values_reduce
1235
1251
_dispatch_table [_collections_abc .dict_items ] = _dict_items_reduce
@@ -1309,7 +1325,8 @@ def dump(self, obj):
1309
1325
else :
1310
1326
raise
1311
1327
1312
- def __init__ (self , file , protocol = None , buffer_callback = None ):
1328
+ def __init__ (
1329
+ self , file , protocol = None , buffer_callback = None , config = DEFAULT_CONFIG ):
1313
1330
if protocol is None :
1314
1331
protocol = DEFAULT_PROTOCOL
1315
1332
super ().__init__ (file , protocol = protocol , buffer_callback = buffer_callback )
@@ -1318,6 +1335,7 @@ def __init__(self, file, protocol=None, buffer_callback=None):
1318
1335
# their global namespace at unpickling time.
1319
1336
self .globals_ref = {}
1320
1337
self .proto = int (protocol )
1338
+ self .config = config
1321
1339
1322
1340
if not PYPY :
1323
1341
# pickle.Pickler is the C implementation of the CPython pickler and
@@ -1384,7 +1402,9 @@ def reducer_override(self, obj):
1384
1402
is_anyclass = False
1385
1403
1386
1404
if is_anyclass :
1387
- return _class_reduce (obj )
1405
+ return _class_reduce (obj , self .config )
1406
+ elif isinstance (obj , typing .TypeVar ): # Add this check
1407
+ return _typevar_reduce (obj , self .config )
1388
1408
elif isinstance (obj , types .FunctionType ):
1389
1409
return self ._function_reduce (obj )
1390
1410
else :
@@ -1454,12 +1474,20 @@ def save_global(self, obj, name=None, pack=struct.pack):
1454
1474
if name is not None :
1455
1475
super ().save_global (obj , name = name )
1456
1476
elif not _should_pickle_by_reference (obj , name = name ):
1457
- self ._save_reduce_pickle5 (* _dynamic_class_reduce (obj ), obj = obj )
1477
+ self ._save_reduce_pickle5 (
1478
+ * _dynamic_class_reduce (obj , self .config ), obj = obj )
1458
1479
else :
1459
1480
super ().save_global (obj , name = name )
1460
1481
1461
1482
dispatch [type ] = save_global
1462
1483
1484
+ def save_typevar (self , obj , name = None ):
1485
+ """Handle TypeVar objects with access to config."""
1486
+ return self ._save_reduce_pickle5 (
1487
+ * _typevar_reduce (obj , self .config ), obj = obj )
1488
+
1489
+ dispatch [typing .TypeVar ] = save_typevar
1490
+
1463
1491
def save_function (self , obj , name = None ):
1464
1492
"""Registered with the dispatch to handle all function types.
1465
1493
@@ -1505,7 +1533,7 @@ def save_pypy_builtin_func(self, obj):
1505
1533
# Shorthands similar to pickle.dump/pickle.dumps
1506
1534
1507
1535
1508
- def dump (obj , file , protocol = None , buffer_callback = None ):
1536
+ def dump (obj , file , protocol = None , buffer_callback = None , config = DEFAULT_CONFIG ):
1509
1537
"""Serialize obj as bytes streamed into file
1510
1538
1511
1539
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -1518,10 +1546,12 @@ def dump(obj, file, protocol=None, buffer_callback=None):
1518
1546
implementation details that can change from one Python version to the
1519
1547
next).
1520
1548
"""
1521
- Pickler (file , protocol = protocol , buffer_callback = buffer_callback ).dump (obj )
1549
+ Pickler (
1550
+ file , protocol = protocol , buffer_callback = buffer_callback ,
1551
+ config = config ).dump (obj )
1522
1552
1523
1553
1524
- def dumps (obj , protocol = None , buffer_callback = None ):
1554
+ def dumps (obj , protocol = None , buffer_callback = None , config = DEFAULT_CONFIG ):
1525
1555
"""Serialize obj as a string of bytes allocated in memory
1526
1556
1527
1557
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -1535,7 +1565,8 @@ def dumps(obj, protocol=None, buffer_callback=None):
1535
1565
next).
1536
1566
"""
1537
1567
with io .BytesIO () as file :
1538
- cp = Pickler (file , protocol = protocol , buffer_callback = buffer_callback )
1568
+ cp = Pickler (
1569
+ file , protocol = protocol , buffer_callback = buffer_callback , config = config )
1539
1570
cp .dump (obj )
1540
1571
return file .getvalue ()
1541
1572
0 commit comments