Skip to content

Commit b2b1a09

Browse files
committed
add caching on local side to prevent extraneous computations
1 parent 39bab00 commit b2b1a09

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

magicpickle/magicpickle.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import os
2-
import argparse
32
import shutil
43
import socket
54
import joblib
65
import tempfile
76
import subprocess
87

9-
from typing import Union, Callable
8+
from typing import Union, Callable, Optional
109

1110

1211
class MagicPickle:
@@ -27,6 +26,9 @@ def __init__(
2726
local_hostname_or_func: Union[str, Callable[[], bool]],
2827
verbose: bool = True,
2928
compress: Union[bool, int] = True,
29+
local_store_cache: Optional[str] = os.path.join(
30+
tempfile.gettempdir(), "magicpickle_cache"
31+
),
3032
):
3133
"""
3234
Parameters
@@ -36,6 +38,8 @@ def __init__(
3638
verbose
3739
compress
3840
compression level for joblib.dump
41+
local_store_cache
42+
persistent cache file used, triggered when empty string passed into prompt on local
3943
"""
4044
if callable(local_hostname_or_func):
4145
self.is_local = local_hostname_or_func()
@@ -45,6 +49,7 @@ def __init__(
4549

4650
self.verbose = verbose
4751
self.compress = compress
52+
self.local_store_cache = local_store_cache
4853

4954
if self.verbose:
5055
print(f"MagicPickle is_local: {self.is_local}")
@@ -58,19 +63,33 @@ def __enter__(self):
5863
print(f"MagicPickle tmpdir: {self.tmpdir.name}")
5964

6065
if self.is_local:
66+
print("Press enter to load from cache")
6167
command = input("Enter wormhole command: ").strip()
62-
# of the form wormhole receive 89-ohio-buzzard
63-
assert (
64-
command.startswith("wormhole receive") and len(command.split()) == 3
65-
), "Invalid command received"
66-
code = command.split()[-1]
67-
command = f"wormhole receive --accept-file {code}"
68-
subprocess.run(command.split(), cwd=self.tmpdir.name, check=True)
69-
70-
assert os.path.exists(
71-
self.store_path
72-
), f"store not found in {self.tmpdir.name}"
73-
self.store = joblib.load(self.store_path)
68+
if command == "":
69+
# try to load from cache
70+
assert self.local_store_cache is not None, "local_store_cache is None"
71+
assert os.path.exists(
72+
self.local_store_cache
73+
), f"cache not found in {self.local_store_cache}"
74+
75+
self.store = joblib.load(self.local_store_cache)
76+
else:
77+
# of the form wormhole receive 89-ohio-buzzard
78+
assert (
79+
command.startswith("wormhole receive") and len(command.split()) == 3
80+
), "Invalid command received"
81+
code = command.split()[-1]
82+
command = f"wormhole receive --accept-file {code}"
83+
subprocess.run(command.split(), cwd=self.tmpdir.name, check=True)
84+
85+
assert os.path.exists(
86+
self.store_path
87+
), f"store not found in {self.tmpdir.name}"
88+
89+
if self.local_store_cache is not None:
90+
shutil.copy(self.store_path, self.local_store_cache)
91+
92+
self.store = joblib.load(self.store_path)
7493
else:
7594
self.store = []
7695

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "magicpickle"
3-
version = "0.0.4"
3+
version = "0.0.5"
44
description = "A wrapper around magic-wormhole and joblib to send pickled objects across the internet."
55
readme = "README.md"
66
license = {text = "MIT License"}

0 commit comments

Comments
 (0)