Skip to content

Commit f8d4e72

Browse files
committed
Project-wide refactoring, created a single point of init and entry to caching logic
* Separated API logic from cache logic and init logic and by that deduplicated all the cache init code and the parsing and handling logic that was inside the files 'fastpi4modelcache.py', 'fastpi4modelcache_demo.py', 'flask4modelcache.py' and 'flask4modelcache_demo.py' moved all that logic into Cache class. * The cache init code is now modular and extendable and is contained inside the static Cache.init() function. Added EmbeddingModel and MetricType enums that control which logic runs in the adapters and is configured in the init code * All the cache configuration is now held inside the Cache object instead of being spread around the entire project * Deduplicated classes, interfaces and factory methods. Removed some unused dead code. * Moved factory methods into their respective Interface instead of being static global functions with arbitrary names
1 parent bd2caeb commit f8d4e72

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+831
-1293
lines changed

fastapi4modelcache.py

Lines changed: 6 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,19 @@
11
# -*- coding: utf-8 -*-
2-
import time
32
import uvicorn
4-
import asyncio
5-
import logging
6-
import configparser
73
import json
84
from fastapi import FastAPI, Request, HTTPException
9-
from pydantic import BaseModel
10-
from concurrent.futures import ThreadPoolExecutor
11-
from starlette.responses import PlainTextResponse
12-
import functools
13-
14-
from modelcache import cache
15-
from modelcache.adapter import adapter
16-
from modelcache.manager import CacheBase, VectorBase, get_data_manager
17-
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
18-
from modelcache.processor.pre import query_multi_splicing
19-
from modelcache.processor.pre import insert_multi_splicing
20-
from modelcache.utils.model_filter import model_blacklist_filter
21-
from modelcache.embedding import Data2VecAudio
5+
from modelcache.cache import Cache
226

237
#创建一个FastAPI实例
248
app = FastAPI()
259

26-
class RequestData(BaseModel):
27-
type: str
28-
scope: dict = None
29-
query: str = None
30-
chat_info: dict = None
31-
remove_type: str = None
32-
id_list: list = []
33-
34-
data2vec = Data2VecAudio()
35-
mysql_config = configparser.ConfigParser()
36-
mysql_config.read('modelcache/config/mysql_config.ini')
37-
38-
milvus_config = configparser.ConfigParser()
39-
milvus_config.read('modelcache/config/milvus_config.ini')
40-
41-
# redis_config = configparser.ConfigParser()
42-
# redis_config.read('modelcache/config/redis_config.ini')
43-
44-
# 初始化datamanager
45-
data_manager = get_data_manager(
46-
CacheBase("mysql", config=mysql_config),
47-
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)
48-
)
49-
50-
# # 使用redis初始化datamanager
51-
# data_manager = get_data_manager(
52-
# CacheBase("mysql", config=mysql_config),
53-
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
54-
# )
10+
cache = Cache.init("mysql", "milvus")
5511

56-
cache.init(
57-
embedding_func=data2vec.to_embeddings,
58-
data_manager=data_manager,
59-
similarity_evaluation=SearchDistanceEvaluation(),
60-
query_pre_embedding_func=query_multi_splicing,
61-
insert_pre_embedding_func=insert_multi_splicing,
62-
)
63-
64-
executor = ThreadPoolExecutor(max_workers=6)
65-
66-
# 异步保存查询信息
67-
async def save_query_info(result, model, query, delta_time_log):
68-
loop = asyncio.get_running_loop()
69-
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
70-
await loop.run_in_executor(None, func)
71-
72-
73-
74-
@app.get("/welcome", response_class=PlainTextResponse)
12+
@app.get("/welcome")
7513
async def first_fastapi():
7614
return "hello, modelcache!"
7715

