Skip to content

Commit a13ec49

Browse files
Yuval-Rotholgaoznovichomerdor001adiaybgu
committed
Added documentation
Co-authored-by: olgaoznovich <ol.oznovich@gmail.com> Co-authored-by: Yuval-Roth <rothyuv@post.bgu.ac.il> Co-authored-by: omerdor001 <omerdo@post.bgu.ac.il> Co-authored-by: adiaybgu <adiay@post.bgu.ac.il>
1 parent 15da47d commit a13ec49

File tree

10 files changed

+287
-70
lines changed

10 files changed

+287
-70
lines changed

modelcache/adapter/adapter_insert.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,31 @@ async def adapt_insert(*args, **kwargs):
99
chat_cache = kwargs.pop("cache_obj")
1010
model = kwargs.pop("model", None)
1111
require_object_store = kwargs.pop("require_object_store", False)
12+
13+
# Validate object store availability if required
1214
if require_object_store:
1315
assert chat_cache.data_manager.o, "Object store is required for adapter."
16+
1417
context = kwargs.pop("cache_context", {})
1518
chat_info = kwargs.pop("chat_info", [])
1619

17-
pre_embedding_data_list = []
18-
embedding_futures_list = []
19-
llm_data_list = []
20+
# Initialize collections for parallel processing
21+
pre_embedding_data_list = [] # Preprocessed data ready for embedding
22+
embedding_futures_list = [] # Async embedding generation tasks
23+
llm_data_list = [] # Extracted LLM response data
2024

25+
# Process each chat entry and prepare for parallel embedding generation
2126
for row in chat_info:
27+
# Preprocess chat data using configured preprocessing function
2228
pre_embedding_data = chat_cache.insert_pre_embedding_func(
2329
row,
2430
extra_param=context.get("pre_embedding_func", None),
2531
prompts=chat_cache.prompts,
2632
)
2733
pre_embedding_data_list.append(pre_embedding_data)
28-
llm_data_list.append(row['answer'])
34+
llm_data_list.append(row['answer']) # Extract answer text for storage
35+
36+
# Create async embedding generation task with performance monitoring
2937
embedding_future = time_cal(
3038
chat_cache.embedding_func,
3139
func_name="embedding",
@@ -34,8 +42,10 @@ async def adapt_insert(*args, **kwargs):
3442
)(pre_embedding_data)
3543
embedding_futures_list.append(embedding_future)
3644

45+
# Wait for all embedding generation tasks to complete in parallel
3746
embedding_data_list = await asyncio.gather(*embedding_futures_list)
3847

48+
# Save all processed data to the data manager asynchronously
3949
await asyncio.to_thread(
4050
chat_cache.data_manager.save,
4151
pre_embedding_data_list,

modelcache/adapter/adapter_query.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,21 @@
88
USE_RERANKER = False # 如果为 True 则启用 reranker,否则使用原有逻辑
99

1010
async def adapt_query(cache_data_convert, *args, **kwargs):
11+
# Extract query parameters
1112
chat_cache = kwargs.pop("cache_obj")
1213
scope = kwargs.pop("scope")
1314
model = scope['model']
1415
context = kwargs.pop("cache_context", {})
1516
cache_factor = kwargs.pop("cache_factor", 1.0)
17+
18+
# Preprocess query for embedding generation
1619
pre_embedding_data = chat_cache.query_pre_embedding_func(
1720
kwargs,
1821
extra_param=context.get("pre_embedding_func", None),
1922
prompts=chat_cache.prompts,
2023
)
24+
25+
# Generate embedding with performance monitoring
2126
embedding_data = await time_cal(
2227
chat_cache.embedding_func,
2328
func_name="embedding",
@@ -39,24 +44,29 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
3944
model=model
4045
)
4146

47+
# Initialize result containers
4248
cache_answers = []
4349
cache_questions = []
4450
cache_ids = []
4551
cosine_similarity = None
4652

53+
# Similarity evaluation based on metric type
4754
if chat_cache.similarity_metric_type == MetricType.COSINE:
4855
cosine_similarity = cache_data_list[0][0]
4956
# This code uses the built-in cosine similarity evaluation in milvus
5057
if cosine_similarity < chat_cache.similarity_threshold:
51-
return None
58+
return None # No suitable match found
59+
5260
elif chat_cache.similarity_metric_type == MetricType.L2:
53-
## this is the code that uses L2 for similarity evaluation
61+
# this is the code that uses L2 for similarity evaluation
5462
similarity_threshold = chat_cache.similarity_threshold
5563
similarity_threshold_long = chat_cache.similarity_threshold_long
5664

5765
min_rank, max_rank = chat_cache.similarity_evaluation.range()
5866
rank_threshold = (max_rank - min_rank) * similarity_threshold * cache_factor
5967
rank_threshold_long = (max_rank - min_rank) * similarity_threshold_long * cache_factor
68+
69+
# Clamp thresholds to valid range
6070
rank_threshold = (
6171
max_rank
6272
if rank_threshold > max_rank
@@ -71,6 +81,8 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
7181
if rank_threshold_long < min_rank
7282
else rank_threshold_long
7383
)
84+
85+
# Evaluate similarity score
7486
if cache_data_list is None or len(cache_data_list) == 0:
7587
rank_pre = -1.0
7688
else:
@@ -81,12 +93,13 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
8193
extra_param=context.get("evaluation_func", None),
8294
)
8395
if rank_pre < rank_threshold:
84-
return None
96+
return None # Similarity too low
8597
else:
8698
raise ValueError(
8799
f"Unsupported similarity metric type: {chat_cache.similarity_metric_type}"
88100
)
89101

