10
10
from modelcache .manager .vector_data import manager
11
11
from modelcache .manager import CacheBase , VectorBase , get_data_manager , data_manager
12
12
from modelcache .similarity_evaluation .distance import SearchDistanceEvaluation
13
- from modelcache .processor .pre import query_multi_splicing
14
- from modelcache .processor .pre import insert_multi_splicing
13
+ from modelcache .processor .pre import query_multi_splicing ,insert_multi_splicing , query_with_role
15
14
from concurrent .futures import ThreadPoolExecutor
16
15
from modelcache .utils .model_filter import model_blacklist_filter
17
16
from modelcache .embedding import Data2VecAudio
@@ -36,13 +35,17 @@ def response_hitquery(cache_resp):
36
35
37
36
if manager .MPNet_base :
38
37
mpnet_base = MPNet_Base ()
39
- embedding_func = lambda x : mpnet_base .embedding_func ( x )
38
+ embedding_func = mpnet_base .to_embeddings
40
39
dimension = mpnet_base .dimension
41
40
data_manager .NORMALIZE = False
41
+ query_pre_embedding_func = query_with_role
42
+ insert_pre_embedding_func = query_with_role
42
43
else :
43
44
data2vec = Data2VecAudio ()
44
45
embedding_func = data2vec .to_embeddings
45
46
dimension = data2vec .dimension
47
+ query_pre_embedding_func = query_multi_splicing
48
+ insert_pre_embedding_func = insert_multi_splicing
46
49
47
50
mysql_config = configparser .ConfigParser ()
48
51
mysql_config .read ('modelcache/config/mysql_config.ini' )
@@ -95,8 +98,8 @@ def response_hitquery(cache_resp):
95
98
embedding_func = embedding_func ,
96
99
data_manager = data_manager ,
97
100
similarity_evaluation = SearchDistanceEvaluation (),
98
- query_pre_embedding_func = query_multi_splicing ,
99
- insert_pre_embedding_func = insert_multi_splicing ,
101
+ query_pre_embedding_func = query_pre_embedding_func ,
102
+ insert_pre_embedding_func = insert_pre_embedding_func ,
100
103
)
101
104
102
105
global executor
0 commit comments