1
1
# -*- coding: utf-8 -*-
2
- import time
3
2
import uvicorn
4
- import asyncio
5
- import logging
6
- import configparser
7
3
import json
8
4
from fastapi import FastAPI , Request , HTTPException
9
- from pydantic import BaseModel
10
- from concurrent .futures import ThreadPoolExecutor
11
- from starlette .responses import PlainTextResponse
12
- import functools
13
-
14
- from modelcache import cache
15
- from modelcache .adapter import adapter
16
- from modelcache .manager import CacheBase , VectorBase , get_data_manager
17
- from modelcache .similarity_evaluation .distance import SearchDistanceEvaluation
18
- from modelcache .processor .pre import query_multi_splicing
19
- from modelcache .processor .pre import insert_multi_splicing
20
- from modelcache .utils .model_filter import model_blacklist_filter
21
- from modelcache .embedding import Data2VecAudio
5
+ from modelcache .cache import Cache
22
6
23
7
#创建一个FastAPI实例
24
8
app = FastAPI ()
25
9
26
- class RequestData (BaseModel ):
27
- type : str
28
- scope : dict = None
29
- query : str = None
30
- chat_info : dict = None
31
- remove_type : str = None
32
- id_list : list = []
33
-
34
- data2vec = Data2VecAudio ()
35
- mysql_config = configparser .ConfigParser ()
36
- mysql_config .read ('modelcache/config/mysql_config.ini' )
37
-
38
- milvus_config = configparser .ConfigParser ()
39
- milvus_config .read ('modelcache/config/milvus_config.ini' )
40
-
41
- # redis_config = configparser.ConfigParser()
42
- # redis_config.read('modelcache/config/redis_config.ini')
43
-
44
- # 初始化datamanager
45
- data_manager = get_data_manager (
46
- CacheBase ("mysql" , config = mysql_config ),
47
- VectorBase ("milvus" , dimension = data2vec .dimension , milvus_config = milvus_config )
48
- )
49
-
50
- # # 使用redis初始化datamanager
51
- # data_manager = get_data_manager(
52
- # CacheBase("mysql", config=mysql_config),
53
- # VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
54
- # )
10
+ cache = Cache .init ("mysql" , "milvus" )
55
11
56
- cache .init (
57
- embedding_func = data2vec .to_embeddings ,
58
- data_manager = data_manager ,
59
- similarity_evaluation = SearchDistanceEvaluation (),
60
- query_pre_embedding_func = query_multi_splicing ,
61
- insert_pre_embedding_func = insert_multi_splicing ,
62
- )
63
-
64
- executor = ThreadPoolExecutor (max_workers = 6 )
65
-
66
- # 异步保存查询信息
67
- async def save_query_info (result , model , query , delta_time_log ):
68
- loop = asyncio .get_running_loop ()
69
- func = functools .partial (cache .data_manager .save_query_resp , result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log )
70
- await loop .run_in_executor (None , func )
71
-
72
-
73
-
74
- @app .get ("/welcome" , response_class = PlainTextResponse )
12
+ @app .get ("/welcome" )
75
13
async def first_fastapi ():
76
14
return "hello, modelcache!"
77
15
16
+
78
17
@app .post ("/modelcache" )
79
18
async def user_backend (request : Request ):
80
19
try :
@@ -90,7 +29,7 @@ async def user_backend(request: Request):
90
29
# 如果无法解析,返回格式错误
91
30
result = {"errorCode" : 101 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' ,
92
31
"answer" : '' }
93
- asyncio . create_task ( save_query_info (result , model = '' , query = '' , delta_time_log = 0 ) )
32
+ cache . save_query_info (result , model = '' , query = '' , delta_time_log = 0 )
94
33
raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
95
34
else :
96
35
request_data = raw_body
@@ -102,19 +41,7 @@ async def user_backend(request: Request):
102
41
except json .JSONDecodeError :
103
42
raise HTTPException (status_code = 101 , detail = "Invalid JSON format" )
104
43
105
- request_type = request_data .get ('type' )
106
- model = None
107
- if 'scope' in request_data :
108
- model = request_data ['scope' ].get ('model' , '' ).replace ('-' , '_' ).replace ('.' , '_' )
109
- query = request_data .get ('query' )
110
- chat_info = request_data .get ('chat_info' )
111
-
112
- if not request_type or request_type not in ['query' , 'insert' , 'remove' , 'register' ]:
113
- result = {"errorCode" : 102 ,
114
- "errorDesc" : "type exception, should one of ['query', 'insert', 'remove', 'register']" ,
115
- "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
116
- asyncio .create_task (save_query_info (result , model = model , query = '' , delta_time_log = 0 ))
117
- raise HTTPException (status_code = 102 , detail = "Type exception, should be one of ['query', 'insert', 'remove', 'register']" )
44
+ return cache .handle_request (request_data )
118
45
119
46
except Exception as e :
120
47
request_data = raw_body if 'raw_body' in locals () else None
@@ -129,65 +56,6 @@ async def user_backend(request: Request):
129
56
}
130
57
return result
131
58
132
-
133
- # model filter
134
- filter_resp = model_blacklist_filter (model , request_type )
135
- if isinstance (filter_resp , dict ):
136
- return filter_resp
137
-
138
- if request_type == 'query' :
139
- try :
140
- start_time = time .time ()
141
- response = adapter .ChatCompletion .create_query (scope = {"model" : model }, query = query )
142
- delta_time = f"{ round (time .time () - start_time , 2 )} s"
143
-
144
- if response is None :
145
- result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : False , "delta_time" : delta_time , "hit_query" : '' , "answer" : '' }
146
- elif response in ['adapt_query_exception' ]:
147
- result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
148
- "hit_query" : '' , "answer" : '' }
149
- else :
150
- answer = response ['data' ]
151
- hit_query = response ['hitQuery' ]
152
- result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time , "hit_query" : hit_query , "answer" : answer }
153
-
154
- delta_time_log = round (time .time () - start_time , 2 )
155
- asyncio .create_task (save_query_info (result , model , query , delta_time_log ))
156
- return result
157
- except Exception as e :
158
- result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
159
- "hit_query" : '' , "answer" : '' }
160
- logging .info (f'result: { str (result )} ' )
161
- return result
162
-
163
- if request_type == 'insert' :
164
- try :
165
- response = adapter .ChatCompletion .create_insert (model = model , chat_info = chat_info )
166
- if response == 'success' :
167
- return {"errorCode" : 0 , "errorDesc" : "" , "writeStatus" : "success" }
168
- else :
169
- return {"errorCode" : 301 , "errorDesc" : response , "writeStatus" : "exception" }
170
- except Exception as e :
171
- return {"errorCode" : 303 , "errorDesc" : str (e ), "writeStatus" : "exception" }
172
-
173
- if request_type == 'remove' :
174
- response = adapter .ChatCompletion .create_remove (model = model , remove_type = request_data .get ("remove_type" ), id_list = request_data .get ("id_list" ))
175
- if not isinstance (response , dict ):
176
- return {"errorCode" : 401 , "errorDesc" : "" , "response" : response , "removeStatus" : "exception" }
177
-
178
- state = response .get ('status' )
179
- if state == 'success' :
180
- return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
181
- else :
182
- return {"errorCode" : 402 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
183
-
184
- if request_type == 'register' :
185
- response = adapter .ChatCompletion .create_register (model = model )
186
- if response in ['create_success' , 'already_exists' ]:
187
- return {"errorCode" : 0 , "errorDesc" : "" , "response" : response , "writeStatus" : "success" }
188
- else :
189
- return {"errorCode" : 502 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
190
-
191
59
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
192
60
if __name__ == '__main__' :
193
61
uvicorn .run (app , host = '0.0.0.0' , port = 5000 )
0 commit comments