Skip to content

Commit a37f92e

Browse files
committed
New feature: Embedding Dispatcher for parallel embedding
Based on the multiprocessing module, provides true parallel embedding. The number of workers can be adjusted in the cache initialization. Access to the caching logic is now done asynchronously using asyncio module
1 parent 4a20eaf commit a37f92e

15 files changed

+393
-212
lines changed

fastapi4modelcache.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,46 @@
11
# -*- coding: utf-8 -*-
2+
import asyncio
3+
from contextlib import asynccontextmanager
24
import uvicorn
35
import json
4-
from fastapi import FastAPI, Request, HTTPException
6+
from fastapi.responses import JSONResponse
7+
from fastapi import FastAPI, Request
58
from modelcache.cache import Cache
6-
7-
#创建一个FastAPI实例
8-
app = FastAPI()
9-
10-
cache = Cache.init("mysql", "milvus")
9+
from modelcache.embedding import EmbeddingModel
10+
11+
@asynccontextmanager
12+
async def lifespan(app: FastAPI):
13+
global cache
14+
cache, _ = await Cache.init(
15+
sql_storage="mysql",
16+
vector_storage="milvus",
17+
embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2,
18+
embedding_workers_num=2
19+
)
20+
yield
21+
22+
app = FastAPI(lifespan=lifespan)
23+
cache: Cache = None
1124

1225
@app.get("/welcome")
1326
async def first_fastapi():
1427
return "hello, modelcache!"
1528

16-
1729
@app.post("/modelcache")
1830
async def user_backend(request: Request):
19-
try:
20-
raw_body = await request.body()
21-
# 解析字符串为JSON对象
22-
if isinstance(raw_body, bytes):
23-
raw_body = raw_body.decode("utf-8")
24-
if isinstance(raw_body, str):
25-
try:
26-
# 尝试将字符串解析为JSON对象
27-
request_data = json.loads(raw_body)
28-
except json.JSONDecodeError as e:
29-
# 如果无法解析,返回格式错误
30-
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
31-
"answer": ''}
32-
cache.save_query_info(result, model='', query='', delta_time_log=0)
33-
raise HTTPException(status_code=101, detail="Invalid JSON format")
34-
else:
35-
request_data = raw_body
3631

37-
# 确保request_data是字典对象
38-
if isinstance(request_data, str):
39-
try:
40-
request_data = json.loads(request_data)
41-
except json.JSONDecodeError:
42-
raise HTTPException(status_code=101, detail="Invalid JSON format")
43-
44-
return cache.handle_request(request_data)
32+
try:
33+
request_data = await request.json()
34+
except Exception:
35+
result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
36+
return JSONResponse(status_code=400, content=result)
4537

38+
try:
39+
return await cache.handle_request(request_data)
4640
except Exception as e:
47-
request_data = raw_body if 'raw_body' in locals() else None
48-
result = {
49-
"errorCode": 103,
50-
"errorDesc": str(e),
51-
"cacheHit": False,
52-
"delta_time": 0,
53-
"hit_query": '',
54-
"answer": '',
55-
"para_dict": request_data
56-
}
57-
return result
41+
result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
42+
cache.save_query_resp(result, model='', query='', delta_time=0)
43+
return JSONResponse(status_code=500, content=result)
5844

59-
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
6045
if __name__ == '__main__':
61-
uvicorn.run(app, host='0.0.0.0', port=5000)
46+
uvicorn.run(app, host='0.0.0.0', port=5000, loop="asyncio", http="httptools")

fastapi4modelcache_demo.py

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,46 @@
11
# -*- coding: utf-8 -*-
2+
import asyncio
3+
from contextlib import asynccontextmanager
24
import uvicorn
35
import json
4-
from fastapi import FastAPI, Request, HTTPException
5-
6+
from fastapi.responses import JSONResponse
7+
from fastapi import FastAPI, Request
68
from modelcache.cache import Cache
7-
8-
# 创建一个FastAPI实例
9-
app = FastAPI()
10-
11-
cache = Cache.init("sqlite", "faiss")
9+
from modelcache.embedding import EmbeddingModel
10+
11+
@asynccontextmanager
12+
async def lifespan(app: FastAPI):
13+
global cache
14+
cache, _ = await Cache.init(
15+
sql_storage="sqlite",
16+
vector_storage="faiss",
17+
embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2,
18+
embedding_workers_num=2
19+
)
20+
yield
21+
22+
app = FastAPI(lifespan=lifespan)
23+
cache: Cache = None
1224