16+
7817
@app.post("/modelcache")
7918
async def user_backend(request: Request):
8019
try:
@@ -90,7 +29,7 @@ async def user_backend(request: Request):
9029
# 如果无法解析,返回格式错误
9130
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
9231
"answer": ''}
93-
asyncio.create_task(save_query_info(result, model='', query='', delta_time_log=0))
32+
cache.save_query_info(result, model='', query='', delta_time_log=0)
9433
raise HTTPException(status_code=101, detail="Invalid JSON format")
9534
else:
9635
request_data = raw_body
@@ -102,19 +41,7 @@ async def user_backend(request: Request):
10241
except json.JSONDecodeError:
10342
raise HTTPException(status_code=101, detail="Invalid JSON format")
10443

105-
request_type = request_data.get('type')
106-
model = None
107-
if 'scope' in request_data:
108-
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
109-
query = request_data.get('query')
110-
chat_info = request_data.get('chat_info')
111-
112-
if not request_type or request_type not in ['query', 'insert', 'remove', 'register']:
113-
result = {"errorCode": 102,
114-
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
115-
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
116-
asyncio.create_task(save_query_info(result, model=model, query='', delta_time_log=0))
117-
raise HTTPException(status_code=102, detail="Type exception, should be one of ['query', 'insert', 'remove', 'register']")
44+
return cache.handle_request(request_data)
11845

11946
except Exception as e:
12047
request_data = raw_body if 'raw_body' in locals() else None
@@ -129,65 +56,6 @@ async def user_backend(request: Request):
12956
}
13057
return result
13158

132-
133-
# model filter
134-
filter_resp = model_blacklist_filter(model, request_type)
135-
if isinstance(filter_resp, dict):
136-
return filter_resp
137-
138-
if request_type == 'query':
139-
try:
140-
start_time = time.time()
141-
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
142-
delta_time = f"{round(time.time() - start_time, 2)}s"
143-
144-
if response is None:
145-
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
146-
elif response in ['adapt_query_exception']:
147-
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
148-
"hit_query": '', "answer": ''}
149-
else:
150-
answer = response['data']
151-
hit_query = response['hitQuery']
152-
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
153-
154-
delta_time_log = round(time.time() - start_time, 2)
155-
asyncio.create_task(save_query_info(result, model, query, delta_time_log))
156-
return result
157-
except Exception as e:
158-
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
159-
"hit_query": '', "answer": ''}
160-
logging.info(f'result: {str(result)}')
161-
return result
162-
163-
if request_type == 'insert':
164-
try:
165-
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
166-
if response == 'success':
167-
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
168-
else:
169-
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
170-
except Exception as e:
171-
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
172-
173-
if request_type == 'remove':
174-
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
175-
if not isinstance(response, dict):
176-
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
177-
178-
state = response.get('status')
179-
if state == 'success':
180-
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
181-
else:
182-
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
183-
184-
if request_type == 'register':
185-
response = adapter.ChatCompletion.create_register(model=model)
186-
if response in ['create_success', 'already_exists']:
187-
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
188-
else:
189-
return {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
190-
19159
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
19260
if __name__ == '__main__':
19361
uvicorn.run(app, host='0.0.0.0', port=5000)

fastapi4modelcache_demo.py

Lines changed: 10 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,16 @@
11
# -*- coding: utf-8 -*-
2-
import time
32
import uvicorn
4-
import asyncio
5-
import logging
6-
# import configparser
73
import json
84
from fastapi import FastAPI, Request, HTTPException
9-
from pydantic import BaseModel
10-
from concurrent.futures import ThreadPoolExecutor
11-
from starlette.responses import PlainTextResponse
12-
import functools
135

14-
from modelcache import cache
15-
from modelcache.adapter import adapter
16-
from modelcache.manager import CacheBase, VectorBase, get_data_manager
17-
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
18-
from modelcache.processor.pre import query_multi_splicing
19-
from modelcache.processor.pre import insert_multi_splicing
20-
from modelcache.utils.model_filter import model_blacklist_filter
21-
from modelcache.embedding import Data2VecAudio
6+
from modelcache.cache import Cache
227

238
# 创建一个FastAPI实例
249
app = FastAPI()
2510

26-
class RequestData(BaseModel):
27-
type: str
28-
scope: dict = None
29-
query: str = None
30-
chat_info: list = None
31-
remove_type: str = None
32-
id_list: list = []
11+
cache = Cache.init("sqlite", "faiss")
3312

34-
data2vec = Data2VecAudio()
35-
36-
data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
37-
38-
cache.init(
39-
embedding_func=data2vec.to_embeddings,
40-
data_manager=data_manager,
41-
similarity_evaluation=SearchDistanceEvaluation(),
42-
query_pre_embedding_func=query_multi_splicing,
43-
insert_pre_embedding_func=insert_multi_splicing,
44-
)
45-
46-
executor = ThreadPoolExecutor(max_workers=6)
47-
48-
# 异步保存查询信息
49-
async def save_query_info_fastapi(result, model, query, delta_time_log):
50-
loop = asyncio.get_running_loop()
51-
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
52-
await loop.run_in_executor(None, func)
53-
54-
55-
56-
@app.get("/welcome", response_class=PlainTextResponse)
13+
@app.get("/welcome")
5714
async def first_fastapi():
5815
return "hello, modelcache!"
5916

@@ -68,9 +25,12 @@ async def user_backend(request: Request):
6825
try:
6926
# 尝试将字符串解析为JSON对象
7027
request_data = json.loads(raw_body)
71-
except json.JSONDecodeError:
28+
except json.JSONDecodeError as e:
7229
# 如果无法解析,返回格式错误
73-
raise HTTPException(status_code=400, detail="Invalid JSON format")
30+
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
31+
"answer": ''}
32+
cache.save_query_info(result, model='', query='', delta_time_log=0)
33+
raise HTTPException(status_code=101, detail="Invalid JSON format")
7434
else:
7535
request_data = raw_body
7636

