Skip to content

Commit 81b4715

Browse files
Added custom pre embedding func and some other improvements
Co-authored-by: olgaoznovich <ol.oznovich@gmail.com> Co-authored-by: Yuval-Roth <rothyuv@post.bgu.ac.il>
1 parent b474d15 commit 81b4715

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

flask4modelcache.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from modelcache.manager.vector_data import manager
1111
from modelcache.manager import CacheBase, VectorBase, get_data_manager, data_manager
1212
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
13-
from modelcache.processor.pre import query_multi_splicing
14-
from modelcache.processor.pre import insert_multi_splicing
13+
from modelcache.processor.pre import query_multi_splicing,insert_multi_splicing, query_with_role
1514
from concurrent.futures import ThreadPoolExecutor
1615
from modelcache.utils.model_filter import model_blacklist_filter
1716
from modelcache.embedding import Data2VecAudio
@@ -36,13 +35,17 @@ def response_hitquery(cache_resp):
3635

3736
if manager.MPNet_base:
3837
mpnet_base = MPNet_Base()
39-
embedding_func = lambda x: mpnet_base.embedding_func(x)
38+
embedding_func = mpnet_base.to_embeddings
4039
dimension = mpnet_base.dimension
4140
data_manager.NORMALIZE = False
41+
query_pre_embedding_func=query_with_role
42+
insert_pre_embedding_func=query_with_role
4243
else:
4344
data2vec = Data2VecAudio()
4445
embedding_func = data2vec.to_embeddings
4546
dimension = data2vec.dimension
47+
query_pre_embedding_func=query_multi_splicing
48+
insert_pre_embedding_func=insert_multi_splicing
4649

4750
mysql_config = configparser.ConfigParser()
4851
mysql_config.read('modelcache/config/mysql_config.ini')
@@ -95,8 +98,8 @@ def response_hitquery(cache_resp):
9598
embedding_func=embedding_func,
9699
data_manager=data_manager,
97100
similarity_evaluation=SearchDistanceEvaluation(),
98-
query_pre_embedding_func=query_multi_splicing,
99-
insert_pre_embedding_func=insert_multi_splicing,
101+
query_pre_embedding_func=query_pre_embedding_func,
102+
insert_pre_embedding_func=insert_pre_embedding_func,
100103
)
101104

102105
global executor

modelcache/embedding/mpnet_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def __init__(self):
55
self.dimension = 768
66
self.model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
77

8-
def embedding_func(self, *args, **kwargs):
8+
def to_embeddings(self, *args, **kwargs):
99
if not args:
1010
raise ValueError("No word provided for embedding.")
1111
embeddings = self.model.encode(args)

modelcache/processor/pre.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def insert_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
7-
return data.get("chat_info")[-1]["query"]
7+
return data.get("query")[-1]["content"]
88

99

1010
def query_last_content(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
@@ -67,6 +67,11 @@ def insert_multi_splicing(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
6767
insert_query_list = data['query']
6868
return multi_splicing(insert_query_list)
6969

70+
def query_with_role(data: Dict[str, Any], **_: Dict[str, Any]) -> Any:
71+
query = data["query"][-1]
72+
content = query["content"]
73+
role = query["role"]
74+
return role+": "+content
7075

7176
def multi_splicing(data_list) -> Any:
7277
result_str = ""

0 commit comments

Comments
 (0)