5
5
from modelcache .utils .error import NotInitError
6
6
from modelcache .utils .time import time_cal
7
7
from modelcache .processor .pre import multi_analysis
8
+ from FlagEmbedding import FlagReranker
8
9
10
+ USE_RERANKER = True # 如果为 True 则启用 reranker,否则使用原有逻辑
9
11
10
12
def adapt_query (cache_data_convert , * args , ** kwargs ):
11
13
chat_cache = kwargs .pop ("cache_obj" , cache )
@@ -74,53 +76,102 @@ def adapt_query(cache_data_convert, *args, **kwargs):
74
76
if rank_pre < rank_threshold :
75
77
return
76
78
77
- for cache_data in cache_data_list :
78
- primary_id = cache_data [1 ]
79
- ret = chat_cache .data_manager .get_scalar_data (
80
- cache_data , extra_param = context .get ("get_scalar_data" , None )
81
- )
82
- if ret is None :
83
- continue
79
+ if USE_RERANKER :
80
+ reranker = FlagReranker ('BAAI/bge-reranker-v2-m3' , use_fp16 = False )
81
+ for cache_data in cache_data_list :
82
+ primary_id = cache_data [1 ]
83
+ ret = chat_cache .data_manager .get_scalar_data (
84
+ cache_data , extra_param = context .get ("get_scalar_data" , None )
85
+ )
86
+ if ret is None :
87
+ continue
84
88
85
- if "deps" in context and hasattr (ret .question , "deps" ):
86
- eval_query_data = {
87
- "question" : context ["deps" ][0 ]["data" ],
88
- "embedding" : None
89
- }
90
- eval_cache_data = {
91
- "question" : ret .question .deps [0 ].data ,
92
- "answer" : ret .answers [0 ].answer ,
93
- "search_result" : cache_data ,
94
- "embedding" : None ,
95
- }
96
- else :
97
- eval_query_data = {
98
- "question" : pre_embedding_data ,
99
- "embedding" : embedding_data ,
100
- }
89
+ rank = reranker .compute_score ([pre_embedding_data , ret [0 ]], normalize = True )
101
90
102
- eval_cache_data = {
103
- "question" : ret [0 ],
104
- "answer" : ret [1 ],
105
- "search_result" : cache_data ,
106
- "embedding" : None
107
- }
108
- rank = chat_cache .similarity_evaluation .evaluation (
109
- eval_query_data ,
110
- eval_cache_data ,
111
- extra_param = context .get ("evaluation_func" , None ),
112
- )
91
+ if "deps" in context and hasattr (ret .question , "deps" ):
92
+ eval_query_data = {
93
+ "question" : context ["deps" ][0 ]["data" ],
94
+ "embedding" : None
95
+ }
96
+ eval_cache_data = {
97
+ "question" : ret .question .deps [0 ].data ,
98
+ "answer" : ret .answers [0 ].answer ,
99
+ "search_result" : cache_data ,
100
+ "embedding" : None ,
101
+ }
102
+ else :
103
+ eval_query_data = {
104
+ "question" : pre_embedding_data ,
105
+ "embedding" : embedding_data ,
106
+ }
107
+
108
+ eval_cache_data = {
109
+ "question" : ret [0 ],
110
+ "answer" : ret [1 ],
111
+ "search_result" : cache_data ,
112
+ "embedding" : None
113
+ }
114
+
115
+ if len (pre_embedding_data ) <= 256 :
116
+ if rank_threshold <= rank :
117
+ cache_answers .append ((rank , ret [1 ]))
118
+ cache_questions .append ((rank , ret [0 ]))
119
+ cache_ids .append ((rank , primary_id ))
120
+ else :
121
+ if rank_threshold_long <= rank :
122
+ cache_answers .append ((rank , ret [1 ]))
123
+ cache_questions .append ((rank , ret [0 ]))
124
+ cache_ids .append ((rank , primary_id ))
125
+ else :
126
+ # 不使用 reranker 时,走原来的逻辑
127
+ for cache_data in cache_data_list :
128
+ primary_id = cache_data [1 ]
129
+ ret = chat_cache .data_manager .get_scalar_data (
130
+ cache_data , extra_param = context .get ("get_scalar_data" , None )
131
+ )
132
+ if ret is None :
133
+ continue
134
+
135
+ if "deps" in context and hasattr (ret .question , "deps" ):
136
+ eval_query_data = {
137
+ "question" : context ["deps" ][0 ]["data" ],
138
+ "embedding" : None
139
+ }
140
+ eval_cache_data = {
141
+ "question" : ret .question .deps [0 ].data ,
142
+ "answer" : ret .answers [0 ].answer ,
143
+ "search_result" : cache_data ,
144
+ "embedding" : None ,
145
+ }
146
+ else :
147
+ eval_query_data = {
148
+ "question" : pre_embedding_data ,
149
+ "embedding" : embedding_data ,
150
+ }
151
+
152
+ eval_cache_data = {
153
+ "question" : ret [0 ],
154
+ "answer" : ret [1 ],
155
+ "search_result" : cache_data ,
156
+ "embedding" : None
157
+ }
158
+ rank = chat_cache .similarity_evaluation .evaluation (
159
+ eval_query_data ,
160
+ eval_cache_data ,
161
+ extra_param = context .get ("evaluation_func" , None ),
162
+ )
163
+
164
+ if len (pre_embedding_data ) <= 256 :
165
+ if rank_threshold <= rank :
166
+ cache_answers .append ((rank , ret [1 ]))
167
+ cache_questions .append ((rank , ret [0 ]))
168
+ cache_ids .append ((rank , primary_id ))
169
+ else :
170
+ if rank_threshold_long <= rank :
171
+ cache_answers .append ((rank , ret [1 ]))
172
+ cache_questions .append ((rank , ret [0 ]))
173
+ cache_ids .append ((rank , primary_id ))
113
174
114
- if len (pre_embedding_data ) <= 256 :
115
- if rank_threshold <= rank :
116
- cache_answers .append ((rank , ret [1 ]))
117
- cache_questions .append ((rank , ret [0 ]))
118
- cache_ids .append ((rank , primary_id ))
119
- else :
120
- if rank_threshold_long <= rank :
121
- cache_answers .append ((rank , ret [1 ]))
122
- cache_questions .append ((rank , ret [0 ]))
123
- cache_ids .append ((rank , primary_id ))
124
175
cache_answers = sorted (cache_answers , key = lambda x : x [0 ], reverse = True )
125
176
cache_questions = sorted (cache_questions , key = lambda x : x [0 ], reverse = True )
126
177
cache_ids = sorted (cache_ids , key = lambda x : x [0 ], reverse = True )
@@ -141,4 +192,4 @@ def adapt_query(cache_data_convert, *args, **kwargs):
141
192
logging .info ('update_hit_count except, please check!' )
142
193
143
194
chat_cache .report .hint_cache ()
144
- return cache_data_convert (return_message , return_query )
195
+ return cache_data_convert (return_message , return_query )
0 commit comments