102+
# Process search results with optional reranking
90103
if USE_RERANKER:
91104
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=False)
92105
for cache_data in cache_data_list:
@@ -116,7 +129,6 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
116129
"question": pre_embedding_data,
117130
"embedding": embedding_data,
118131
}
119-
120132
eval_cache_data = {
121133
"question": ret[0],
122134
"answer": ret[1],
@@ -135,9 +147,10 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
135147
cache_questions.append((rank, ret[1]))
136148
cache_ids.append((rank, primary_id))
137149
else:
138-
# 不使用 reranker 时,走原来的逻辑
150+
# Original logic without reranking
139151
for cache_data in cache_data_list:
140152
primary_id = cache_data[1]
153+
# Retrieve full cache entry data
141154
ret = await asyncio.to_thread(
142155
chat_cache.data_manager.get_scalar_data,
143156
cache_data, extra_param=context.get("get_scalar_data", None), model=model
@@ -150,6 +163,7 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
150163
cache_answers.append((cosine_similarity, ret[0]))
151164
cache_questions.append((cosine_similarity, ret[1]))
152165
cache_ids.append((cosine_similarity, primary_id))
166+
153167
elif chat_cache.similarity_metric_type == MetricType.L2:
154168
if "deps" in context and hasattr(ret.question, "deps"):
155169
eval_query_data = {
@@ -167,13 +181,14 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
167181
"question": pre_embedding_data,
168182
"embedding": embedding_data,
169183
}
170-
171184
eval_cache_data = {
172185
"question": ret[0],
173186
"answer": ret[1],
174187
"search_result": cache_data,
175188
"embedding": None
176189
}
190+
191+
# Evaluate similarity for this specific result
177192
rank = chat_cache.similarity_evaluation.evaluation(
178193
eval_query_data,
179194
eval_cache_data,
@@ -195,6 +210,7 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
195210
f"Unsupported similarity metric type: {chat_cache.similarity_metric_type}"
196211
)
197212

213+
# Sort results by similarity score (highest first)
198214
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
199215
cache_questions = sorted(cache_questions, key=lambda x: x[0], reverse=True)
200216
cache_ids = sorted(cache_ids, key=lambda x: x[0], reverse=True)
@@ -208,12 +224,14 @@ async def adapt_query(cache_data_convert, *args, **kwargs):
208224
return_id = chat_cache.post_process_messages_func(
209225
[t[1] for t in cache_ids]
210226
)
211-
# 更新命中次数
227+
228+
# Update hit count for analytics (async to avoid blocking)
212229
try:
213230
asyncio.create_task(asyncio.to_thread(chat_cache.data_manager.update_hit_count,return_id))
214231
except Exception:
215232
logging.info('update_hit_count except, please check!')
216233

234+
# Record cache hit for reporting
217235
chat_cache.report.hint_cache()
218236
return cache_data_convert(return_message, return_query)
219237
return None

0 commit comments

Comments
 (0)