Skip to content

Commit 4a20eaf

Browse files
committed
New feature: websocket4modelcache
A websocket-based API for the ModelCache system. The goal of this API is to save the overhead of creating a new http connection for every request and allow faster querying
1 parent f8d4e72 commit 4a20eaf

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ snowflake-id==1.0.2
2020
flagembedding==1.3.4
2121
cryptography==45.0.2
2222
sentence-transformers==4.1.0
23+
websockets==15.0.1

websocket4modelcache.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# -*- coding: utf-8 -*-
2+
import asyncio
3+
import websockets
4+
import json
5+
from modelcache.cache import Cache
6+
7+
# Initialize the cache
8+
cache = Cache.init("mysql", "milvus")
9+
10+
11+
async def handle_client(websocket):
12+
async for message in websocket:
13+
# Parse JSON
14+
try:
15+
param_dict = json.loads(message)
16+
except json.JSONDecodeError:
17+
await websocket.send(json.dumps({"errorCode": 400, "errorDesc": "bad request"}))
18+
continue
19+
20+
request_id = param_dict.get("requestId")
21+
request_payload = param_dict.get("payload")
22+
if not request_id or not request_payload:
23+
await websocket.send(json.dumps({"errorCode": 400, "errorDesc": "bad request"}))
24+
continue
25+
asyncio.create_task(process_and_respond(websocket, request_id, request_payload))
26+
27+
28+
async def process_and_respond(websocket,request_id, request_payload):
29+
try:
30+
result = cache.handle_request(request_payload)
31+
await websocket.send(json.dumps({"requestId": request_id,"result": result}))
32+
except Exception as e:
33+
error_result = {"errorCode": 102, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
34+
"answer": ''}
35+
cache.save_query_resp(error_result, model='', query='', delta_time=0)
36+
await websocket.send(json.dumps(error_result))
37+
38+
39+
async def main():
40+
print("WebSocket server starting on ws://0.0.0.0:5000")
41+
async with websockets.serve(handle_client, "0.0.0.0", 5000):
42+
await asyncio.Future() # Run forever
43+
44+
if __name__ == "__main__":
45+
asyncio.run(main())

0 commit comments

Comments
 (0)