1325
@app.get("/welcome")
1426
async def first_fastapi():
1527
return "hello, modelcache!"
1628

1729
@app.post("/modelcache")
1830
async def user_backend(request: Request):
19-
try:
20-
raw_body = await request.body()
21-
# 解析字符串为JSON对象
22-
if isinstance(raw_body, bytes):
23-
raw_body = raw_body.decode("utf-8")
24-
if isinstance(raw_body, str):
25-
try:
26-
# 尝试将字符串解析为JSON对象
27-
request_data = json.loads(raw_body)
28-
except json.JSONDecodeError as e:
29-
# 如果无法解析,返回格式错误
30-
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
31-
"answer": ''}
32-
cache.save_query_info(result, model='', query='', delta_time_log=0)
33-
raise HTTPException(status_code=101, detail="Invalid JSON format")
34-
else:
35-
request_data = raw_body
3631

37-
# 确保request_data是字典对象
38-
if isinstance(request_data, str):
39-
try:
40-
request_data = json.loads(request_data)
41-
except json.JSONDecodeError:
42-
raise HTTPException(status_code=101, detail="Invalid JSON format")
43-
44-
return cache.handle_request(request_data)
32+
try:
33+
request_data = await request.json()
34+
except Exception:
35+
result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
36+
return JSONResponse(status_code=400, content=result)
4537

38+
try:
39+
return await cache.handle_request(request_data)
4640
except Exception as e:
47-
request_data = raw_body if 'raw_body' in locals() else None
48-
result = {
49-
"errorCode": 103,
50-
"errorDesc": str(e),
51-
"cacheHit": False,
52-
"delta_time": 0,
53-
"hit_query": '',
54-
"answer": '',
55-
"para_dict": request_data
56-
}
57-
return result
41+
result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
42+
cache.save_query_resp(result, model='', query='', delta_time=0)
43+
return JSONResponse(status_code=500, content=result)
5844

59-
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
6045
if __name__ == '__main__':
61-
uvicorn.run(app, host='0.0.0.0', port=5000)
46+
uvicorn.run(app, host='0.0.0.0', port=5000, loop="asyncio", http="httptools")

flask4modelcache.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,48 @@
11
# -*- coding: utf-8 -*-
2-
from flask import Flask, request
3-
import json
2+
import asyncio
3+
4+
from flask import Flask, request, jsonify
45
from modelcache.cache import Cache
6+
from modelcache.embedding import EmbeddingModel
7+
58

6-
# 创建一个Flask实例
7-
app = Flask(__name__)
9+
async def main():
810

9-
cache = Cache.init("mysql","milvus")
11+
# 创建一个Flask实例
12+
app = Flask(__name__)
1013

11-
@app.route('/welcome')
12-
def first_flask(): # 视图函数
13-
return 'hello, modelcache!'
14+
cache,loop = await Cache.init(
15+
sql_storage="mysql",
16+
vector_storage="milvus",
17+
embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2,
18+
embedding_workers_num=2
19+
)
1420

21+
@app.route('/welcome')
22+
def first_flask(): # 视图函数
23+
return 'hello, modelcache!'
1524

16-
@app.route('/modelcache', methods=['GET', 'POST'])
17-
def user_backend():
18-
param_dict = {}
19-
try:
20-
if request.method == 'POST':
25+
26+
@app.post('/modelcache')
27+
def user_backend():
28+
try:
2129
param_dict = request.json
22-
elif request.method == 'GET':
23-
param_dict = request.args
30+
except Exception:
31+
result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''}
32+
return jsonify(result), 400
33+
34+
try:
35+
result = asyncio.run_coroutine_threadsafe(
36+
cache.handle_request(param_dict), loop
37+
).result()
38+
return jsonify(result), 200
39+
except Exception as e:
40+
result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''}
41+
cache.save_query_resp(result, model='', query='', delta_time=0)
42+
return jsonify(result), 500
2443

25-
return json.dumps(cache.handle_request(param_dict))
26-
except Exception as e:
27-
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
28-
"answer": ''}
29-
cache.save_query_resp(result, model='', query='', delta_time=0)
30-
return json.dumps(result)
44+
await asyncio.to_thread(app.run, host='0.0.0.0', port=5000)
3145

3246

