Skip to content

Commit 752a814

Browse files
committed
use index_info to cover other data except index
1 parent 04f26a1 commit 752a814

File tree

1 file changed

+14
-7
lines changed
  • paddlex/inference/components/retrieval

1 file changed

+14
-7
lines changed

paddlex/inference/components/retrieval/faiss.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ class IndexData:
2929
IDMAP_FN = "id_map"
3030
IDMAP_SUFFIX = ".yaml"
3131

32-
def __init__(self, index, id_map, metric_type, index_type):
32+
def __init__(self, index, index_info):
3333
self._index = index
34-
self._id_map = id_map
35-
self._metric_type = metric_type
36-
self._index_type = index_type
34+
self._index_info = index_info
35+
self._id_map = index_info["id_map"]
36+
self._metric_type = index_info["metric_type"]
37+
self._index_type = index_info["index_type"]
3738

3839
@property
3940
def index(self):
@@ -260,7 +261,9 @@ def build(
260261
index, ids = cls._add_gallery(
261262
metric_type, index, ids, features, gallery_docs, mode="new"
262263
)
263-
return IndexData(index, ids, metric_type, index_type)
264+
return IndexData(
265+
index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
266+
)
264267

265268
@classmethod
266269
def remove(
@@ -288,7 +291,9 @@ def remove(
288291
# remove ids in id_map, remove index data in faiss index
289292
index.remove_ids(remove_ids)
290293
ids = {k: v for k, v in ids.items() if k not in remove_ids}
291-
return IndexData(index, ids, metric_type, index_type)
294+
return IndexData(
295+
index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
296+
)
292297

293298
@classmethod
294299
def append(cls, gallery_imgs, gallery_label, predict_func, index):
@@ -310,7 +315,9 @@ def append(cls, gallery_imgs, gallery_label, predict_func, index):
310315
index, ids = cls._add_gallery(
311316
metric_type, index, ids, features, gallery_docs, mode="append"
312317
)
313-
return IndexData(index, ids, metric_type, index_type)
318+
return IndexData(
319+
index, {"id_map": ids, "metric_type": metric_type, "index_type": index_type}
320+
)
314321

315322
@classmethod
316323
def _add_gallery(

0 commit comments

Comments
 (0)