Skip to content

Commit 682179b

Browse files
committed
Add feature : add timm support
1 parent 2d692ae commit 682179b

File tree

3 files changed

+207
-4
lines changed

3 files changed

+207
-4
lines changed

modelcache/manager/scalar_data/sql_storage_es.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ def create(self):
6262
}
6363
}
6464

65-
if not self.client.indices.exists(index="modelcache_llm_answer"):
66-
self.client.indices.create(index="modelcache_llm_answer", body=answer_index_body)
65+
if not self.client.indices.exists(index=self.ans_index):
66+
self.client.indices.create(index=self.ans_index, body=answer_index_body)
6767

68-
if not self.client.indices.exists(index="modelcache_query_log"):
69-
self.client.indices.create(index="modelcache_query_log", body=log_index_body)
68+
if not self.client.indices.exists(index=self.log_index):
69+
self.client.indices.create(index=self.log_index, body=log_index_body)
7070

7171
def _insert(self, data: List) -> str or None:
7272
doc = {
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[elasticsearch]
2+
host = ''
3+
port = ''
4+
user = ''
5+
password = ''
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# -*- coding: utf-8 -*-
2+
import json
3+
from typing import List
4+
from elasticsearch import Elasticsearch, helpers
5+
from modelcache.manager.scalar_data.base import CacheStorage, CacheData
6+
import time
7+
from snowflake import SnowflakeGenerator
8+
9+
10+
class SQLStorage(CacheStorage):
11+
def __init__(
12+
self,
13+
db_type: str = "elasticsearch",
14+
config=None
15+
):
16+
self.host = config.get('elasticsearch', 'host')
17+
self.port = int(config.get('elasticsearch', 'port'))
18+
self.client = Elasticsearch(
19+
hosts=[{"host": self.host, "port": self.port}],
20+
timeout=30,
21+
http_auth=('esuser', 'password')
22+
)
23+
24+
self.log_index = "open_cache_mm_query_log"
25+
self.ans_index = "open_cache_mm_answer"
26+
self.create()
27+
self.instance_id = 1 # 雪花算法使用的机器id 使用同一套数据库的分布式系统需要配置不同id
28+
# 生成雪花id
29+
self.snowflake_id = SnowflakeGenerator(self.instance_id)
30+
31+
def create(self):
32+
answer_index_body = {
33+
"mappings": {
34+
"properties": {
35+
"gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
36+
"gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
37+
"question": {"type": "text"},
38+
"answer": {"type": "text"},
39+
"answer_type": {"type": "integer"},
40+
"hit_count": {"type": "integer"},
41+
"model": {"type": "keyword"},
42+
"image_url": {"type": "text"},
43+
"image_id": {"type": "text"},
44+
"is_deleted": {"type": "integer"},
45+
}
46+
}
47+
}
48+
49+
log_index_body = {
50+
"mappings": {
51+
"properties": {
52+
"gmt_create": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
53+
"gmt_modified": {"type": "date", "format": "strict_date_optional_time||epoch_millis"},
54+
"error_code": {"type": "integer"},
55+
"error_desc": {"type": "text"},
56+
"cache_hit": {"type": "keyword"},
57+
"delta_time": {"type": "float"},
58+
"model": {"type": "keyword"},
59+
"query": {"type": "text"},
60+
"hit_query": {"type": "text"},
61+
"answer": {"type": "text"}
62+
}
63+
}
64+
}
65+
66+
if not self.client.indices.exists(index=self.ans_index):
67+
self.client.indices.create(index=self.ans_index, body=answer_index_body)
68+
69+
if not self.client.indices.exists(index=self.log_index):
70+
self.client.indices.create(index=self.log_index, body=log_index_body)
71+
72+
def _insert(self, data: List) -> str or None:
73+
doc = {
74+
"answer": data[0],
75+
"question": data[1],
76+
"image_url": data[2],
77+
"image_id": data[3],
78+
"model": data[4],
79+
"answer_type": 0,
80+
"hit_count": 0,
81+
"is_deleted": 0
82+
}
83+
84+
try:
85+
86+
response = self.client.index(
87+
index=self.ans_index,
88+
id=next(self.snowflake_id),
89+
body=doc,
90+
)
91+
return int(response['_id'])
92+
except Exception as e:
93+
94+
print(f"Failed to insert document: {e}")
95+
return None
96+
97+
def batch_insert(self, all_data: List[List]) -> List[str]:
98+
successful_ids = []
99+
for data in all_data:
100+
_id = self._insert(data)
101+
if _id is not None:
102+
successful_ids.append(_id)
103+
self.client.indices.refresh(index=self.ans_index) # 批量插入后手动刷新
104+
105+
return successful_ids
106+
107+
def insert_query_resp(self, query_resp, **kwargs):
108+
doc = {
109+
"error_code": query_resp.get('errorCode'),
110+
"error_desc": query_resp.get('errorDesc'),
111+
"cache_hit": query_resp.get('cacheHit'),
112+
"model": kwargs.get('model'),
113+
"query": kwargs.get('query'),
114+
"delta_time": kwargs.get('delta_time'),
115+
"hit_query": json.dumps(query_resp.get('hit_query'), ensure_ascii=False) if isinstance(
116+
query_resp.get('hit_query'), list) else query_resp.get('hit_query'),
117+
"answer": query_resp.get('answer'),
118+
"hit_count": 0,
119+
"is_deleted": 0
120+
121+
}
122+
self.client.index(index=self.log_index, body=doc)
123+
124+
def get_data_by_id(self, key: int):
125+
try:
126+
response = self.client.get(index=self.ans_index, id=key, _source=['question', 'image_url','image_id', 'answer', 'model'])
127+
source = response["_source"]
128+
result = [
129+
source.get('question'),
130+
source.get('image_url'),
131+
source.get('image_id'),
132+
source.get('answer'),
133+
source.get('model')
134+
]
135+
return result
136+
except Exception as e:
137+
print(e)
138+
139+
def update_hit_count_by_id(self, primary_id: int):
140+
self.client.update(
141+
index=self.ans_index,
142+
id=primary_id,
143+
body={"script": {"source": "ctx._source.hit_count += 1"}}
144+
)
145+
146+
def get_ids(self, deleted=True):
147+
query = {
148+
"query": {
149+
"term": {"is_deleted": 1 if deleted else 0}
150+
}
151+
}
152+
response = self.client.search(index=self.ans_index, body=query)
153+
return [hit["_id"] for hit in response["hits"]["hits"]]
154+
155+
def mark_deleted(self, keys):
156+
actions = [
157+
{
158+
"_op_type": "update",
159+
"_index": self.ans_index,
160+
"_id": key,
161+
"doc": {"is_deleted": 1}
162+
}
163+
for key in keys
164+
]
165+
responses = helpers.bulk(self.client, actions)
166+
return responses[0] # 返回更新的文档数
167+
168+
def model_deleted(self, model_name):
169+
query = {
170+
"query": {
171+
"term": {"model": model_name}
172+
}
173+
}
174+
175+
response = self.client.delete_by_query(index=self.ans_index, body=query)
176+
return response["deleted"]
177+
178+
def clear_deleted_data(self):
179+
query = {
180+
"query": {
181+
"term": {"is_deleted": 1}
182+
}
183+
}
184+
response = self.client.delete_by_query(index=self.ans_index, body=query)
185+
return response["deleted"]
186+
187+
def count(self, state: int = 0, is_all: bool = False):
188+
query = {"query": {"match_all": {}}} if is_all else {"query": {"term": {"is_deleted": state}}}
189+
response = self.client.count(index=self.ans_index, body=query)
190+
return response["count"]
191+
192+
def close(self):
193+
self.client.close()
194+
195+
def count_answers(self):
196+
query = {"query": {"match_all": {}}}
197+
response = self.client.count(index=self.ans_index, body=query)
198+
return response["count"]

0 commit comments

Comments
 (0)