3347
if __name__ == '__main__':
34-
app.run(host='0.0.0.0', port=5000)
48+
asyncio.run(main())

flask4modelcache_demo.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,48 @@
11
# -*- coding: utf-8 -*-
2-
from flask import Flask, request
3-
import json
2+
import asyncio
3+
4+
from flask import Flask, request, jsonify
45
from modelcache.cache import Cache
6+
from modelcache.embedding import EmbeddingModel
7+
58

6-
# 创建一个Flask实例
7-
app = Flask(__name__)
9+
async def main():
810

9-
cache = Cache.init("sqlite","faiss")
11+
# 创建一个Flask实例
12+
app = Flask(__name__)
1013

11-
@app.route('/welcome')
12-
def first_flask(): # 视图函数
13-
return 'hello, modelcache!'
14+
cache,loop = await Cache.init(
15+
sql_storage="sqlite",
16+
vector_storage="faiss",
17+
embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2,
18+
embedding_workers_num=2
19+
)
1420

21+
@app.route('/welcome')
22+
def first_flask(): # 视图函数
23+
return 'hello, modelcache!'
1524

16-
@app.route('/modelcache', methods=['GET', 'POST'])
17-
def user_backend():
18-
param_dict = {}
19-
try:
20-
if request.method == 'POST':
25+
26+
@app.post('/modelcache')
27+
def user_backend():
28+
try:
2129
param_dict = request.json
22-
elif request.method == 'GET':
23-
param_dict = request.args
30+
except Exception:
31+
result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''}
32+
return jsonify(result), 400
33+
34+
try:
35+
result = asyncio.run_coroutine_threadsafe(
36+
cache.handle_request(param_dict), loop
37+
).result()
38+
return jsonify(result), 200
39+
except Exception as e:
40+
result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',"answer": ''}
41+
cache.save_query_resp(result, model='', query='', delta_time=0)
42+
return jsonify(result), 500
2443

25-
return json.dumps(cache.handle_request(param_dict))
26-
except Exception as e:
27-
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
28-
"answer": ''}
29-
cache.save_query_resp(result, model='', query='', delta_time=0)
30-
return json.dumps(result)
44+
await asyncio.to_thread(app.run, host='0.0.0.0', port=5000)
3145

3246

3347
if __name__ == '__main__':
34-
app.run(host='0.0.0.0', port=5000)
48+
asyncio.run(main())

model/text2vec-base-chinese/logs.txt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,19 @@
1-
1+
Epoch:0 Valid| corr: 0.794410
2+
Epoch:0 Valid| corr: 0.691819
3+
Epoch:1 Valid| corr: 0.722749
4+
Epoch:2 Valid| corr: 0.735054
5+
Epoch:3 Valid| corr: 0.738295
6+
Epoch:4 Valid| corr: 0.739411
7+
Test | corr: 0.679971
8+
Epoch:0 Valid| corr: 0.817416
9+
Epoch:1 Valid| corr: 0.832376
10+
Epoch:2 Valid| corr: 0.842308
11+
Epoch:3 Valid| corr: 0.843520
12+
Epoch:4 Valid| corr: 0.841837
13+
Test | corr: 0.793495
14+
Epoch:0 Valid| corr: 0.814648
15+
Epoch:1 Valid| corr: 0.831609
16+
Epoch:2 Valid| corr: 0.841678
17+
Epoch:3 Valid| corr: 0.842387
18+
Epoch:4 Valid| corr: 0.841435
19+
Test | corr: 0.794840

modelcache/adapter/adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ class ChatCompletion(object):
1010
"""Openai ChatCompletion Wrapper"""
1111

1212
@classmethod
13-
def create_query(cls, *args, **kwargs):
13+
async def create_query(cls, *args, **kwargs):
1414
def cache_data_convert(cache_data, cache_query):
1515
return construct_resp_from_cache(cache_data, cache_query)
1616
try:
17-
return adapt_query(
17+
return await adapt_query(
1818
cache_data_convert,
1919
*args,
2020
**kwargs
@@ -24,9 +24,9 @@ def cache_data_convert(cache_data, cache_query):
2424
return str(e)
2525

2626
@classmethod
27-
def create_insert(cls, *args, **kwargs):
27+
async def create_insert(cls, *args, **kwargs):
2828
try:
29-
return adapt_insert(
29+
return await adapt_insert(
3030
*args,
3131
**kwargs
3232
)

0 commit comments

Comments
 (0)