Skip to content

Commit cdba51f

Browse files
authored
Merge pull request #59 from powerli2002/add-es
Add feature : Support for Elasticsearch as a scalar database
2 parents 5cb5061 + 7901e49 commit cdba51f

File tree

7 files changed

+413
-1
lines changed

7 files changed

+413
-1
lines changed

flask4modelcache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,16 @@ def response_hitquery(cache_resp):
3838
milvus_config = configparser.ConfigParser()
3939
milvus_config.read('modelcache/config/milvus_config.ini')
4040

41+
es_config = configparser.ConfigParser()
42+
es_config.read('modelcache/config/elasticsearch_config.ini')
43+
4144
# redis_config = configparser.ConfigParser()
4245
# redis_config.read('modelcache/config/redis_config.ini')
4346

4447
# chromadb_config = configparser.ConfigParser()
4548
# chromadb_config.read('modelcache/config/chromadb_config.ini')
4649

47-
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
50+
data_manager = get_data_manager(CacheBase("elasticsearch", config=es_config),
4851
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
4952

5053

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[elasticsearch]
2+
host = ''
3+
port = ''
4+
user = ''
5+
password = ''

modelcache/manager/scalar_data/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ def get(name, **kwargs):
2727
from modelcache.manager.scalar_data.sql_storage_sqlite import SQLStorage
2828
sql_url = kwargs.get("sql_url", SQL_URL[name])
2929
cache_base = SQLStorage(db_type=name, url=sql_url)
30+
elif name == 'elasticsearch':
31+
from modelcache.manager.scalar_data.sql_storage_es import SQLStorage
32+
config = kwargs.get("config")
33+
cache_base = SQLStorage(db_type=name, config=config)
3034
else:
3135
raise NotFoundError("cache store", name)
3236
return cache_base
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# -*- coding: utf-8 -*-
2+
import json
3+
from typing import List
4+
from elasticsearch import Elasticsearch, helpers
5+
from modelcache.manager.scalar_data.base import CacheStorage, CacheData
6+
import time
7+
from snowflake import SnowflakeGenerator
8+
9+
10+
class SQLStorage(CacheStorage):
11+
def __init__(
12+
self,
13+
db_type: str = "elasticsearch",
14+
config=None
15+
):
16+
self.host = config.get('elasticsearch', 'host')
17+
self.port = int(config.get('elasticsearch', 'port'))
18+
self.client = Elasticsearch(
19+
hosts=[{"host": self.host, "port": self.port}],
20+
timeout=30,
21+
http_auth=('esuser', 'password')
22+
)
23+
24+
self.log_index = "modelcache_query_log"
25+
self.ans_index = "modelcache_llm_answer"
26+
self.create()
27+
self.instance_id = 1 # 雪花算法使用的机器id 使用同一套数据库的分布式系统需要配置不同id
28+
# 生成雪花id
29+
self.snowflake_id = SnowflakeGenerator(self.instance_id)
30+
31+
def create(self):
32+
answer_index_body = {
33+
"mappings": {
34+
"properties": {
35+
"gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
36+
"gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
37+
"question": {"type": "text"},
38+
"answer": {"type": "text"},
39+
"answer_type": {"type": "integer"},
40+
"hit_count": {"type": "integer"},
41+
"model": {"type": "keyword"},
42+
"embedding_data": {"type": "binary"},
43+
"is_deleted": {"type": "integer"},
44+
}
45+
}
46+
}
47+
48+
log_index_body = {
49+
"mappings": {
50+
"properties": {
51+
"gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
52+
"gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
53+
"error_code": {"type": "integer"},
54+
"error_desc": {"type": "text"},
55+
"cache_hit": {"type": "keyword"},
56+
"delta_time": {"type": "float"},
57+
"model": {"type": "keyword"},
58+
"query": {"type": "text"},
59+
"hit_query": {"type": "text"},
60+
"answer": {"type": "text"}
61+
}
62+
}
63+
}
64+
65+
if not self.client.indices.exists(index=self.ans_index):
66+
self.client.indices.create(index=self.ans_index, body=answer_index_body)
67+
68+
if not self.client.indices.exists(index=self.log_index):
69+
self.client.indices.create(index=self.log_index, body=log_index_body)
70+
71+
def _insert(self, data: List) -> str or None:
72+
doc = {
73+
"answer": data[0],
74+
"question": data[1],
75+
"embedding_data": data[2].tolist() if hasattr(data[2], "tolist") else data[2],
76+
"model": data[3],
77+
"answer_type": 0,
78+
"hit_count": 0,
79+
"is_deleted": 0
80+
}
81+
82+
try:
83+
84+
response = self.client.index(
85+
index=self.ans_index,
86+
id=next(self.snowflake_id),
87+
body=doc,
88+
)
89+
return int(response['_id'])
90+
except Exception as e:
91+
92+
print(f"Failed to insert document: {e}")
93+
return None
94+
95+
def batch_insert(self, all_data: List[List]) -> List[str]:
96+
successful_ids = []
97+
for data in all_data:
98+
_id = self._insert(data)
99+
if _id is not None:
100+
successful_ids.append(_id)
101+
self.client.indices.refresh(index=self.ans_index) # 批量插入后手动刷新
102+
103+
return successful_ids
104+
105+
def insert_query_resp(self, query_resp, **kwargs):
106+
doc = {
107+
"error_code": query_resp.get('errorCode'),
108+
"error_desc": query_resp.get('errorDesc'),
109+
"cache_hit": query_resp.get('cacheHit'),
110+
"model": kwargs.get('model'),
111+
"query": kwargs.get('query'),
112+
"delta_time": kwargs.get('delta_time'),
113+
"hit_query": json.dumps(query_resp.get('hit_query'), ensure_ascii=False) if isinstance(
114+
query_resp.get('hit_query'), list) else query_resp.get('hit_query'),
115+
"answer": query_resp.get('answer'),
116+
"hit_count": 0,
117+
"is_deleted": 0
118+
119+
}
120+
self.client.index(index=self.log_index, body=doc)
121+
122+
def get_data_by_id(self, key: int):
123+
try:
124+
response = self.client.get(index=self.ans_index, id=key, _source=['question', 'answer', 'embedding_data', 'model'])
125+
source = response["_source"]
126+
result = [
127+
source.get('question'),
128+
source.get('answer'),
129+
source.get('embedding_data'),
130+
source.get('model')
131+
]
132+
return result
133+
except Exception as e:
134+
print(e)
135+
136+
def update_hit_count_by_id(self, primary_id: int):
137+
self.client.update(
138+
index=self.ans_index,
139+
id=primary_id,
140+
body={"script": {"source": "ctx._source.hit_count += 1"}}
141+
)
142+
143+
def get_ids(self, deleted=True):
144+
query = {
145+
"query": {
146+
"term": {"is_deleted": 1 if deleted else 0}
147+
}
148+
}
149+
response = self.client.search(index=self.ans_index, body=query)
150+
return [hit["_id"] for hit in response["hits"]["hits"]]
151+
152+
def mark_deleted(self, keys):
153+
actions = [
154+
{
155+
"_op_type": "update",
156+
"_index": self.ans_index,
157+
"_id": key,
158+
"doc": {"is_deleted": 1}
159+
}
160+
for key in keys
161+
]
162+
responses = helpers.bulk(self.client, actions)
163+
return responses[0] # 返回更新的文档数
164+
165+
def model_deleted(self, model_name):
166+
query = {
167+
"query": {
168+
"term": {"model": model_name}
169+
}
170+
}
171+
172+
response = self.client.delete_by_query(index=self.ans_index, body=query)
173+
return response["deleted"]
174+
175+
def clear_deleted_data(self):
176+
query = {
177+
"query": {
178+
"term": {"is_deleted": 1}
179+
}
180+
}
181+
response = self.client.delete_by_query(index=self.ans_index, body=query)
182+
return response["deleted"]
183+
184+
def count(self, state: int = 0, is_all: bool = False):
185+
query = {"query": {"match_all": {}}} if is_all else {"query": {"term": {"is_deleted": state}}}
186+
response = self.client.count(index=self.ans_index, body=query)
187+
return response["count"]
188+
189+
def close(self):
190+
self.client.close()
191+
192+
def count_answers(self):
193+
query = {"query": {"match_all": {}}}
194+
response = self.client.count(index=self.ans_index, body=query)
195+
return response["count"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[elasticsearch]
2+
host = ''
3+
port = ''
4+
user = ''
5+
password = ''

0 commit comments

Comments
 (0)