Skip to content

Commit d026973

Browse files
huxinHuXin0817
authored andcommitted
Support vector distance operation
1 parent b79ac11 commit d026973

File tree

9 files changed

+164
-1
lines changed

9 files changed

+164
-1
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ TAGS_sorted_by_file
1313
[._]ss[a-gi-z]
1414
[._]sw[a-p]
1515

16-
# Configuration files for Visual Studio Code
16+
# Configuration files for IDE
1717
.vscode/
18+
.idea/
1819

1920
build/
21+
cmake-build-*/
2022
datatool/
2123
datatools/
2224
__pycache__/

client/mysql.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,7 @@ static COMMANDS commands[] = {
11721172
{"FROM_VECTOR", 0, nullptr, false, ""},
11731173
{"VECTOR_TO_STRING", 0, nullptr, false, ""},
11741174
{"VECTOR_DIM", 0, nullptr, false, ""},
1175+
{"DISTANCE", 0, nullptr, false, ""},
11751176
{"UCASE", 0, nullptr, false, ""},
11761177
{"UNCOMPRESS", 0, nullptr, false, ""},
11771178
{"UNCOMPRESSED_LENGTH", 0, nullptr, false, ""},

share/messages_to_clients.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10430,6 +10430,15 @@ ER_EXCEEDS_VECTOR_MAX_DIMENSIONS
1043010430
ER_TO_VECTOR_CONVERSION
1043110431
eng "Data cannot be converted to a valid vector: '%.*s'"
1043210432

10433+
ER_VECTOR_INVALID_DATA
10434+
eng "Invalid vector data provided to function %s."
10435+
10436+
ER_VECTOR_DIM_NO_EQ
10437+
eng "Vector dim not equal: %d != %d"
10438+
10439+
ER_UNKNOWN_DISTANCE_TYPE
10440+
eng "Unknown distance type: '%.*s'"
10441+
1043310442
OBSOLETE_ER_EXTERNAL_UNSUPPORTED_INDEX_ALGORITHM
1043410443
eng "This storage engine ignores the %s index algorithm."
1043510444

sql/item_create.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,6 +1658,7 @@ static const std::pair<const char *, Create_func *> func_array[] = {
16581658
{"FROM_VECTOR", SQL_FN(Item_func_from_vector, 1)},
16591659
{"VECTOR_TO_STRING", SQL_FN(Item_func_from_vector, 1)},
16601660
{"VECTOR_DIM", SQL_FN(Item_func_vector_dim, 1)},
1661+
{"DISTANCE", SQL_FN(Item_func_vector_distance, 3)},
16611662
{"UCASE", SQL_FN(Item_func_upper, 1)},
16621663
{"UNCOMPRESS", SQL_FN(Item_func_uncompress, 1)},
16631664
{"UNCOMPRESSED_LENGTH", SQL_FN(Item_func_uncompressed_length, 1)},

sql/item_func.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,10 @@ class Item_real_func : public Item_func {
844844
set_data_type_double();
845845
}
846846

847+
Item_real_func(const POS &pos, Item *a, Item *b, Item *c) : Item_func(pos, a, b, c) {
848+
set_data_type_double();
849+
}
850+
847851
explicit Item_real_func(mem_root_deque<Item *> *list) : Item_func(list) {
848852
set_data_type_double();
849853
}

sql/item_strfunc.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4326,6 +4326,74 @@ String *Item_func_from_vector::val_str_ascii(String *str) {
43264326
return &buffer;
43274327
}
43284328

4329+
double Item_func_vector_distance::val_real() {
4330+
assert(fixed);
4331+
4332+
String tmp_value1;
4333+
String tmp_value2;
4334+
String tmp_value3;
4335+
String *res1 = args[0]->val_str(&tmp_value1);
4336+
String *res2 = args[1]->val_str(&tmp_value2);
4337+
String *res3 = args[2]->val_str(&tmp_value3);
4338+
4339+
if ((null_value =
4340+
(!res1 || args[0]->null_value || !res2 || args[1]->null_value))) {
4341+
assert(is_nullable());
4342+
return 0.0;
4343+
}
4344+
4345+
if (res1 == nullptr || res2 == nullptr) {
4346+
my_error(ER_VECTOR_INVALID_DATA, MYF(0), func_name());
4347+
return error_real();
4348+
}
4349+
4350+
if (res3 == nullptr || res3->ptr() == nullptr) {
4351+
return error_real();
4352+
}
4353+
4354+
uint32 dimensions1 = get_dimensions(res1->length(), Field_vector::precision);
4355+
if (dimensions1 == UINT32_MAX) {
4356+
my_error(ER_TO_VECTOR_CONVERSION, MYF(0), res1->length(), res1->ptr());
4357+
return error_real();
4358+
}
4359+
4360+
uint32 dimensions2 = get_dimensions(res2->length(), Field_vector::precision);
4361+
if (dimensions2 == UINT32_MAX) {
4362+
my_error(ER_TO_VECTOR_CONVERSION, MYF(0), res2->length(), res2->ptr());
4363+
return error_real();
4364+
}
4365+
4366+
if (dimensions1 != dimensions2) {
4367+
my_error(ER_VECTOR_DIM_NO_EQ, MYF(0), dimensions1, dimensions2);
4368+
return error_real();
4369+
}
4370+
4371+
float distance = 0.0;
4372+
bool success = true;
4373+
4374+
// COSINE, DOT, and EUCLIDEAN
4375+
if (res3->length() == 3 && memcmp(res3->ptr(), "DOT", 3) == 0) {
4376+
success =
4377+
vector_dot_distance(res1->ptr(), res2->ptr(), dimensions1, &distance);
4378+
} else if (res3->length() == 6 && memcmp(res3->ptr(), "COSINE", 6) == 0) {
4379+
success = vector_cosine_distance(res1->ptr(), res2->ptr(), dimensions1,
4380+
&distance);
4381+
} else if (res3->length() == 9 && memcmp(res3->ptr(), "EUCLIDEAN", 9) == 0) {
4382+
success = vector_euclidean_distance(res1->ptr(), res2->ptr(), dimensions1,
4383+
&distance);
4384+
} else {
4385+
my_error(ER_UNKNOWN_DISTANCE_TYPE, MYF(0), res3->length(), res3->ptr());
4386+
return error_real();
4387+
}
4388+
4389+
if (!success) {
4390+
// ex. Division by zero
4391+
return error_real();
4392+
}
4393+
4394+
return distance;
4395+
}
4396+
43294397
String *Item_func_uncompress::val_str(String *str) {
43304398
assert(fixed);
43314399
String *res = args[0]->val_str(str);

sql/item_strfunc.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,19 @@ class Item_func_from_vector final : public Item_str_ascii_func {
13261326
String *val_str_ascii(String *str) override;
13271327
};
13281328

1329+
class Item_func_vector_distance final : public Item_real_func {
1330+
public:
1331+
Item_func_vector_distance(const POS &pos, Item *ilist1, Item *ilist2,
1332+
Item *ilist3)
1333+
: Item_real_func(pos, ilist1, ilist2, ilist3) {
1334+
set_nullable(true);
1335+
}
1336+
1337+
double val_real() override;
1338+
1339+
const char *func_name() const override { return "vector_distance"; }
1340+
};
1341+
13291342
class Item_func_uncompress final : public Item_str_func {
13301343
String buffer;
13311344

vector-common/vector_conversion.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include <sys/types.h>
2525
#include <cerrno>
26+
#include <cfloat>
2627
#include <cmath>
2728
#include <cstddef>
2829
#include <cstdio>
@@ -170,4 +171,59 @@ bool from_vector_to_string(const char *input, uint32_t input_dims, char *output,
170171

171172
*max_output_len = total_length;
172173
return false;
174+
}
175+
176+
static float vector_data_at(const char *input, uint32_t index) {
177+
return *(const float *)(input + index * sizeof(float));
178+
}
179+
180+
bool vector_euclidean_distance(const char *input1, const char *input2,
181+
uint32_t input_dims, float *result) {
182+
float distance = 0;
183+
for (uint32_t i = 0; i < input_dims; i++) {
184+
float d1 = vector_data_at(input1, i);
185+
float d2 = vector_data_at(input2, i);
186+
float dif = d1 - d2;
187+
distance += dif * dif;
188+
}
189+
190+
*result = sqrt(distance);
191+
return true;
192+
}
193+
194+
bool vector_cosine_distance(const char *input1, const char *input2,
195+
uint32_t input_dims, float *result) {
196+
float dot_product = 0;
197+
float norm1 = 0;
198+
float norm2 = 0;
199+
200+
for (uint32_t i = 0; i < input_dims; i++) {
201+
float d1 = vector_data_at(input1, i);
202+
float d2 = vector_data_at(input2, i);
203+
dot_product += d1 * d2;
204+
norm1 += d1 * d1;
205+
norm2 += d2 * d2;
206+
}
207+
208+
if (norm1 < FLT_EPSILON || norm2 < FLT_EPSILON) {
209+
return false;
210+
}
211+
212+
float cos_sim = dot_product / (sqrt(norm1) * sqrt(norm2));
213+
*result = 1.0 - cos_sim;
214+
return true;
215+
}
216+
217+
bool vector_dot_distance(const char *input1, const char *input2,
218+
uint32_t input_dims, float *result) {
219+
float dot_product = 0;
220+
221+
for (uint32_t i = 0; i < input_dims; i++) {
222+
float d1 = vector_data_at(input1, i);
223+
float d2 = vector_data_at(input2, i);
224+
dot_product += d1 * d2;
225+
}
226+
227+
*result = dot_product;
228+
return true;
173229
}

vector-common/vector_conversion.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ bool from_string_to_vector(const CHARSET_INFO *cs, const char *input,
3232

3333
bool from_vector_to_string(const char *input, uint32_t input_dims, char *output,
3434
uint32_t *max_output_len);
35+
36+
bool vector_euclidean_distance(const char *input1, const char *input2,
37+
uint32_t input_dims, float *result);
38+
39+
bool vector_cosine_distance(const char *input1, const char *input2,
40+
uint32_t input_dims, float *result);
41+
42+
bool vector_dot_distance(const char *input1, const char *input2,
43+
uint32_t input_dims, float *result);

0 commit comments

Comments
 (0)