Skip to content

Commit 17dbc97

Browse files
wip on different input vectors
1 parent 3a45b54 commit 17dbc97

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

engine/clients/pgvector/search.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,26 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
3737
if "hnsw_ef" in cls.search_params:
3838
cls.cur.execute(f"SET hnsw.ef_search = {cls.search_params['hnsw_ef']}")
3939

40+
# Ensure vector is in the correct format for pgvector
41+
try:
42+
if isinstance(vector, bytes):
43+
# If vector is bytes, it might be serialized - try to convert
44+
# First try to interpret as float32 bytes
45+
try:
46+
import struct
47+
num_floats = len(vector) // 4 # 4 bytes per float32
48+
vector_array = np.array(struct.unpack(f'{num_floats}f', vector), dtype=np.float32)
49+
except struct.error:
50+
# If that fails, try to decode as numpy array
51+
vector_array = np.frombuffer(vector, dtype=np.float32)
52+
elif isinstance(vector, np.ndarray):
53+
vector_array = vector.astype(np.float32)
54+
else:
55+
# Convert list to numpy array
56+
vector_array = np.array(vector, dtype=np.float32)
57+
except Exception as e:
58+
raise ValueError(f"Failed to convert vector to proper format. Vector type: {type(vector)}, Error: {e}")
59+
4060
if cls.distance == Distance.COSINE:
4161
query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};"
4262
elif cls.distance == Distance.L2:
@@ -46,7 +66,7 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
4666

4767
cls.cur.execute(
4868
query,
49-
(np.array(vector),),
69+
(vector_array,),
5070
)
5171
return cls.cur.fetchall()
5272

engine/clients/pgvector/upload.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@ def init_client(cls, host, distance, connection_params, upload_params):
5050
def upload_batch(
5151
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
5252
):
53-
vectors = np.array(vectors)
54-
5553
# Copy is faster than insert
5654
with cls.cur.copy("COPY items (id, embedding) FROM STDIN") as copy:
57-
for i, embedding in zip(ids, vectors):
55+
for i, vector in zip(ids, vectors):
56+
# Convert vector to numpy array with float32 dtype for pgvector
57+
if isinstance(vector, np.ndarray):
58+
embedding = vector.astype(np.float32)
59+
else:
60+
embedding = np.array(vector, dtype=np.float32)
5861
copy.write_row((i, embedding))
5962

6063
@classmethod

0 commit comments

Comments
 (0)