1
1
import os
2
- import argparse
3
2
import shutil
4
3
import socket
5
4
import joblib
6
5
import tempfile
7
6
import subprocess
8
7
9
- from typing import Union , Callable
8
+ from typing import Union , Callable , Optional
10
9
11
10
12
11
class MagicPickle :
@@ -27,6 +26,9 @@ def __init__(
27
26
local_hostname_or_func : Union [str , Callable [[], bool ]],
28
27
verbose : bool = True ,
29
28
compress : Union [bool , int ] = True ,
29
+ local_store_cache : Optional [str ] = os .path .join (
30
+ tempfile .gettempdir (), "magicpickle_cache"
31
+ ),
30
32
):
31
33
"""
32
34
Parameters
@@ -36,6 +38,8 @@ def __init__(
36
38
verbose
37
39
compress
38
40
compression level for joblib.dump
41
+ local_store_cache
42
+ persistent cache file used, triggered when empty string passed into prompt on local
39
43
"""
40
44
if callable (local_hostname_or_func ):
41
45
self .is_local = local_hostname_or_func ()
@@ -45,6 +49,7 @@ def __init__(
45
49
46
50
self .verbose = verbose
47
51
self .compress = compress
52
+ self .local_store_cache = local_store_cache
48
53
49
54
if self .verbose :
50
55
print (f"MagicPickle is_local: { self .is_local } " )
@@ -58,19 +63,33 @@ def __enter__(self):
58
63
print (f"MagicPickle tmpdir: { self .tmpdir .name } " )
59
64
60
65
if self .is_local :
66
+ print ("Press enter to load from cache" )
61
67
command = input ("Enter wormhole command: " ).strip ()
62
- # of the form wormhole receive 89-ohio-buzzard
63
- assert (
64
- command .startswith ("wormhole receive" ) and len (command .split ()) == 3
65
- ), "Invalid command received"
66
- code = command .split ()[- 1 ]
67
- command = f"wormhole receive --accept-file { code } "
68
- subprocess .run (command .split (), cwd = self .tmpdir .name , check = True )
69
-
70
- assert os .path .exists (
71
- self .store_path
72
- ), f"store not found in { self .tmpdir .name } "
73
- self .store = joblib .load (self .store_path )
68
+ if command == "" :
69
+ # try to load from cache
70
+ assert self .local_store_cache is not None , "local_store_cache is None"
71
+ assert os .path .exists (
72
+ self .local_store_cache
73
+ ), f"cache not found in { self .local_store_cache } "
74
+
75
+ self .store = joblib .load (self .local_store_cache )
76
+ else :
77
+ # of the form wormhole receive 89-ohio-buzzard
78
+ assert (
79
+ command .startswith ("wormhole receive" ) and len (command .split ()) == 3
80
+ ), "Invalid command received"
81
+ code = command .split ()[- 1 ]
82
+ command = f"wormhole receive --accept-file { code } "
83
+ subprocess .run (command .split (), cwd = self .tmpdir .name , check = True )
84
+
85
+ assert os .path .exists (
86
+ self .store_path
87
+ ), f"store not found in { self .tmpdir .name } "
88
+
89
+ if self .local_store_cache is not None :
90
+ shutil .copy (self .store_path , self .local_store_cache )
91
+
92
+ self .store = joblib .load (self .store_path )
74
93
else :
75
94
self .store = []
76
95
0 commit comments