Skip to content

Commit 22f01c5

Browse files
William Grantbraxtonmckee
authored andcommitted
Ensure that serializer can use the 'name' pickle protocol.
Pickle supports a protocol where __reduce__returns a string giving the global name. Implementing this behaviour lets us serialize numpy ufuncs. Also adjust installInflightFunctions to handle new load behaviour, fix an instability caused by not leaving LoadedModule objects in memory, and adjust alternative test.
1 parent 2782704 commit 22f01c5

8 files changed

+113
-19
lines changed

typed_python/SerializationContext.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,24 @@
3131
import types
3232
import traceback
3333
import logging
34+
import numpy
35+
import pickle
3436

3537

3638
_badModuleCache = set()
3739

3840

41+
def pickledByStr(module_name: str, name: str) -> None:
42+
"""Generate the object given the module_name and name.
43+
44+
This mimics pickle's behavior when given a string from __reduce__. The
45+
string is interpreted as the name of a global variable, and pickle.whichmodules
46+
is used to search the module namespace, generating module_name.
47+
"""
48+
module = importlib.import_module(module_name)
49+
return getattr(module, name)
50+
51+
3952
def createFunctionWithLocalsAndGlobals(code, globals):
4053
if globals is None:
4154
globals = {}
@@ -708,26 +721,30 @@ def walkCodeObject(code):
708721
return (createFunctionWithLocalsAndGlobals, args, representation)
709722

710723
if not isinstance(inst, type) and hasattr(type(inst), '__reduce_ex__'):
711-
res = inst.__reduce_ex__(4)
724+
if isinstance(inst, numpy.ufunc):
725+
res = inst.__name__
726+
else:
727+
res = inst.__reduce_ex__(4)
712728

713-
# pickle supports a protocol where __reduce__ can return a string
714-
# giving a global name. We'll already find that separately, so we
715-
# don't want to handle it here. We ought to look at this in more detail
716-
# however
729+
# mimic pickle's behaviour when a string is received.
717730
if isinstance(res, str):
718-
return None
731+
name_tuple = (inst, res)
732+
module_name = pickle.whichmodule(*name_tuple)
733+
res = (pickledByStr, (module_name, res,), pickledByStr)
719734

720735
return res
721736

722737
if not isinstance(inst, type) and hasattr(type(inst), '__reduce__'):
723-
res = inst.__reduce__()
738+
if isinstance(inst, numpy.ufunc):
739+
res = inst.__name__
740+
else:
741+
res = inst.__reduce()
724742

725-
# pickle supports a protocol where __reduce__ can return a string
726-
# giving a global name. We'll already find that separately, so we
727-
# don't want to handle it here. We ought to look at this in more detail
728-
# however
743+
# mimic pickle's behaviour when a string is received.
729744
if isinstance(res, str):
730-
return None
745+
name_tuple = (inst, res)
746+
module_name = pickle.whichmodule(*name_tuple)
747+
res = (pickledByStr, (module_name, res,), pickledByStr)
731748

732749
return res
733750

@@ -736,6 +753,9 @@ def walkCodeObject(code):
736753
def setInstanceStateFromRepresentation(
737754
self, instance, representation=None, itemIt=None, kvPairIt=None, setStateFun=None
738755
):
756+
if representation is pickledByStr:
757+
return
758+
739759
if representation is reconstructTypeFunctionType:
740760
return
741761

typed_python/compiler/global_variable_definition.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,4 @@ def __eq__(self, other):
8787
return self.name == other.name and self.type == other.type and self.metadata == other.metadata
8888

8989
def __str__(self):
90-
metadata_str = str(self.metadata) if len(str(self.metadata)) < 100 else str(self.metadata)[:100] + "..."
91-
return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={metadata_str})"
90+
return f"GlobalVariableDefinition(name={self.name}, type={self.type}, metadata={pad(str(self.metadata))})"

typed_python/compiler/llvm_compiler_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typed_python.compiler.module_definition import ModuleDefinition
2121
from typed_python.compiler.global_variable_definition import GlobalVariableMetadata
2222

23+
from typed_python.test_util import evaluateExprInFreshProcess
24+
2325
import pytest
2426
import ctypes
2527

@@ -131,3 +133,28 @@ def test_create_binary_shared_object():
131133
pointers[0].set(5)
132134

133135
assert loaded.functionPointers['__test_f_2']() == 5
136+
137+
138+
@pytest.mark.skipif('sys.platform=="darwin"')
139+
def test_loaded_modules_persist():
140+
"""
141+
Make sure that loaded modules are persisted in the converter state.
142+
143+
We have to maintain these references to avoid surprise segfaults - if this test fails,
144+
it should be because the GlobalVariableDefinition memory management has been refactored.
145+
"""
146+
147+
# compile a module
148+
xmodule = "\n".join([
149+
"@Entrypoint",
150+
"def f(x):",
151+
" return x + 1",
152+
"@Entrypoint",
153+
"def g(x):",
154+
" return f(x) * 100",
155+
"g(1000)",
156+
"def get_loaded_modules():",
157+
" return len(Runtime.singleton().converter.loadedUncachedModules)"
158+
])
159+
VERSION1 = {'x.py': xmodule}
160+
assert evaluateExprInFreshProcess(VERSION1, 'x.get_loaded_modules()') == 1

