Skip to content

Commit bd2caeb

Browse files
adiaybguomerdor001
authored andcommitted
Added feature: Eviction logic and new memory cache types.
connected the memory caching logic to the main logic because it was not connected up until now Co-authored-by: omerdor001 <omerdo@post.bgu.ac.il> Co-authored-by: adiaybgu <adiay@post.bgu.ac.il>
1 parent 81b4715 commit bd2caeb

File tree

13 files changed

+385
-46
lines changed

13 files changed

+385
-46
lines changed

flask4modelcache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def response_hitquery(cache_resp):
8484
"ANNOY": {"metric_type": "COSINE", "params": {"search_k": 10}},
8585
"AUTOINDEX": {"metric_type": "COSINE", "params": {}},
8686
} if manager.MPNet_base else None
87-
)
87+
),
88+
eviction='WTINYLFU',
89+
max_size=100000
8890
)
8991

9092

modelcache/adapter/adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def create_insert(cls, *args, **kwargs):
3030
**kwargs
3131
)
3232
except Exception as e:
33+
print(e)
3334
return str(e)
3435

3536
@classmethod

modelcache/adapter/adapter_query.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def adapt_query(cache_data_convert, *args, **kwargs):
9191
for cache_data in cache_data_list:
9292
primary_id = cache_data[1]
9393
ret = chat_cache.data_manager.get_scalar_data(
94-
cache_data, extra_param=context.get("get_scalar_data", None)
94+
cache_data, extra_param=context.get("get_scalar_data", None),model=model
9595
)
9696
if ret is None:
9797
continue
@@ -124,27 +124,27 @@ def adapt_query(cache_data_convert, *args, **kwargs):
124124

125125
if len(pre_embedding_data) <= 256:
126126
if rank_threshold <= rank:
127-
cache_answers.append((rank, ret[1]))
128-
cache_questions.append((rank, ret[0]))
127+
cache_answers.append((rank, ret[0]))
128+
cache_questions.append((rank, ret[1]))
129129
cache_ids.append((rank, primary_id))
130130
else:
131131
if rank_threshold_long <= rank:
132-
cache_answers.append((rank, ret[1]))
133-
cache_questions.append((rank, ret[0]))
132+
cache_answers.append((rank, ret[0]))
133+
cache_questions.append((rank, ret[1]))
134134
cache_ids.append((rank, primary_id))
135135
else:
136136
# 不使用 reranker 时,走原来的逻辑
137137
for cache_data in cache_data_list:
138138
primary_id = cache_data[1]
139139
ret = chat_cache.data_manager.get_scalar_data(
140-
cache_data, extra_param=context.get("get_scalar_data", None)
140+
cache_data, extra_param=context.get("get_scalar_data", None),model=model
141141
)
142142
if ret is None:
143143
continue
144144

145145
if manager.MPNet_base:
146-
cache_answers.append((cosine_similarity, ret[1]))
147-
cache_questions.append((cosine_similarity, ret[0]))
146+
cache_answers.append((cosine_similarity, ret[0]))
147+
cache_questions.append((cosine_similarity, ret[1]))
148148
cache_ids.append((cosine_similarity, primary_id))
149149
else:
150150
if "deps" in context and hasattr(ret.question, "deps"):
@@ -178,13 +178,13 @@ def adapt_query(cache_data_convert, *args, **kwargs):
178178

179179
if len(pre_embedding_data) <= 256:
180180
if rank_threshold <= rank:
181-
cache_answers.append((rank, ret[1]))
182-
cache_questions.append((rank, ret[0]))
181+
cache_answers.append((rank, ret[0]))
182+
cache_questions.append((rank, ret[1]))
183183
cache_ids.append((rank, primary_id))
184184
else:
185185
if rank_threshold_long <= rank:
186-
cache_answers.append((rank, ret[1]))
187-
cache_questions.append((rank, ret[0]))
186+
cache_answers.append((rank, ret[0]))
187+
cache_questions.append((rank, ret[1]))
188188
cache_ids.append((rank, primary_id))
189189