@@ -79,17 +39,9 @@ async def user_backend(request: Request):
7939
try:
8040
request_data = json.loads(request_data)
8141
except json.JSONDecodeError:
82-
raise HTTPException(status_code=400, detail="Invalid JSON format")
83-
84-
request_type = request_data.get('type')
85-
model = None
86-
if 'scope' in request_data:
87-
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
88-
query = request_data.get('query')
89-
chat_info = request_data.get('chat_info')
42+
raise HTTPException(status_code=101, detail="Invalid JSON format")
9043

91-
if not request_type or request_type not in ['query', 'insert', 'remove', 'detox']:
92-
raise HTTPException(status_code=400, detail="Type exception, should be one of ['query', 'insert', 'remove', 'detox']")
44+
return cache.handle_request(request_data)
9345

9446
except Exception as e:
9547
request_data = raw_body if 'raw_body' in locals() else None
@@ -104,59 +56,6 @@ async def user_backend(request: Request):
10456
}
10557
return result
10658

107-
108-
# model filter
109-
filter_resp = model_blacklist_filter(model, request_type)
110-
if isinstance(filter_resp, dict):
111-
return filter_resp
112-
113-
if request_type == 'query':
114-
try:
115-
start_time = time.time()
116-
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
117-
delta_time = f"{round(time.time() - start_time, 2)}s"
118-
119-
if response is None:
120-
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
121-
elif response in ['adapt_query_exception']:
122-
# elif isinstance(response, str):
123-
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
124-
"hit_query": '', "answer": ''}
125-
else:
126-
answer = response['data']
127-
hit_query = response['hitQuery']
128-
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
129-
130-
delta_time_log = round(time.time() - start_time, 2)
131-
asyncio.create_task(save_query_info_fastapi(result, model, query, delta_time_log))
132-
return result
133-
except Exception as e:
134-
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
135-
"hit_query": '', "answer": ''}
136-
logging.info(f'result: {str(result)}')
137-
return result
138-
139-
if request_type == 'insert':
140-
try:
141-
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
142-
if response == 'success':
143-
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
144-
else:
145-
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
146-
except Exception as e:
147-
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
148-
149-
if request_type == 'remove':
150-
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
151-
if not isinstance(response, dict):
152-
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
153-
154-
state = response.get('status')
155-
if state == 'success':
156-
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
157-
else:
158-
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
159-
16059
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
16160
if __name__ == '__main__':
16261
uvicorn.run(app, host='0.0.0.0', port=5000)

0 commit comments

Comments
 (0)