Skip to content

Commit aead2c1

Browse files
committed
Add feature : add chromadb support as a vector database
1 parent d29299d commit aead2c1

File tree

10 files changed

+225
-6
lines changed

10 files changed

+225
-6
lines changed

flask4modelcache.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,16 @@ def response_hitquery(cache_resp):
4141
# redis_config = configparser.ConfigParser()
4242
# redis_config.read('modelcache/config/redis_config.ini')
4343

44+
# chromadb_config = configparser.ConfigParser()
45+
# chromadb_config.read('modelcache/config/chromadb_config.ini')
4446

4547
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
4648
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
4749

50+
51+
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
52+
# VectorBase("chromadb", dimension=data2vec.dimension, chromadb_config=chromadb_config))
53+
4854
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
4955
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config))
5056

modelcache/config/chromadb_config.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[chromadb]
2+
persist_directory=''
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import logging
5+
from modelcache.manager.vector_data.base import VectorBase, VectorData
6+
from modelcache.utils import import_chromadb, import_torch
7+
8+
import_torch()
9+
import_chromadb()
10+
11+
import chromadb
12+
13+
14+
class Chromadb(VectorBase):
15+
16+
def __init__(
17+
self,
18+
persist_directory="./chromadb",
19+
top_k: int = 1,
20+
):
21+
self.collection_name = "modelcache"
22+
self.top_k = top_k
23+
24+
self._client = chromadb.PersistentClient(path=persist_directory)
25+
self._collection = None
26+
27+
def mul_add(self, datas: List[VectorData], model=None):
28+
collection_name_model = self.collection_name + '_' + model
29+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
30+
31+
data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
32+
self._collection.add(embeddings=data_array, ids=id_array)
33+
34+
def search(self, data: np.ndarray, top_k: int = -1, model=None):
35+
collection_name_model = self.collection_name + '_' + model
36+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
37+
38+
if self._collection.count() == 0:
39+
return []
40+
if top_k == -1:
41+
top_k = self.top_k
42+
results = self._collection.query(
43+
query_embeddings=[data.tolist()],
44+
n_results=top_k,
45+
include=["distances"],
46+
)
47+
return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))
48+
49+
def rebuild(self, ids=None):
50+
pass
51+
52+
def delete(self, ids, model=None):
53+
try:
54+
collection_name_model = self.collection_name + '_' + model
55+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
56+
# 查询集合中实际存在的 ID
57+
ids_str = [str(x) for x in ids]
58+
existing_ids = set(self._collection.get(ids=ids_str).ids)
59+
60+
# 删除存在的 ID
61+
if existing_ids:
62+
self._collection.delete(list(existing_ids))
63+
64+
# 返回实际删除的条目数量
65+
return len(existing_ids)
66+
67+
except Exception as e:
68+
logging.error('Error during deletion: {}'.format(e))
69+
raise ValueError(str(e))
70+
71+
def rebuild_col(self, model):
72+
collection_name_model = self.collection_name + '_' + model
73+
74+
# 检查集合是否存在,如果存在则删除
75+
collections = self._client.list_collections()
76+
if any(col.name == collection_name_model for col in collections):
77+
self._client.delete_collection(collection_name_model)
78+
else:
79+
return 'model collection not found, please check!'
80+
81+
try:
82+
self._client.create_collection(collection_name_model)
83+
except Exception as e:
84+
logging.info(f'rebuild_collection: {e}')
85+
raise ValueError(str(e))
86+
87+
def flush(self):
88+
# chroma无flush方法
89+
pass
90+
91+
def close(self):
92+
# chroma无flush方法
93+
pass

modelcache/manager/vector_data/manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,11 @@ def get(name, **kwargs):
102102
elif name == "chromadb":
103103
from modelcache.manager.vector_data.chroma import Chromadb
104104

105-
client_settings = kwargs.get("client_settings", None)
106-
persist_directory = kwargs.get("persist_directory", None)
107-
collection_name = kwargs.get("collection_name", COLLECTION_NAME)
105+
chromadb_config = kwargs.get("chromadb_config", None)
106+
persist_directory = chromadb_config.get('chromadb','persist_directory')
107+
108108
vector_base = Chromadb(
109-
client_settings=client_settings,
110109
persist_directory=persist_directory,
111-
collection_name=collection_name,
112110
top_k=top_k,
113111
)
114112
elif name == "hnswlib":

modelcache/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,7 @@ def import_pillow():
7373

