Skip to content

Commit 5cb5061

Browse files
authored
Merge pull request #63 from charleschile/reranker
bge-reranker-v2-m3 Reranker
2 parents 3dad91c + 992f3b9 commit 5cb5061

File tree

1 file changed

+96
-45
lines changed

1 file changed

+96
-45
lines changed

modelcache/adapter/adapter_query.py

Lines changed: 96 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from modelcache.utils.error import NotInitError
66
from modelcache.utils.time import time_cal
77
from modelcache.processor.pre import multi_analysis
8+
from FlagEmbedding import FlagReranker
89

10+
USE_RERANKER = True # 如果为 True 则启用 reranker,否则使用原有逻辑
911

1012
def adapt_query(cache_data_convert, *args, **kwargs):
1113
chat_cache = kwargs.pop("cache_obj", cache)
@@ -74,53 +76,102 @@ def adapt_query(cache_data_convert, *args, **kwargs):
7476
if rank_pre < rank_threshold:
7577
return
7678

77-
for cache_data in cache_data_list:
78-
primary_id = cache_data[1]
79-
ret = chat_cache.data_manager.get_scalar_data(
80-
cache_data, extra_param=context.get("get_scalar_data", None)
81-
)
82-
if ret is None:
83-
continue
79+
if USE_RERANKER:
80+
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=False)
81+
for cache_data in cache_data_list:
82+
primary_id = cache_data[1]
83+
ret = chat_cache.data_manager.get_scalar_data(
84+
cache_data, extra_param=context.get("get_scalar_data", None)
85+
)
86+
if ret is None:
87+
continue
8488

85-
if "deps" in context and hasattr(ret.question, "deps"):
86-
eval_query_data = {
87-
"question": context["deps"][0]["data"],
88-
"embedding": None
89-
}
90-
eval_cache_data = {
91-
"question": ret.question.deps[0].data,
92-
"answer": ret.answers[0].answer,
93-
"search_result": cache_data,
94-
"embedding": None,
95-
}
96-
else:
97-
eval_query_data = {
98-
"question": pre_embedding_data,
99-
"embedding": embedding_data,
100-
}
89+
rank = reranker.compute_score([pre_embedding_data, ret[0]], normalize=True)
10190

102-
eval_cache_data = {
103-
"question": ret[0],
104-
"answer": ret[1],
105-
"search_result": cache_data,
106-
"embedding": None
107-
}
108-
rank = chat_cache.similarity_evaluation.evaluation(
109-
eval_query_data,
110-
eval_cache_data,
111-
extra_param=context.get("evaluation_func", None),
112-
)
91+
if "deps" in context and hasattr(ret.question, "deps"):
92+
eval_query_data = {
93+
"question": context["deps"][0]["data"],
94+
"embedding": None
95+
}
96+
eval_cache_data = {
97+
"question": ret.question.deps[0].data,
98+
"answer": ret.answers[0].answer,
99+
"search_result": cache_data,
100+
"embedding": None,
101+
}
102+
else:
103+
eval_query_data = {
104+
"question": pre_embedding_data,
105+
"embedding": embedding_data,
106+
}
107+
108+
eval_cache_data = {
109+
"question": ret[0],
110+
"answer": ret[1],
111+
"search_result": cache_data,
112+
"embedding": None
113+
}
114+
115+
if len(pre_embedding_data) <= 256:
116+
if rank_threshold <= rank:
117+
cache_answers.append((rank, ret[1]))
118+
cache_questions.append((rank, ret[0]))
119+
cache_ids.append((rank, primary_id))
120+
else:
121+
if rank_threshold_long <= rank:
122+
cache_answers.append((rank, ret[1]))
123+
cache_questions.append((rank, ret[0]))
124+
cache_ids.append((rank, primary_id))
125+
else:
126+
# 不使用 reranker 时,走原来的逻辑
127+
for cache_data in cache_data_list:
128+
primary_id = cache_data[1]
129+
ret = chat_cache.data_manager.get_scalar_data(
130+
cache_data, extra_param=context.get("get_scalar_data", None)
131+
)
132+
if ret is None:
133+
continue
134+
135+
if "deps" in context and hasattr(ret.question, "deps"):
136+
eval_query_data = {
137+
"question": context["deps"][0]["data"],
138+
"embedding": None
139+
}
140+
eval_cache_data = {
141+
"question": ret.question.deps[0].data,
142+
"answer": ret.answers[0].answer,
143+
"search_result": cache_data,
144+
"embedding": None,
145+
}
146+
else:
147+
eval_query_data = {
148+
"question": pre_embedding_data,
149+
"embedding": embedding_data,
150+
}
151+
152+
eval_cache_data = {
153+
"question": ret[0],
154+
"answer": ret[1],
155+
"search_result": cache_data,
156+
"embedding": None
157+
}
158+
rank = chat_cache.similarity_evaluation.evaluation(
159+
eval_query_data,
160+
eval_cache_data,
161+
extra_param=context.get("evaluation_func", None),
162+
)
163+
164+
if len(pre_embedding_data) <= 256:
165+
if rank_threshold <= rank:
166+
cache_answers.append((rank, ret[1]))
167+
cache_questions.append((rank, ret[0]))
168+
cache_ids.append((rank, primary_id))
169+
else:
170+
if rank_threshold_long <= rank:
171+
cache_answers.append((rank, ret[1]))
172+
cache_questions.append((rank, ret[0]))
173+
cache_ids.append((rank, primary_id))
113174

114-
if len(pre_embedding_data) <= 256:
115-
if rank_threshold <= rank:
116-
cache_answers.append((rank, ret[1]))
117-
cache_questions.append((rank, ret[0]))
118-
cache_ids.append((rank, primary_id))
119-
else:
120-
if rank_threshold_long <= rank:
121-
cache_answers.append((rank, ret[1]))
122-
cache_questions.append((rank, ret[0]))
123-
cache_ids.append((rank, primary_id))
124175
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
125176
cache_questions = sorted(cache_questions, key=lambda x: x[0], reverse=True)
126177
cache_ids = sorted(cache_ids, key=lambda x: x[0], reverse=True)
@@ -141,4 +192,4 @@ def adapt_query(cache_data_convert, *args, **kwargs):
141192
logging.info('update_hit_count except, please check!')
142193

143194
chat_cache.report.hint_cache()
144-
return cache_data_convert(return_message, return_query)
195+
return cache_data_convert(return_message, return_query)

0 commit comments

Comments
 (0)