Skip to content

Commit 15da47d

Browse files
Yuval-Rotholgaoznovichomerdor001adiaybgu
committed
Optimizations and fixes
* Updates to EmbeddingDispatcher: Added catch for exceptions in worker, set the processes to run in high priority * Offloaded some CPU-intensive and blocking code in adapt_insert and adapt_query to a background thread instead of having it run on the main asyncio event-loop * Fixed not inserting into memory cache after memory cache miss. * Fixes in WTINYLFU memory cache class. * Replaced hardcoded similarity threshold in cosine similarity with dynamic value Co-authored-by: olgaoznovich <ol.oznovich@gmail.com> Co-authored-by: Yuval-Roth <rothyuv@post.bgu.ac.il> Co-authored-by: omerdor001 <omerdo@post.bgu.ac.il> Co-authored-by: adiaybgu <adiay@post.bgu.ac.il>
1 parent d8afc32 commit 15da47d

File tree

13 files changed

+116
-76
lines changed

13 files changed

+116
-76
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ celerybeat.pid
9393

9494
# Environments
9595
.env
96-
.venv
96+
.venv*
9797
env/
9898
venv/
9999
ENV/

modelcache/adapter/adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ async def create_insert(cls, *args, **kwargs):
3535
return str(e)
3636

3737
@classmethod
38-
def create_remove(cls, *args, **kwargs):
38+
async def create_remove(cls, *args, **kwargs):
3939
try:
40-
return adapt_remove(
40+
return await adapt_remove(
4141
*args,
4242
**kwargs
4343
)
@@ -46,9 +46,9 @@ def create_remove(cls, *args, **kwargs):
4646
return str(e)
4747

4848
@classmethod
49-
def create_register(cls, *args, **kwargs):
49+
async def create_register(cls, *args, **kwargs):
5050
try:
51-
return adapt_register(
51+
return await adapt_register(
5252
*args,
5353
**kwargs
5454
)

modelcache/adapter/adapter_insert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ async def adapt_insert(*args, **kwargs):
1616

1717
pre_embedding_data_list = []
1818
embedding_futures_list = []
19-
# embedding_data_list = []
2019
llm_data_list = []
2120

2221
for row in chat_info:
@@ -37,7 +36,8 @@ async def adapt_insert(*args, **kwargs):
3736

3837
embedding_data_list = await asyncio.gather(*embedding_futures_list)
3938

40-
chat_cache.data_manager.save(
39+
await asyncio.to_thread(
40+
chat_cache.data_manager.save,
4141
pre_embedding_data_list,
4242
llm_data_list,
4343
embedding_data_list,

modelcache/adapter/adapter_query.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
import asyncio
23
import logging
34
from modelcache.embedding import MetricType
45
from modelcache.utils.time import time_cal
@@ -24,17 +25,20 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
2425
cache_obj=chat_cache
2526
)(pre_embedding_data)
2627

27-
cache_data_list = time_cal(
28+
search_time_cal = time_cal(
2829
chat_cache.data_manager.search,
2930
func_name="vector_search",
3031
report_func=chat_cache.report.search,
3132
cache_obj=chat_cache
32-
)(
33+
)
34+
cache_data_list = await asyncio.to_thread(
35+
search_time_cal,
3336
embedding_data,
3437
extra_param=context.get("search_func", None),
3538
top_k=kwargs.pop("top_k", -1),
3639
model=model
3740
)
41+
3842
cache_answers = []
3943
cache_questions = []
4044
cache_ids = []
@@ -43,7 +47,7 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
4347
if chat_cache.similarity_metric_type == MetricType.COSINE:
4448
cosine_similarity = cache_data_list[0][0]
4549
# This code uses the built-in cosine similarity evaluation in milvus
46-
if cosine_similarity < 0.9:
50+
if cosine_similarity < chat_cache.similarity_threshold:
4751
return None
4852
elif chat_cache.similarity_metric_type == MetricType.L2:
4953
## this is the code that uses L2 for similarity evaluation
@@ -87,8 +91,9 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
8791
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=False)
8892
for cache_data in cache_data_list:
8993
primary_id = cache_data[1]
90-
ret = chat_cache.data_manager.get_scalar_data(
91-
cache_data, extra_param=context.get("get_scalar_data", None),model=model
94+
ret = await asyncio.to_thread(
95+
chat_cache.data_manager.get_scalar_data,
96+
cache_data, extra_param=context.get("get_scalar_data", None), model=model
9297
)
9398
if ret is None:
9499
continue
@@ -133,8 +138,9 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
133138
# 不使用 reranker 时,走原来的逻辑
134139
for cache_data in cache_data_list:
135140
primary_id = cache_data[1]
136-
ret = chat_cache.data_manager.get_scalar_data(
137-
cache_data, extra_param=context.get("get_scalar_data", None),model=model
141+
ret = await asyncio.to_thread(
142+
chat_cache.data_manager.get_scalar_data,
143+
cache_data, extra_param=context.get("get_scalar_data", None), model=model
138144
)
139145
if ret is None:
140146
continue
@@ -204,7 +210,7 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
204210
)
205211
# 更新命中次数
206212
try:
207-
chat_cache.data_manager.update_hit_count(return_id)
213+
asyncio.create_task(asyncio.to_thread(chat_cache.data_manager.update_hit_count,return_id))
208214
except Exception:
209215
logging.info('update_hit_count except, please check!')
210216

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
# -*- coding: utf-8 -*-
2+
import asyncio
23

34

4-
def adapt_register(*args, **kwargs):
5+
async def adapt_register(*args, **kwargs):
56
chat_cache = kwargs.pop("cache_obj")
67
model = kwargs.pop("model", None)
78
if model is None or len(model) == 0:
89
return ValueError('')
910

10-
register_resp = chat_cache.data_manager.create_index(model)
11+
register_resp = await asyncio.to_thread(
12+
chat_cache.data_manager.create_index,
13+
model
14+
)
15+
1116
return register_resp

modelcache/adapter/adapter_remove.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# -*- coding: utf-8 -*-
2-
from modelcache.utils.error import NotInitError, RemoveError
2+
import asyncio
33

4+
from modelcache.utils.error import RemoveError
45

5-
def adapt_remove(*args, **kwargs):
6+
7+
async def adapt_remove(*args, **kwargs):
68
chat_cache = kwargs.pop("cache_obj")
79
model = kwargs.pop("model", None)
810
remove_type = kwargs.pop("remove_type", None)
@@ -13,9 +15,15 @@ def adapt_remove(*args, **kwargs):
1315
# delete data
1416
if remove_type == 'delete_by_id':
1517
id_list = kwargs.pop("id_list", [])
16-
resp = chat_cache.data_manager.delete(id_list, model=model)
18+
resp = await asyncio.to_thread(
19+
chat_cache.data_manager.delete,
20+
id_list, model=model
21+
)
1722
elif remove_type == 'truncate_by_model':
18-
resp = chat_cache.data_manager.truncate(model)
23+
resp = await asyncio.to_thread(
24+
chat_cache.data_manager.truncate,
25+
model
26+
)
1927
else:
2028
# resp = "remove_type_error"
2129
raise RemoveError()

modelcache/cache.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,6 @@
2727
#==================== Cache class definition =========================#
2828
#=====================================================================#
2929

30-
executor = ThreadPoolExecutor(max_workers=2)
31-
32-
def response_text(cache_resp):
33-
return cache_resp['data']
34-
35-
def response_hitquery(cache_resp):
36-
return cache_resp['hitQuery']
3730

3831
# noinspection PyMethodMayBeStatic
3932
class Cache:
@@ -80,11 +73,16 @@ def close():
8073
modelcache_log.error(e)
8174

8275
def save_query_resp(self, query_resp_dict, **kwargs):
83-
self.data_manager.save_query_resp(query_resp_dict, **kwargs)
76+
asyncio.create_task(asyncio.to_thread(
77+
self.data_manager.save_query_resp,
78+
query_resp_dict, **kwargs
79+
))
8480

8581
def save_query_info(self,result, model, query, delta_time_log):
86-
self.data_manager.save_query_resp(result, model=model, query=json.dumps(query, ensure_ascii=False),
87-
delta_time=delta_time_log)
82+
asyncio.create_task(asyncio.to_thread(
83+
self.data_manager.save_query_resp,
84+
result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log
85+
))
8886

8987
async def handle_request(self, param_dict: dict):
9088
# param parsing
@@ -103,7 +101,7 @@ async def handle_request(self, param_dict: dict):
103101
result = {"errorCode": 102,
104102
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
105103
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
106-
self.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
104+
self.save_query_resp(result, model=model, query='', delta_time=0)
107105
return result
108106
except Exception as e:
109107
return {"errorCode": 103, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
@@ -120,14 +118,14 @@ async def handle_request(self, param_dict: dict):
120118
elif request_type == 'insert':
121119
return await self.handle_insert(chat_info, model)
122120
elif request_type == 'remove':
123-
return self.handle_remove(model, param_dict)
121+
return await self.handle_remove(model, param_dict)
124122
elif request_type == 'register':
125-
return self.handle_register(model)
123+
return await self.handle_register(model)
126124
else:
127125
return {"errorCode": 400, "errorDesc": "bad request"}
128126

129-
def handle_register(self, model):
130-
response = adapter.ChatCompletion.create_register(
127+
async def handle_register(self, model):
128+
response = await adapter.ChatCompletion.create_register(
131129
model=model,
132130
cache_obj=self
133131
)
@@ -137,10 +135,10 @@ def handle_register(self, model):
137135
result = {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
138136
return result
139137

140-
def handle_remove(self, model, param_dict):
138+
async def handle_remove(self, model, param_dict):
141139
remove_type = param_dict.get("remove_type")
142140
id_list = param_dict.get("id_list", [])
143-
response = adapter.ChatCompletion.create_remove(
141+
response = await adapter.ChatCompletion.create_remove(
144142
model=model,
145143
remove_type=remove_type,
146144
id_list=id_list,
@@ -191,12 +189,12 @@ async def handle_query(self, model, query):
191189
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
192190
"hit_query": '', "answer": ''}
193191
else:
194-
answer = response_text(response)
195-
hit_query = response_hitquery(response)
192+
answer = response['data']
193+
hit_query = response['hitQuery']
196194
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time,
197195
"hit_query": hit_query, "answer": answer}
198196
delta_time_log = round(time.time() - start_time, 2)
199-
executor.submit(self.save_query_info, result, model, query, delta_time_log)
197+
self.save_query_info(result, model, query, delta_time_log)
200198
except Exception as e:
201199
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
202200
"hit_query": '', "answer": ''}
@@ -265,7 +263,9 @@ async def init(
265263
#==================================================#
266264

267265
# switching based on embedding_model
268-
if embedding_model == EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2:
266+
if (embedding_model == EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2
267+
or embedding_model == EmbeddingModel.HUGGINGFACE_ALL_MINILM_L6_V2
268+
or embedding_model == EmbeddingModel.HUGGINGFACE_ALL_MINILM_L12_V2):
269269
query_pre_embedding_func = query_with_role
270270
insert_pre_embedding_func = query_with_role
271271
post_process_messages_func = first
@@ -287,8 +287,8 @@ async def init(
287287

288288
# add more configurations for other embedding models as needed
289289
else:
290-
modelcache_log.error(f"Please add configuration for {embedding_model} in modelcache/__init__.py.")
291-
raise CacheError(f"Please add configuration for {embedding_model} in modelcache/__init__.py.")
290+
modelcache_log.error(f"Please add configuration for {embedding_model} in modelcache/cache.py.")
291+
raise CacheError(f"Please add configuration for {embedding_model} in modelcache/cache.py.")
292292

293293
# ====================== Data manager ==============================#
294294

@@ -300,7 +300,7 @@ async def init(
300300
config=vector_config,
301301
metric_type=similarity_metric_type,
302302
),
303-
eviction='ARC',
303+
memory_cache_policy='ARC',
304304
max_size=10000,
305305
normalize=normalize,
306306
)

modelcache/embedding/base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# -*- coding: utf-8 -*-
22
from abc import abstractmethod, ABCMeta
33

4+
from modelcache.utils.error import CacheError
45
from modelcache.utils.lazy_import import LazyImport
56
from enum import Enum
7+
8+
from modelcache.utils.log import modelcache_log
9+
610
huggingface = LazyImport("huggingface", globals(), "modelcache.embedding.huggingface")
711
data2vec = LazyImport("data2vec", globals(), "modelcache.embedding.data2vec")
812
llmEmb = LazyImport("llmEmb", globals(), "modelcache.embedding.llmEmb")
@@ -21,7 +25,7 @@ class EmbeddingModel(Enum):
2125
HUGGINGFACE_ALL_MPNET_BASE_V2 = {"dimension":768, "model_path":"sentence-transformers/all-mpnet-base-v2"}
2226
HUGGINGFACE_ALL_MINILM_L6_V2 = {"dimension":384, "model_path":"sentence-transformers/all-MiniLM-L6-v2"}
2327
HUGGINGFACE_ALL_MINILM_L12_V2 = {"dimension":384, "model_path":"sentence-transformers/all-MiniLM-L12-v2"}
24-
DATA2VEC_AUDIO = {"dimension":None, "model_path":"model/text2vec-base-chinese/"}
28+
DATA2VEC_AUDIO = {"dimension":768, "model_path":"model/text2vec-base-chinese/"}
2529
LLM_EMB2VEC_AUDIO = {"dimension":None, "model_path":None}
2630
FASTTEXT = {"dimension":None, "model_path":None}
2731
PADDLE_NLP = {"dimension":None, "model_path":None}
@@ -68,6 +72,14 @@ def get(model:EmbeddingModel, **kwargs):
6872
model_path = kwargs.pop("model_path","sentence-transformers/all-mpnet-base-v2")
6973
return huggingface.Huggingface(model_path)
7074

75+
elif model == EmbeddingModel.HUGGINGFACE_ALL_MINILM_L6_V2:
76+
model_path = kwargs.pop("model_path","sentence-transformers/all-MiniLM-L6-v2")
77+
return huggingface.Huggingface(model_path)
78+
79+
elif model == EmbeddingModel.HUGGINGFACE_ALL_MINILM_L12_V2:
80+
model_path = kwargs.pop("model_path","sentence-transformers/all-MiniLM-L12-v2")
81+
return huggingface.Huggingface(model_path)
82+
7183
elif model == EmbeddingModel.DATA2VEC_AUDIO:
7284
model_path = kwargs.pop("model_path","model/text2vec-base-chinese/")
7385
return data2vec.Data2VecAudio(model_path)
@@ -99,5 +111,7 @@ def get(model:EmbeddingModel, **kwargs):
99111
return bge_m3.BgeM3Embedding(model_path)
100112

101113
else:
102-
raise ValueError(f"Unsupported embedding model: {model}")
114+
modelcache_log.error(f"Please add configuration for {model} in modelcache/embedding/base.py.")
115+
raise CacheError(f"Please add configuration for {model} in modelcache/embedding/base.py.")
116+
103117

modelcache/embedding/embedding_dispatcher.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
import uuid
44
import asyncio
5+
import psutil
56
from asyncio import Future, AbstractEventLoop
67

78
from modelcache.embedding import EmbeddingModel
@@ -11,13 +12,18 @@
1112
def worker_func(embedding_model: EmbeddingModel, model_path, task_queue, result_queue, worker_id):
1213
base_embedding = BaseEmbedding.get(embedding_model, model_path=model_path)
1314
print(f"Embedding worker {worker_id} started.")
14-
while True:
15-
job_id, data = task_queue.get()
16-
try:
17-
result = base_embedding.to_embeddings(data)
18-
except Exception as e:
19-
result = e
20-
result_queue.put((job_id, result))
15+
try:
16+
while True:
17+
job_id, data = task_queue.get()
18+
try:
19+
result = base_embedding.to_embeddings(data)
20+
except Exception as e:
21+
result = e
22+
result_queue.put((job_id, result))
23+
except KeyboardInterrupt:
24+
print(f"Embedding worker {worker_id} stopped.")
25+
except Exception as e:
26+
print(f"Embedding worker {worker_id} encountered an error: {e}")
2127

2228

2329
class EmbeddingDispatcher:
@@ -46,6 +52,7 @@ def __init__(
4652
)
4753
p.daemon = True
4854
p.start()
55+
psutil.Process(p.pid).nice(psutil.HIGH_PRIORITY_CLASS)
4956
self.workers.append(p)
5057

5158
def _start_result_collector_thread(self):

0 commit comments

Comments
 (0)