Skip to content

Commit 559abfc

Browse files
authored
Merge pull request #87 from Yuval-Roth/main
feat: big performance improvements, new features and refactorings
2 parents e053e0d + a13ec49 commit 559abfc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+2925
-1476
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ celerybeat.pid
9393

9494
# Environments
9595
.env
96-
.venv
96+
.venv*
9797
env/
9898
venv/
9999
ENV/
@@ -142,7 +142,7 @@ dmypy.json
142142
**/multicache_serving.py
143143
**/modelcache_serving.py
144144

145-
**/model/
145+
**/model/text2vec-base-chinese
146146

147147
/data/milvus/db
148148
/data/mysql/db

data/mysql/init/init.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CREATE DATABASE IF NOT EXISTS `modelcache`;
33
USE `modelcache`;
44

55
CREATE TABLE IF NOT EXISTS `modelcache_llm_answer` (
6-
`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT comment '主键',
6+
`id` CHAR(36) comment '主键',
77
`gmt_create` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP comment '创建时间',
88
`gmt_modified` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP comment '修改时间',
99
`question` text NOT NULL comment 'question',

docker-compose.yaml

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
version: 'Beta'
1+
name: "modelcache"
22
services:
33
mysql:
44
image: mysql:8.0.23
@@ -14,12 +14,12 @@ services:
1414
- ./data/mysql/db:/var/lib/mysql
1515
- ./data/mysql/my.cnf:/etc/mysql/conf.d/my.cnf
1616
- ./data/mysql/init:/docker-entrypoint-initdb.d
17-
restart: on-failure
17+
# restart: on-failure
1818
networks:
1919
- modelcache
2020

2121
milvus:
22-
image: milvusdb/milvus:v2.5.0-beta
22+
image: milvusdb/milvus:v2.5.10
2323
container_name: milvus
2424
security_opt:
2525
- seccomp:unconfined
@@ -36,35 +36,35 @@ services:
3636
- 19530:19530
3737
- 9091:9091
3838
- 2379:2379
39-
healthcheck:
40-
test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
41-
interval: 30s
42-
start_period: 90s
43-
timeout: 20s
44-
retries: 3
39+
# healthcheck:
40+
# test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"]
41+
# interval: 30s
42+
# start_period: 90s
43+
# timeout: 20s
44+
# retries: 3
4545
networks:
4646
- modelcache
47-
restart: on-failure
47+
# restart: on-failure
4848
command: milvus run standalone
4949

50-
modelcache:
51-
build:
52-
context: .
53-
dockerfile: Dockerfile
54-
container_name: modelcache
55-
image: modelcache:0.1.0
56-
ports:
57-
- 5000:5000
58-
volumes:
59-
- ./model:/home/user/model
60-
- ./modelcache:/home/user/modelcache
61-
- ./modelcache_mm:/home/user/modelcache_mm
62-
- ./fastapi4modelcache.py:/home/user/fastapi4modelcache.py
63-
networks:
64-
- modelcache
65-
restart: on-failure
66-
command: sh -c "uvicorn fastapi4modelcache:app --reload --reload-dir /home/user --port=5000 --host=0.0.0.0"
50+
# modelcache:
51+
# build:
52+
# context: .
53+
# dockerfile: Dockerfile
54+
# container_name: modelcache
55+
# image: modelcache:0.1.0
56+
# ports:
57+
# - 5000:5000
58+
# volumes:
59+
# - ./model:/home/user/model
60+
# - ./modelcache:/home/user/modelcache
61+
# - ./modelcache_mm:/home/user/modelcache_mm
62+
# - ./fastapi4modelcache.py:/home/user/fastapi4modelcache.py
63+
# networks:
64+
# - modelcache
65+
# restart: on-failure
66+
# command: sh -c "uvicorn fastapi4modelcache:app --reload --reload-dir /home/user --port=5000 --host=0.0.0.0"
6767

6868
networks:
6969
modelcache:
70-
external: true
70+
driver: bridge

fastapi4modelcache.py

Lines changed: 33 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -1,193 +1,46 @@
11
# -*- coding: utf-8 -*-
2-
import time
3-
import uvicorn
42
import asyncio
5-
import logging
6-
import configparser
3+
from contextlib import asynccontextmanager
4+
import uvicorn
75
import json
8-
from fastapi import FastAPI, Request, HTTPException
9-
from pydantic import BaseModel
10-
from concurrent.futures import ThreadPoolExecutor
11-
from starlette.responses import PlainTextResponse
12-
import functools
13-
14-
from modelcache import cache
15-
from modelcache.adapter import adapter
16-
from modelcache.manager import CacheBase, VectorBase, get_data_manager
17-
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
18-
from modelcache.processor.pre import query_multi_splicing
19-
from modelcache.processor.pre import insert_multi_splicing
20-
from modelcache.utils.model_filter import model_blacklist_filter
21-
from modelcache.embedding import Data2VecAudio
22-
23-
#创建一个FastAPI实例
24-
app = FastAPI()
25-
26-
class RequestData(BaseModel):
27-
type: str
28-
scope: dict = None
29-
query: str = None
30-
chat_info: dict = None
31-
remove_type: str = None
32-
id_list: list = []
33-
34-
data2vec = Data2VecAudio()
35-
mysql_config = configparser.ConfigParser()
36-
mysql_config.read('modelcache/config/mysql_config.ini')
37-
38-
milvus_config = configparser.ConfigParser()
39-
milvus_config.read('modelcache/config/milvus_config.ini')
40-
41-
# redis_config = configparser.ConfigParser()
42-
# redis_config.read('modelcache/config/redis_config.ini')
43-
44-
# 初始化datamanager
45-
data_manager = get_data_manager(
46-
CacheBase("mysql", config=mysql_config),
47-
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)
48-
)
49-
50-
# # 使用redis初始化datamanager
51-
# data_manager = get_data_manager(
52-
# CacheBase("mysql", config=mysql_config),
53-
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
54-
# )
55-
56-
cache.init(
57-
embedding_func=data2vec.to_embeddings,
58-
data_manager=data_manager,
59-
similarity_evaluation=SearchDistanceEvaluation(),
60-
query_pre_embedding_func=query_multi_splicing,
61-
insert_pre_embedding_func=insert_multi_splicing,
62-
)
63-
64-
executor = ThreadPoolExecutor(max_workers=6)
65-
66-
# 异步保存查询信息
67-
async def save_query_info(result, model, query, delta_time_log):
68-
loop = asyncio.get_running_loop()
69-
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
70-
await loop.run_in_executor(None, func)
71-
72-
73-
74-
@app.get("/welcome", response_class=PlainTextResponse)
6+
from fastapi.responses import JSONResponse
7+
from fastapi import FastAPI, Request
8+
from modelcache.cache import Cache
9+
from modelcache.embedding import EmbeddingModel
10+
11+
@asynccontextmanager
12+
async def lifespan(app: FastAPI):
13+
global cache
14+
cache, _ = await Cache.init(
15+
sql_storage="mysql",
16+
vector_storage="milvus",
17+
embedding_model=EmbeddingModel.HUGGINGFACE_ALL_MPNET_BASE_V2,
18+
embedding_workers_num=2
19+
)
20+
yield
21+
22+
app = FastAPI(lifespan=lifespan)
23+
cache: Cache = None
24+
25+
@app.get("/welcome")
7526
async def first_fastapi():
7627
return "hello, modelcache!"
7728

7829
@app.post("/modelcache")
7930
async def user_backend(request: Request):
80-
try:
81-
raw_body = await request.body()
82-
# 解析字符串为JSON对象
83-
if isinstance(raw_body, bytes):
84-
raw_body = raw_body.decode("utf-8")
85-
if isinstance(raw_body, str):
86-
try:
87-
# 尝试将字符串解析为JSON对象
88-
request_data = json.loads(raw_body)
89-
except json.JSONDecodeError as e:
90-
# 如果无法解析,返回格式错误
91-
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
92-
"answer": ''}
93-
asyncio.create_task(save_query_info(result, model='', query='', delta_time_log=0))
94-
raise HTTPException(status_code=101, detail="Invalid JSON format")
95-
else:
96-
request_data = raw_body
97-
98-
# 确保request_data是字典对象
99-
if isinstance(request_data, str):
100-
try:
101-
request_data = json.loads(request_data)
102-
except json.JSONDecodeError:
103-
raise HTTPException(status_code=101, detail="Invalid JSON format")
104-
105-
request_type = request_data.get('type')
106-
model = None
107-
if 'scope' in request_data:
108-
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
109-
query = request_data.get('query')
110-
chat_info = request_data.get('chat_info')
11131

112-
if not request_type or request_type not in ['query', 'insert', 'remove', 'register']:
113-
result = {"errorCode": 102,
114-
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
115-
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
116-
asyncio.create_task(save_query_info(result, model=model, query='', delta_time_log=0))
117-
raise HTTPException(status_code=102, detail="Type exception, should be one of ['query', 'insert', 'remove', 'register']")
32+
try:
33+
request_data = await request.json()
34+
except Exception:
35+
result = {"errorCode": 400, "errorDesc": "bad request", "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
36+
return JSONResponse(status_code=400, content=result)
11837

38+
try:
39+
return await cache.handle_request(request_data)
11940
except Exception as e:
120-
request_data = raw_body if 'raw_body' in locals() else None
121-
result = {
122-
"errorCode": 103,
123-
"errorDesc": str(e),
124-
"cacheHit": False,
125-
"delta_time": 0,
126-
"hit_query": '',
127-
"answer": '',
128-
"para_dict": request_data
129-
}
130-
return result
131-
132-
133-
# model filter
134-
filter_resp = model_blacklist_filter(model, request_type)
135-
if isinstance(filter_resp, dict):
136-
return filter_resp
137-
138-
if request_type == 'query':
139-
try:
140-
start_time = time.time()
141-
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
142-
delta_time = f"{round(time.time() - start_time, 2)}s"
143-
144-
if response is None:
145-
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
146-
elif response in ['adapt_query_exception']:
147-
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
148-
"hit_query": '', "answer": ''}
149-
else:
150-
answer = response['data']
151-
hit_query = response['hitQuery']
152-
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
153-
154-
delta_time_log = round(time.time() - start_time, 2)
155-
asyncio.create_task(save_query_info(result, model, query, delta_time_log))
156-
return result
157-
except Exception as e:
158-
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
159-
"hit_query": '', "answer": ''}
160-
logging.info(f'result: {str(result)}')
161-
return result
162-
163-
if request_type == 'insert':
164-
try:
165-
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
166-
if response == 'success':
167-
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
168-
else:
169-
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
170-
except Exception as e:
171-
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
172-
173-
if request_type == 'remove':
174-
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
175-
if not isinstance(response, dict):
176-
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
177-
178-
state = response.get('status')
179-
if state == 'success':
180-
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
181-
else:
182-
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
183-
184-
if request_type == 'register':
185-
response = adapter.ChatCompletion.create_register(model=model)
186-
if response in ['create_success', 'already_exists']:
187-
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
188-
else:
189-
return {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
41+
result = {"errorCode": 500, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
42+
cache.save_query_resp(result, model='', query='', delta_time=0)
43+
return JSONResponse(status_code=500, content=result)
19044

191-
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
19245
if __name__ == '__main__':
193-
uvicorn.run(app, host='0.0.0.0', port=5000)
46+
uvicorn.run(app, host='0.0.0.0', port=5000, loop="asyncio", http="httptools")

0 commit comments

Comments
 (0)