7474
def import_redis():
7575
_check_library("redis")
76+
77+
78+
def import_chromadb():
79+
_check_library("chromadb", package="chromadb")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[chromadb]
2+
persist_directory=./chromadb
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from typing import List
2+
3+
import numpy as np
4+
import logging
5+
from modelcache_mm.manager.vector_data.base import VectorBase, VectorData
6+
from modelcache_mm.utils import import_chromadb, import_torch
7+
from modelcache_mm.utils.index_util import get_mm_index_name
8+
9+
import_torch()
10+
import_chromadb()
11+
12+
import chromadb
13+
14+
15+
class Chromadb(VectorBase):
16+
17+
def __init__(
18+
self,
19+
persist_directory="./chromadb",
20+
top_k: int = 1,
21+
):
22+
# self.collection_name = "modelcache"
23+
self.top_k = top_k
24+
25+
self._client = chromadb.PersistentClient(path=persist_directory)
26+
self._collection = None
27+
28+
def create(self, model=None, mm_type=None):
29+
try:
30+
collection_name_model = get_mm_index_name(model, mm_type)
31+
# collection_name_model = self.collection_name + '_' + model
32+
self._client.get_or_create_collection(name=collection_name_model)
33+
except Exception as e:
34+
raise ValueError(str(e))
35+
36+
def add(self, datas: List[VectorData], model=None, mm_type=None):
37+
collection_name_model = get_mm_index_name(model, mm_type)
38+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
39+
40+
data_array, id_array = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
41+
self._collection.add(embeddings=data_array, ids=id_array)
42+
43+
def search(self, data: np.ndarray, top_k: int = -1, model=None, mm_type='mm'):
44+
collection_name_model = get_mm_index_name(model, mm_type)
45+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
46+
47+
if self._collection.count() == 0:
48+
return []
49+
if top_k == -1:
50+
top_k = self.top_k
51+
results = self._collection.query(
52+
query_embeddings=[data.tolist()],
53+
n_results=top_k,
54+
include=["distances"],
55+
)
56+
return list(zip(results["distances"][0], [int(x) for x in results["ids"][0]]))
57+
58+
def delete(self, ids, model=None, mm_type=None):
59+
try:
60+
collection_name_model = get_mm_index_name(model, mm_type)
61+
self._collection = self._client.get_or_create_collection(name=collection_name_model)
62+
# 查询集合中实际存在的 ID
63+
ids_str = [str(x) for x in ids]
64+
existing_ids = set(self._collection.get(ids=ids_str).ids)
65+
66+
# 删除存在的 ID
67+
if existing_ids:
68+
self._collection.delete(list(existing_ids))
69+
70+
# 返回实际删除的条目数量
71+
return len(existing_ids)
72+
73+
except Exception as e:
74+
logging.error('Error during deletion: {}'.format(e))
75+
raise ValueError(str(e))
76+
77+
def rebuild_idx(self, model, mm_type=None):
78+
collection_name_model = get_mm_index_name(model, mm_type)
79+
80+
# 检查集合是否存在,如果存在则删除
81+
collections = self._client.list_collections()
82+
if any(col.name == collection_name_model for col in collections):
83+
self._client.delete_collection(collection_name_model)
84+
else:
85+
return 'model collection not found, please check!'
86+
87+
try:
88+
self._client.create_collection(collection_name_model)
89+
except Exception as e:
90+
logging.info(f'rebuild_collection: {e}')
91+
raise ValueError(str(e))
92+
93+
def rebuild(self, ids=None):
94+
pass
95+
96+
def flush(self):
97+
pass
98+
99+
def close(self):
100+
pass

modelcache_mm/manager/vector_data/manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def get(name, **kwargs):
108108
dimension=dimension,
109109
top_k=top_k
110110
)
111+
elif name == "chromadb":
112+
from modelcache_mm.manager.vector_data.chroma import Chromadb
113+
114+
chromadb_config = kwargs.get("chromadb_config", None)
115+
persist_directory = chromadb_config.get('chromadb', 'persist_directory')
116+
vector_base = Chromadb(
117+
persist_directory=persist_directory,
118+
top_k=top_k,
119+
)
111120
else:
112121
raise NotFoundError("vector store", name)
113122
return vector_base

modelcache_mm/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,7 @@ def import_pillow():
7373

7474
def import_redis():
7575
_check_library("redis")
76+
77+
78+
def import_chromadb():
79+
_check_library("chromadb", package="chromadb")

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ faiss-cpu==1.7.4
1313
redis==5.0.1
1414
modelscope==1.14.0
1515
fastapi==0.115.5
16-
uvicorn==0.32.0
16+
uvicorn==0.32.0
17+
chromadb==0.5.23

0 commit comments

Comments
 (0)