190190
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)

modelcache/embedding/mpnet_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self):
77

88
def to_embeddings(self, *args, **kwargs):
99
if not args:
10-
raise ValueError("No word provided for embedding.")
10+
raise ValueError("No data provided for embedding.")
1111
embeddings = self.model.encode(args)
1212
return embeddings[0] if len(args) == 1 else embeddings
1313

modelcache/manager/data_manager.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from modelcache.manager.object_data.base import ObjectBase
2323
from modelcache.manager.eviction import EvictionBase
2424
from modelcache.manager.eviction_manager import EvictionManager
25+
from modelcache.manager.eviction.memory_cache import MemoryCacheEviction
2526
from modelcache.utils.log import modelcache_log
2627

2728
NORMALIZE = True
@@ -38,9 +39,7 @@ def save_query_resp(self, query_resp_dict, **kwargs):
3839
pass
3940

4041
@abstractmethod
41-
def import_data(
42-
self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model:Any
43-
):
42+
def import_data(self, questions: List[Any], answers: List[Any], embedding_datas: List[Any], model:Any):
4443
pass
4544

4645
@abstractmethod
@@ -162,10 +161,18 @@ def __init__(
162161
self.v = v
163162
self.o = o
164163

164+
# added
165+
self.eviction_base = MemoryCacheEviction(
166+
policy=policy,
167+
maxsize=max_size,
168+
clean_size=clean_size,
169+
on_evict=self._evict_ids)
170+
165171
def save(self, questions: List[any], answers: List[any], embedding_datas: List[any], **kwargs):
166172
model = kwargs.pop("model", None)
167173
self.import_data(questions, answers, embedding_datas, model)
168174

175+
169176
def save_query_resp(self, query_resp_dict, **kwargs):
170177
save_query_start_time = time.time()
171178
self.s.insert_query_resp(query_resp_dict, **kwargs)
@@ -217,14 +224,20 @@ def import_data(
217224
cache_datas.append([ans, question, embedding_data, model])
218225

219226
ids = self.s.batch_insert(cache_datas)
220-
datas_ = [VectorData(id=ids[i], data=embedding_data.astype("float32")) for i, embedding_data in enumerate(embedding_datas)]
221-
self.v.mul_add(
222-
datas_,
223-
model
224-
225-
)
227+
datas = []
228+
for i, embedding_data in enumerate(embedding_datas):
229+
_id = ids[i]
230+
datas.append(VectorData(id=_id, data=embedding_data.astype("float32")))
231+
self.eviction_base.put([(_id, cache_datas[i])],model=model)
232+
self.v.mul_add(datas,model)
226233

227234
def get_scalar_data(self, res_data, **kwargs) -> Optional[CacheData]:
235+
model = kwargs.pop("model")
236+
#Get Data from RAM Cache
237+
_id = res_data[1]
238+
cache_hit = self.eviction_base.get(_id, model=model)
239+
if cache_hit is not None:
240+
return cache_hit
228241
cache_data = self.s.get_data_by_id(res_data[1])
229242
if cache_data is None:
230243
return None
@@ -244,8 +257,10 @@ def search(self, embedding_data, **kwargs):
244257
return self.v.search(data=embedding_data, top_k=top_k, model=model)
245258

246259
def delete(self, id_list, **kwargs):
247-
model = kwargs.pop("model", None)
260+
model = kwargs.pop("model")
248261
try:
262+
for id in id_list:
263+
self.eviction_base.get_cache(model).pop(id, None) # Remove from in-memory LRU too
249264
v_delete_count = self.v.delete(ids=id_list, model=model)
250265
except Exception as e:
251266
return {'status': 'failed', 'milvus': 'delete milvus data failed, please check! e: {}'.format(e),
@@ -262,23 +277,51 @@ def delete(self, id_list, **kwargs):
262277
def create_index(self, model, **kwargs):
263278
return self.v.create(model)
264279

265-
def truncate(self, model_name):
280+
def truncate(self, model):
281+
# drop memory cache data
282+
self.eviction_base.clear(model)
283+
266284
# drop vector base data
267285
try:
268-
vector_resp = self.v.rebuild_col(model_name)
286+
vector_resp = self.v.rebuild_col(model)
269287
except Exception as e:
270288
return {'status': 'failed', 'VectorDB': 'truncate VectorDB data failed, please check! e: {}'.format(e),
271289
'ScalarDB': 'unexecuted'}
272290
if vector_resp:
273291
return {'status': 'failed', 'VectorDB': vector_resp, 'ScalarDB': 'unexecuted'}
274292
# drop scalar base data
275293
try:
276-
delete_count = self.s.model_deleted(model_name)
294+
delete_count = self.s.model_deleted(model)
277295
except Exception as e:
278296
return {'status': 'failed', 'VectorDB': 'rebuild',
279297
'ScalarDB': 'truncate scalar data failed, please check! e: {}'.format(e)}
280298
return {'status': 'success', 'VectorDB': 'rebuild', 'ScalarDB': 'delete_count: ' + str(delete_count)}
281299

300+
# added
301+
def _evict_ids(self, ids, **kwargs):
302+
model = kwargs.get("model")
303+
if not ids or any(i is None for i in ids):
304+
modelcache_log.warning("Skipping eviction for invalid IDs: %s", ids)
305+
return
306+
307+
if isinstance(ids,str):
308+
ids = [ids]
309+
310+
for _id in ids:
311+
self.eviction_base.get_cache(model).pop(_id, None)
312+
313+
try:
314+
self.s.mark_deleted(ids)
315+
modelcache_log.info("Evicted from scalar storage: %s", ids)
316+
except Exception as e:
317+
modelcache_log.error("Failed to delete from scalar storage: %s", str(e))
318+
319+
try:
320+
self.v.delete(ids, model=model)
321+
modelcache_log.info("Evicted from vector storage (model=%s): %s", model, ids)
322+
except Exception as e:
323+
modelcache_log.error("Failed to delete from vector storage (model=%s): %s", model, str(e))
324+
282325
def flush(self):
283326
self.s.flush()
284327
self.v.flush()
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from cachetools import Cache
2+
from collections import OrderedDict
3+
4+
class ARC(Cache):
5+
"""
6+
Adaptive Replacement Cache (ARC) implementation with on_evict callback.
7+
Balances recency and frequency via two active lists (T1, T2) and two ghost lists (B1, B2).
8+
Calls on_evict([key]) whenever an item is evicted from the active cache.
9+
"""
10+
11+
def __init__(self, maxsize, getsizeof=None, on_evict=None):
12+
"""
13+
Args:
14+
maxsize (int): Maximum cache size.
15+
getsizeof (callable, optional): Sizing function for items.
16+
on_evict (callable, optional): Callback called as on_evict([key]) when a key is evicted.
17+
"""
18+
super().__init__(maxsize, getsizeof)
19+
self.t1 = OrderedDict()
20+
self.t2 = OrderedDict()
21+
self.b1 = OrderedDict()
22+
self.b2 = OrderedDict()
23+
self.p = 0 # Adaptive target for T1 size.
24+
self.on_evict = on_evict
25+
26+
def __len__(self):
27+
return len(self.t1) + len(self.t2)
28+
29+
def __contains__(self, key):
30+
return key in self.t1 or key in self.t2
31+
32+
def _evict_internal(self):
33+
"""
34+
Evicts items from T1 or T2 if cache is over capacity, and prunes ghost lists.
35+
Calls on_evict for each evicted key.
36+
"""
37+
# Evict from T1 or T2 if active cache > maxsize
38+
while len(self.t1) + len(self.t2) > self.maxsize:
39+
if len(self.t1) > self.p or (len(self.t1) == 0 and len(self.t2) > 0):
40+
key, value = self.t1.popitem(last=False)
41+
self.b1[key] = value
42+
if self.on_evict:
43+
self.on_evict([key])
44+
else:
45+
key, value = self.t2.popitem(last=False)
46+
self.b2[key] = value
47+
if self.on_evict:
48+
self.on_evict([key])
49+
# Prune ghost lists to their max lengths
50+
while len(self.b1) > (self.maxsize - self.p):
51+
self.b1.popitem(last=False)
52+
while len(self.b2) > self.p:
53+
self.b2.popitem(last=False)
54+
55+
def __setitem__(self, key, value):
56+
# Remove from all lists before re-inserting
57+
for l in (self.t1, self.t2, self.b1, self.b2):
58+
l.pop(key, None)
59+
self.t1[key] = value
60+
self.t1.move_to_end(key)
61+
self._evict_internal()
62+
63+
def __getitem__(self, key):
64+
# Case 1: Hit in T1 → promote to T2
65+
if key in self.t1:
66+
value = self.t1.pop(key)
67+
self.t2[key] = value
68+
self.t2.move_to_end(key)
69+
self.p = max(0, self.p - 1)
70+
self._evict_internal()
71+
return value
72+
# Case 2: Hit in T2 → refresh in T2
73+
if key in self.t2:
74+
value = self.t2.pop(key)
75+
self.t2[key] = value
76+
self.t2.move_to_end(key)
77+
self.p = min(self.maxsize, self.p + 1)
78+
self._evict_internal()
79+
return value
80+
# Case 3: Hit in B1 (ghost) → fetch and promote to T2
81+
if key in self.b1:
82+
self.b1.pop(key)
83+
self.p = min(self.maxsize, self.p + 1)
84+
self._evict_internal()
85+
value = super().__missing__(key)
86+
self.t2[key] = value
87+
self.t2.move_to_end(key)
88+
return value
89+
# Case 4: Hit in B2 (ghost) → fetch and promote to T2
90+
if key in self.b2:
91+
self.b2.pop(key)
92+
self.p = max(0, self.p - 1)
93+
self._evict_internal()
94+
value = super().__missing__(key)
95+
self.t2[key] = value
96+
self.t2.move_to_end(key)
97+
return value
98+
# Case 5: Cold miss → handled by Cache base class (calls __setitem__ after __missing__)
99+
return super().__getitem__(key)
100+
101+
def __missing__(self, key):
102+
"""
103+
Override this in a subclass, or rely on direct assignment (cache[key] = value).
104+
"""
105+
raise KeyError(key)
106+
107+
def pop(self, key, default=None):
108+
"""
109+
Remove key from all lists.
110+
"""
111+
for l in (self.t1, self.t2, self.b1, self.b2):
112+
if key in l:
113+
return l.pop(key)
114+
return default
115+
116+
def clear(self):
117+
self.t1.clear()
118+
self.t2.clear()
119+
self.b1.clear()
120+
self.b2.clear()
121+
self.p = 0
122+
super().clear()
123+
124+
def __iter__(self):
125+
yield from self.t1
126+
yield from self.t2
127+
128+
def __repr__(self):
129+
return (f"ARC(maxsize={self.maxsize}, p={self.p}, len={len(self)}, "
130+
f"t1_len={len(self.t1)}, t2_len={len(self.t2)}, "
131+
f"b1_len={len(self.b1)}, b2_len={len(self.b2)})")

modelcache/manager/eviction/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ class EvictionBase(metaclass=ABCMeta):
99
"""
1010

1111
@abstractmethod
12-
def put(self, objs: List[Any]):
12+
def put(self, objs: List[Any], model:str):
1313
pass
1414

1515
@abstractmethod
16-
def get(self, obj: Any):
16+
def get(self, obj: Any, model:str):
1717
pass
1818

1919
@property

0 commit comments

Comments
 (0)