-
Notifications
You must be signed in to change notification settings - Fork 1.2k
replace vector_search with faiss #1185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
687c135
20b3549
76de244
b15822d
f043a40
c35a1f0
b16af28
533c002
ac1fc81
d181b0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,23 +17,23 @@ | |
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) | ||
|
||
import copy | ||
import cv2 | ||
import faiss | ||
import numpy as np | ||
from tqdm import tqdm | ||
import pickle | ||
|
||
from python.predict_rec import RecPredictor | ||
from vector_search import Graph_Index | ||
|
||
from utils import logger | ||
from utils import config | ||
|
||
|
||
def split_datafile(data_file, image_root, delimiter="\t"): | ||
''' | ||
data_file: image path and info, which can be splitted by spacer | ||
data_file: image path and info, which can be splitted by spacer | ||
image_root: image path root | ||
delimiter: delimiter | ||
delimiter: delimiter | ||
''' | ||
gallery_images = [] | ||
gallery_docs = [] | ||
|
@@ -45,9 +45,8 @@ def split_datafile(data_file, image_root, delimiter="\t"): | |
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}" | ||
image_file = os.path.join(image_root, line[0]) | ||
|
||
image_doc = line[1] | ||
gallery_images.append(image_file) | ||
gallery_docs.append(image_doc) | ||
gallery_docs.append(ori_line.strip()) | ||
|
||
return gallery_images, gallery_docs | ||
|
||
|
@@ -64,9 +63,91 @@ def build(self, config): | |
''' | ||
build index from scratch | ||
''' | ||
operation_method = config.get("index_operation", "new").lower() | ||
|
||
gallery_images, gallery_docs = split_datafile( | ||
config['data_file'], config['image_root'], config['delimiter']) | ||
|
||
# when remove data in index, do not need extract fatures | ||
if operation_method != "remove": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这一段约有100行代码,可以考虑拆分成子函数或者是类的形式 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉这块功能比较统一。拆成小函数,其他方法也没法调用,感觉不太必要拆分 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK。那加点注释吧,不然这么长的代码,可读性会差一些 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已加 |
||
gallery_features = self._extract_features(gallery_images, config) | ||
|
||
assert operation_method in [ | ||
"new", "remove", "append" | ||
], "Only append, remove and new operation are supported" | ||
|
||
# vector.index: faiss index file | ||
# id_map.pkl: use this file to map id to image_doc | ||
if operation_method in ["remove", "append"]: | ||
# if remove or append, vector.index and id_map.pkl must exist | ||
assert os.path.join( | ||
config["index_dir"], "vector.index" | ||
), "The vector.index dose not exist in {} when 'index_operation' is not None".format( | ||
config["index_dir"]) | ||
assert os.path.join( | ||
config["index_dir"], "id_map.pkl" | ||
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format( | ||
config["index_dir"]) | ||
index = faiss.read_index( | ||
os.path.join(config["index_dir"], "vector.index")) | ||
with open(os.path.join(config["index_dir"], "id_map.pkl"), | ||
'rb') as fd: | ||
ids = pickle.load(fd) | ||
assert index.ntotal == len(ids.keys( | ||
)), "data number in index is not equal in in id_map" | ||
else: | ||
if not os.path.exists(config["index_dir"]): | ||
os.makedirs(config["index_dir"], exist_ok=True) | ||
index_method = config.get("index_method", "HNSW32") | ||
|
||
# if IVF method, cal ivf number automaticlly | ||
if index_method == "IVF": | ||
index_method = index_method + str( | ||
min(int(len(gallery_images) // 8), 65536)) + ",Flat" | ||
dist_type = faiss.METRIC_INNER_PRODUCT if config[ | ||
"dist_type"] == "IP" else faiss.METRIC_L2 | ||
index = faiss.index_factory(config["embedding_size"], index_method, | ||
dist_type) | ||
index = faiss.IndexIDMap2(index) | ||
ids = {} | ||
|
||
if config["index_method"] == "HNSW32": | ||
logger.warning( | ||
"The HNSW32 method dose not support 'remove' operation") | ||
|
||
if operation_method != "remove": | ||
# calculate id for new data | ||
start_id = max(ids.keys()) + 1 if ids else 0 | ||
ids_now = ( | ||
np.arange(0, len(gallery_images)) + start_id).astype(np.int64) | ||
|
||
# only train when new index file | ||
if operation_method == "new": | ||
index.train(gallery_features) | ||
index.add_with_ids(gallery_features, ids_now) | ||
|
||
for i, d in zip(list(ids_now), gallery_docs): | ||
ids[i] = d | ||
else: | ||
if config["index_method"] == "HNSW32": | ||
raise RuntimeError( | ||
"The index_method: HNSW32 dose not support 'remove' operation" | ||
) | ||
# remove ids in id_map, remove index data in faiss index | ||
remove_ids = list( | ||
filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) | ||
remove_ids = np.asarray(remove_ids) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove的功能最好加个注释,按照何种方式进行remove; 是指定图片index,还是指定类别 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new、remove、append,输入都是一致的,基于data_file_list,这一块的使用说明,会在后面的使用文档中添加 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
index.remove_ids(remove_ids) | ||
for k in remove_ids: | ||
del ids[k] | ||
|
||
# store faiss index file and id_map file | ||
faiss.write_index(index, | ||
os.path.join(config["index_dir"], "vector.index")) | ||
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: | ||
pickle.dump(ids, fd) | ||
|
||
def _extract_features(self, gallery_images, config): | ||
# extract gallery features | ||
gallery_features = np.zeros( | ||
[len(gallery_images), config['embedding_size']], dtype=np.float32) | ||
|
@@ -91,19 +172,11 @@ def build(self, config): | |
rec_feat = self.rec_predictor.predict(batch_img) | ||
gallery_features[-len(batch_img):, :] = rec_feat | ||
batch_img = [] | ||
|
||
# train index | ||
self.Searcher = Graph_Index(dist_type=config['dist_type']) | ||
self.Searcher.build( | ||
gallery_vectors=gallery_features, | ||
gallery_docs=gallery_docs, | ||
pq_size=config['pq_size'], | ||
index_path=config['index_path'], | ||
append_index=config["append_index"]) | ||
return gallery_features | ||
|
||
|
||
def main(config): | ||
system_builder = GalleryBuilder(config) | ||
GalleryBuilder(config) | ||
return | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ocr_line未定义,CI会不过
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
42行有定义
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
image_doc = line[1]. 这行对应的语句在哪儿;逻辑上是不是有问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原来的逻辑是你写的这行语句。这个imge_doc可以是字符串。但是faiss中只有id的概念,不是doc,只能是int64类型。我代码里自动给image id,id 跟doc的对应用了自己写的文件对应。这一行,保存了所有的文件信息,imge_path img_doc,同时也是为了remove操作