Skip to content

Commit 14a930f

Browse files
committed
enable to use l2 index for queries using l2sq
1 parent 78f65ec commit 14a930f

File tree

6 files changed

+23
-17
lines changed

6 files changed

+23
-17
lines changed

pkg/sql/plan/apply_indices_hnsw.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ func (builder *QueryBuilder) applyIndicesForSortUsingHnsw(nodeID int32, projNode
7272
return nodeID, nil
7373
}
7474

75-
if opType != metric.DistFuncOpTypes[distFnExpr.Func.ObjName] {
75+
origFuncName := distFnExpr.Func.ObjName
76+
if opType != metric.DistFuncOpTypes[origFuncName] {
7677
return nodeID, nil
7778
}
7879

@@ -93,12 +94,13 @@ func (builder *QueryBuilder) applyIndicesForSortUsingHnsw(nodeID int32, projNode
9394
}
9495

9596
// generate JSON by fmt.Sprintf instead of sonic.Marshal for performance
96-
tblCfgStr := fmt.Sprintf(`{"db": "%s", "src": "%s", "metadata":"%s", "index":"%s", "threads_search": %d}`,
97+
tblCfgStr := fmt.Sprintf(`{"db": "%s", "src": "%s", "metadata":"%s", "index":"%s", "threads_search": %d, "orig_func_name": "%s"}`,
9798
scanNode.ObjRef.SchemaName,
9899
scanNode.TableDef.Name,
99100
metaDef.IndexTableName,
100101
idxDef.IndexTableName,
101-
nThread.(int64))
102+
nThread.(int64),
103+
origFuncName)
102104

103105
// JOIN between source table and hnsw_search table function
104106
tableFuncTag := builder.genNewTag()

pkg/sql/plan/apply_indices_ivfflat.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ func (builder *QueryBuilder) applyIndicesForSortUsingIvfflat(nodeID int32, projN
7373
return nodeID, nil
7474
}
7575

76-
if opType != metric.DistFuncOpTypes[distFnExpr.Func.ObjName] {
76+
origFuncName := distFnExpr.Func.ObjName
77+
if opType != metric.DistFuncOpTypes[origFuncName] {
78+
fmt.Println("IVFFlat: Unsupported distance function")
7779
return nodeID, nil
7880
}
7981

@@ -108,7 +110,7 @@ func (builder *QueryBuilder) applyIndicesForSortUsingIvfflat(nodeID int32, projN
108110
params := idxDef.IndexAlgoParams
109111

110112
tblCfgStr := fmt.Sprintf(`{"db": "%s", "src": "%s", "metadata":"%s", "index":"%s", "threads_search": %d,
111-
"entries": "%s", "nprobe" : %d, "pktype" : %d, "pkey" : "%s", "part" : "%s", "parttype" : %d}`,
113+
"entries": "%s", "nprobe" : %d, "pktype" : %d, "pkey" : "%s", "part" : "%s", "parttype" : %d, "orig_func_name": "%s"}`,
112114
scanNode.ObjRef.SchemaName,
113115
scanNode.TableDef.Name,
114116
metaDef.IndexTableName,
@@ -119,7 +121,8 @@ func (builder *QueryBuilder) applyIndicesForSortUsingIvfflat(nodeID int32, projN
119121
pkType.Id,
120122
scanNode.TableDef.Pkey.PkeyColName,
121123
keyPart,
122-
partType.Id)
124+
partType.Id,
125+
origFuncName)
123126

124127
// JOIN between source table and hnsw_search table function
125128
tableFuncTag := builder.genNewTag()

pkg/vectorindex/hnsw/search.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ func (s *HnswSearch[T]) Search(sqlproc *sqlexec.SqlProcess, anyquery any, rt vec
133133
return nil, nil, moerr.NewInternalError(sqlproc.GetContext(), "heap return key is not int64")
134134
}
135135
reskeys = append(reskeys, sr.Id)
136-
sr.Distance = metric.DistanceTransformHnsw(sr.Distance, s.Idxcfg.OpType, s.Idxcfg.Usearch.Metric)
136+
sr.Distance = metric.DistanceTransformHnsw(sr.Distance, metric.DistFuncNameToMetricType[s.Tblcfg.OrigFuncName], s.Idxcfg.Usearch.Metric)
137137
resdistances = append(resdistances, sr.Distance)
138138
}
139139

@@ -247,6 +247,6 @@ func (s *HnswSearch[T]) Load(sqlproc *sqlexec.SqlProcess) error {
247247

248248
// check config and update some parameters such as ef_search
249249
func (s *HnswSearch[T]) UpdateConfig(newalgo cache.VectorIndexSearchIf) error {
250-
250+
s.Tblcfg.OrigFuncName = newalgo.(*HnswSearch[T]).Tblcfg.OrigFuncName
251251
return nil
252252
}

pkg/vectorindex/ivfflat/search.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ func (idx *IvfflatSearchIndex[T]) Search(
207207
resid = append(resid, pk)
208208

209209
dist := vector.GetFixedAtNoTypeCheck[float64](bat.Vecs[1], i)
210-
dist = metric.DistanceTransformIvfflat(dist, idxcfg.OpType, metric.MetricType(idxcfg.Ivfflat.Metric))
210+
dist = metric.DistanceTransformIvfflat(dist, metric.DistFuncNameToMetricType[tblcfg.OrigFuncName], metric.MetricType(idxcfg.Ivfflat.Metric))
211211
distances = append(distances, dist)
212212
}
213213
}
@@ -269,6 +269,6 @@ func (s *IvfflatSearch[T]) Load(sqlproc *sqlexec.SqlProcess) error {
269269

270270
// check config and update some parameters such as ef_search
271271
func (s *IvfflatSearch[T]) UpdateConfig(newalgo cache.VectorIndexSearchIf) error {
272-
272+
s.Tblcfg.OrigFuncName = newalgo.(*IvfflatSearch[T]).Tblcfg.OrigFuncName
273273
return nil
274274
}

pkg/vectorindex/metric/types.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const (
5555
var (
5656
DistFuncOpTypes = map[string]string{
5757
DistFn_L2Distance: OpType_L2Distance,
58-
DistFn_L2sqDistance: OpType_L2sqDistance,
58+
DistFn_L2sqDistance: OpType_L2Distance,
5959
DistFn_InnerProduct: OpType_InnerProduct,
6060
DistFn_CosineDistance: OpType_CosineDistance,
6161
}
@@ -120,17 +120,17 @@ func MaxFloat[T types.RealNumbers]() T {
120120
}
121121
}
122122

123-
func DistanceTransformHnsw(dist float64, optype string, metric usearch.Metric) float64 {
124-
if optype == OpType_L2Distance && metric == usearch.L2sq {
125-
// metric is l2sq but optype is l2_distance
123+
func DistanceTransformHnsw(dist float64, origMetricType MetricType, metricType usearch.Metric) float64 {
124+
if origMetricType == Metric_L2Distance && metricType == usearch.L2sq {
125+
// metric is l2sq but origin is l2_distance
126126
return math.Sqrt(dist)
127127
}
128128
return dist
129129
}
130130

131-
func DistanceTransformIvfflat(dist float64, optype string, metric MetricType) float64 {
132-
if optype == OpType_L2Distance && metric == Metric_L2sqDistance {
133-
// metric is l2sq but optype is l2_distance
131+
func DistanceTransformIvfflat(dist float64, origMetricType, metricType MetricType) float64 {
132+
if origMetricType == Metric_L2Distance && metricType == Metric_L2sqDistance {
133+
// metric is l2sq but origin is l2_distance
134134
return math.Sqrt(dist)
135135
}
136136
return dist

pkg/vectorindex/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ type IndexTableConfig struct {
5454
IndexTable string `json:"index"`
5555
PKey string `json:"pkey"`
5656
KeyPart string `json:"part"`
57+
OrigFuncName string `json:"orig_func_name"`
5758
ThreadsBuild int64 `json:"threads_build"`
5859
ThreadsSearch int64 `json:"threads_search"`
5960
IndexCapacity int64 `json:"index_capacity"`

0 commit comments

Comments
 (0)