typed_python/compiler/loaded_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(self, functionPointers, serializedGlobalVariableDefinitions):
2828

2929
self.functionPointers[ModuleDefinition.GET_GLOBAL_VARIABLES_NAME](self.pointers.pointerUnsafe(0))
3030

31+
self.installedGlobalVariableDefinitions = {}
32+
3133
@staticmethod
3234
def validateGlobalVariables(serializedGlobalVariableDefinitions: Dict[str, bytes]) -> bool:
3335
"""Check that each global variable definition is sensible.
@@ -83,6 +85,8 @@ def linkGlobalVariables(self, variable_names: List[str] = None) -> None:
8385

8486
meta = SerializationContext().deserialize(self.orderedDefs[i]).metadata
8587

88+
self.installedGlobalVariableDefinitions[i] = meta
89+
8690
if meta.matches.StringConstant:
8791
self.pointers[i].cast(str).initialize(meta.value)
8892

typed_python/compiler/python_to_native_converter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def __init__(self, llvmCompiler, compilerCache):
125125
self.llvmCompiler = llvmCompiler
126126
self.compilerCache = compilerCache
127127

128+
# all LoadedModule objects that we have created. We need to keep them alive so
129+
# that any python metadata objects the've created stay alive as well. Ultimately, this
130+
# may not be the place we put these objects (for instance, you could imagine a
131+
# 'dummy' compiler cache or something). But for now, we need to keep them alive.
132+
self.loadedUncachedModules = []
133+
128134
# if True, then insert additional code to check for undefined behavior.
129135
self.generateDebugChecks = False
130136

@@ -191,6 +197,7 @@ def buildAndLinkNewModule(self):
191197
if self.compilerCache is None:
192198
loadedModule = self.llvmCompiler.buildModule(targets)
193199
loadedModule.linkGlobalVariables()
200+
self.loadedUncachedModules.append(loadedModule)
194201
return
195202

196203
# get a set of function names that we depend on
@@ -926,7 +933,11 @@ def _installInflightFunctions(self):
926933
outboundTargets = []
927934
for outboundFuncId in self._dependencies.getNamesDependedOn(identifier):
928935
name = self._link_name_for_identity[outboundFuncId]
929-
outboundTargets.append(self._targets[name])
936+
target = self.getTarget(name)
937+
if target is not None:
938+
outboundTargets.append(target)
939+
else:
940+
raise RuntimeError(f'dependency not found for {name}.')
930941

931942
nativeFunction, actual_output_type = self._inflight_definitions.get(identifier)
932943

typed_python/compiler/tests/numpy_interaction_test.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typed_python import ListOf, Entrypoint
1+
from typed_python import ListOf, Entrypoint, SerializationContext
22
import numpy
33
import numpy.linalg
44

@@ -44,3 +44,12 @@ def test_listof_from_sliced_numpy_array():
4444
y = x[::2]
4545

4646
assert ListOf(int)(y) == [0, 2]
47+
48+
49+
def test_can_serialize_numpy_ufunc():
50+
assert numpy.sin == SerializationContext().deserialize(SerializationContext().serialize(numpy.sin))
51+
52+
53+
def test_can_serialize_numpy_array():
54+
x = numpy.ones(10)
55+
assert (x == SerializationContext().deserialize(SerializationContext().serialize(x))).all()

typed_python/compiler/tests/type_of_instances_compilation_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ def typeOfArg(x: C):
1717

1818
def test_type_of_alternative_is_specific():
1919
for members in [{}, {'a': int}]:
20-
A = Alternative("A", A=members)
20+
Alt = Alternative("Alt", A=members)
2121

2222
@Entrypoint
23-
def typeOfArg(x: A):
23+
def typeOfArg(x: Alt):
2424
return type(x)
2525

26-
assert typeOfArg(A.A()) is A.A
26+
assert typeOfArg(Alt.A()) is Alt.A
2727

2828

2929
def test_type_of_concrete_alternative_is_specific():

typed_python/types_serialization_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3061,3 +3061,27 @@ def f(self):
30613061
print(x)
30623062
# TODO: make this True
30633063
# assert x[0].f.__closure__[0].cell_contents is x
3064+
3065+
def test_serialize_pyobj_with_custom_reduce(self):
3066+
class CustomReduceObject:
3067+
def __reduce__(self):
3068+
return 'CustomReduceObject'
3069+
3070+
assert CustomReduceObject == SerializationContext().deserialize(SerializationContext().serialize(CustomReduceObject))
3071+
3072+
def test_serialize_pyobj_in_MRTG_with_custom_reduce(self):
3073+
def getX():
3074+
class InnerCustomReduceObject:
3075+
def __reduce__(self):
3076+
return 'InnerCustomReduceObject'
3077+
3078+
def f(self):
3079+
return x
3080+
3081+
x = (InnerCustomReduceObject, InnerCustomReduceObject)
3082+
3083+
return x
3084+
3085+
x = callFunctionInFreshProcess(getX, (), showStdout=True)
3086+
3087+
assert x == SerializationContext().deserialize(SerializationContext().serialize(x))

0 commit comments

Comments
 (0)