diff --git a/muted-tests.yml b/muted-tests.yml index 255cb9b60cae8..7d5fff8a0866c 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -533,6 +533,9 @@ tests: - class: org.elasticsearch.xpack.esql.qa.single_node.EsqlSpecIT method: test {csv-spec:spatial.ConvertFromStringParseError} issue: https://github.com/elastic/elasticsearch/issues/132558 +- class: org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT + method: test {p0=esql/40_unsupported_types/nested declared inline} + issue: https://github.com/elastic/elasticsearch/issues/126606 # Examples: # diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 85fa02aecaaef..887a71734d447 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -49,6 +49,7 @@ record CmdLineArgs( float filterSelectivity, long seed, VectorSimilarityFunction vectorSpace, + int rawVectorSize, int quantizeBits, VectorEncoding vectorEncoding, int dimensions, @@ -75,6 +76,7 @@ record CmdLineArgs( static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge"); static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space"); static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); + static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size"); static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); @@ -108,6 +110,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); + PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD); PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); @@ -143,6 +146,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(REINDEX_FIELD.getPreferredName(), reindex); builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); + builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize); builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); @@ -176,6 +180,7 @@ static class Builder { private boolean reindex = false; private boolean forceMerge = false; private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; + private int rawVectorSize = 32; private int quantizeBits = 8; private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; private int dimensions; @@ -278,6 +283,11 @@ public Builder setVectorSpace(String vectorSpace) { return this; } + public Builder setRawVectorSize(int rawVectorSize) { + this.rawVectorSize = rawVectorSize; + return this; + } + public Builder setQuantizeBits(int quantizeBits) { this.quantizeBits = quantizeBits; return this; @@ -343,6 +353,7 @@ public CmdLineArgs build() { filterSelectivity, seed, vectorSpace, + rawVectorSize, quantizeBits, vectorEncoding, dimensions, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index b0695137dfef4..b36314ed3a5a0 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -30,6 +30,8 @@ import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat; import org.elasticsearch.logging.Level; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -105,9 +107,17 @@ static Codec createCodec(CmdLineArgs args) { } else { if (args.quantizeBits() == 1) { if (args.indexType() == IndexType.FLAT) { - format = new ES818BinaryQuantizedVectorsFormat(); + if (args.rawVectorSize() == 16) { + format = new ES92BinaryQuantizedBFloat16VectorsFormat(); + } else { + format = new ES818BinaryQuantizedVectorsFormat(); + } } else { - format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null); + if (args.rawVectorSize() == 16) { + format = new ES92HnswBinaryQuantizedBFloat16VectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null); + } else { + format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null); + } } } else if (args.quantizeBits() < 32) { if (args.indexType() == IndexType.FLAT) { diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml new file mode 100644 index 0000000000000..700d3cc33b51e --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw_bfloat16.yml @@ -0,0 +1,580 @@ +setup: + - requires: + cluster_features: "mapper.bbq_bfloat16" + reason: 'bfloat16 needs to be supported' + - do: + indices.create: + index: bbq_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + index: + index: bbq_hnsw + id: "3" + body: + name: rabbit.jpg + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_hnsw + + - do: + indices.forcemerge: + index: bbq_hnsw + max_num_segments: 1 + + - do: + indices.refresh: { } +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad quantization parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: false + index_options: + type: bbq_hnsw +--- +"Test few dimensions fail indexing": + - do: + catch: bad_request + indices.create: + index: bad_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 42 + index: true + index_options: + type: bbq_hnsw + + - do: + indices.create: + index: dynamic_dim_bbq_hnsw + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index: true + similarity: l2_norm + index_options: + type: bbq_hnsw + + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + - do: + index: + index: dynamic_dim_bbq_hnsw + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } +--- +"Test index configured rescore vector updateable and settable to 0": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + + - do: + indices.create: + index: bbq_rescore_0_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + indices.create: + index: bbq_rescore_update_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 1 + + - do: + indices.put_mapping: + index: bbq_rescore_update_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + indices.get_mapping: + index: bbq_rescore_update_hnsw + + - match: { .bbq_rescore_update_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 0 } +--- +"Test index configured rescore vector score consistency": + - requires: + cluster_features: ["mapper.dense_vector.rescore_zero_vector"] + reason: Needs rescore_zero_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_zero_hnsw + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + bulk: + index: bbq_rescore_zero_hnsw + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: raw_score0 } + - set: { hits.hits.1._score: raw_score1 } + - set: { hits.hits.2._score: raw_score2 } + + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 2 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: override_score0 } + - set: { hits.hits.1._score: override_score1 } + - set: { hits.hits.2._score: override_score2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 2 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: default_rescore0 } + - set: { hits.hits.1._score: default_rescore1 } + - set: { hits.hits.2._score: default_rescore2 } + + - do: + indices.put_mapping: + index: bbq_rescore_zero_hnsw + body: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_hnsw + rescore_vector: + oversample: 0 + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_zero_hnsw + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $override_score0 } + - match: { hits.hits.0._score: $default_rescore0 } + - match: { hits.hits.1._score: $override_score1 } + - match: { hits.hits.1._score: $default_rescore1 } + - match: { hits.hits.2._score: $override_score2 } + - match: { hits.hits.2._score: $default_rescore2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_hnsw + + - match: { bbq_hnsw.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml new file mode 100644 index 0000000000000..82f059b05eb09 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat_bfloat16.yml @@ -0,0 +1,512 @@ +setup: + - requires: + cluster_features: "mapper.bbq_bfloat16" + reason: 'bfloat16 needs to be supported' + - do: + indices.create: + index: bbq_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + + - do: + index: + index: bbq_flat + id: "1" + body: + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "2" + body: + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + index: + index: bbq_flat + id: "3" + body: + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] + # Flush in order to provoke a merge later + - do: + indices.flush: + index: bbq_flat + + - do: + indices.forcemerge: + index: bbq_flat + max_num_segments: 1 +--- +"Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" + - do: + search: + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } +--- +"Vector rescoring has same scoring as exact search for kNN section": + - requires: + reason: 'Quantized vector rescoring is required' + test_runner_features: [capabilities] + capabilities: + - method: GET + path: /_search + capabilities: [knn_quantized_vector_rescore_oversample] + - skip: + features: "headers" + + # Rescore + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + rescore_vector: + oversample: 1.5 + + # Get rescoring scores - hit ordering may change depending on how things are distributed + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + # Exact knn via script score + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_flat + body: + query: + script_score: + query: { match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17, + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"Test bad parameters": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + index_options: + type: bbq_flat + m: 42 +--- +"Test bad raw vector size": + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + dims: 64 + index: true + index_options: + type: bbq_flat + raw_vector_size: 25 +--- +"Test few dimensions fail indexing": + # verify index creation fails + - do: + catch: bad_request + indices.create: + index: bad_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 42 + index: true + similarity: l2_norm + index_options: + type: bbq_flat + + # verify dynamic dimension fails + - do: + indices.create: + index: dynamic_dim_bbq_flat + body: + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + index: true + similarity: l2_norm + index_options: + type: bbq_flat + + # verify index fails for odd dim vector + - do: + catch: bad_request + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 5.0] + + # verify that we can index an even dim vector after the odd dim vector failure + - do: + index: + index: dynamic_dim_bbq_flat + body: + vector: [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0] +--- +"Test index configured rescore vector": + - requires: + cluster_features: ["mapper.dense_vector.rescore_vector"] + reason: Needs rescore_vector feature + - skip: + features: "headers" + - do: + indices.create: + index: bbq_rescore_flat + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + rescore_vector: + oversample: 1.5 + + - do: + bulk: + index: bbq_rescore_flat + refresh: true + body: | + { "index": {"_id": "1"}} + { "vector": [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] } + { "index": {"_id": "2"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + { "index": {"_id": "3"}} + { "vector": [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + knn: + field: vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + k: 3 + num_candidates: 3 + + - match: { hits.total: 3 } + - set: { hits.hits.0._score: rescore_score0 } + - set: { hits.hits.1._score: rescore_score1 } + - set: { hits.hits.2._score: rescore_score2 } + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + index: bbq_rescore_flat + body: + query: + script_score: + query: {match_all: {} } + script: + source: "double similarity = dotProduct(params.query_vector, 'vector'); return similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1" + params: + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + + # Compare scores as hit IDs may change depending on how things are distributed + - match: { hits.total: 3 } + - match: { hits.hits.0._score: $rescore_score0 } + - match: { hits.hits.1._score: $rescore_score1 } + - match: { hits.hits.2._score: $rescore_score2 } + +--- +"default oversample value": + - requires: + cluster_features: ["mapper.dense_vector.default_oversample_value_for_bbq"] + reason: "Needs default_oversample_value_for_bbq feature" + - do: + indices.get_mapping: + index: bbq_flat + + - match: { bbq_flat.mappings.properties.vector.index_options.rescore_vector.oversample: 3.0 } +--- +"Test nested queries": + - do: + indices.create: + index: bbq_flat_nested + body: + settings: + index: + number_of_shards: 1 + mappings: + properties: + name: + type: keyword + nested: + type: nested + properties: + paragraph_id: + type: keyword + vector: + type: dense_vector + element_type: bfloat16 + dims: 64 + index: true + similarity: max_inner_product + index_options: + type: bbq_flat + + - do: + index: + index: bbq_flat_nested + id: "1" + body: + nested: + - paragraph_id: "1" + vector: [ 0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45 ] + - paragraph_id: "2" + vector: [ 0.7, 0.2 , 0.205, 0.63 , 0.032, 0.201, 0.167, 0.313, + 0.176, 0.1, 0.375, 0.334, 0.046, 0.078, 0.349, 0.272, + 0.307, 0.083, 0.504, 0.255, 0.404, 0.289, 0.226, 0.132, + 0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , 0.265, + 0.285, 0.336, 0.272, 0.369, -0.282, 0.086, 0.132, 0.475, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.224, 0.203, 0.439, 0.064, 0.246, 0.396, 0.297, 0.242, + 0.028, 0.321, 0.022, 0.009, 0.001 , 0.031, -0.533, 0.45] + - do: + index: + index: bbq_flat_nested + id: "2" + body: + nested: + - paragraph_id: 0 + vector: [ 0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, -0.013 ] + - paragraph_id: 2 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, 0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, 0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, 0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, 0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27, 0.013 ] + - paragraph_id: 3 + vector: [ 0.196, 0.514, 0.039, 0.555, 0.042, 0.242, 0.463, -0.348, + 0.08 , 0.442, -0.067, -0.05 , 0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, 0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, 0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, 0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, 0.209, -0.153, -0.27, -0.013 ] + + - do: + index: + index: bbq_flat_nested + id: "3" + body: + nested: + - paragraph_id: 0 + vector: [ 0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058 ] + + - do: + indices.flush: + index: bbq_flat_nested + + - do: + indices.forcemerge: + index: bbq_flat_nested + max_num_segments: 1 + + - do: + search: + index: bbq_flat_nested + body: + query: + nested: + path: nested + query: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} + + - do: + search: + index: bbq_flat_nested + body: + knn: + field: nested.vector + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] + num_candidates: 3 + k: 2 + + - match: {hits.hits.0._id: "3"} diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 6ac0ce7ba99a8..21d255aece93f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -459,6 +459,8 @@ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat, + org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat, org.elasticsearch.index.codec.vectors.IVFVectorsFormat; provides org.apache.lucene.codecs.Codec @@ -485,5 +487,6 @@ exports org.elasticsearch.index.codec.perfield; exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn; exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn; + exports org.elasticsearch.index.codec.vectors.es92 to org.elasticsearch.test.knn; exports org.elasticsearch.inference.telemetry; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/MergeReaderWrapper.java similarity index 93% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/MergeReaderWrapper.java index e26784ecfdd96..afc2ca77dda2e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/MergeReaderWrapper.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/MergeReaderWrapper.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.index.codec.vectors.es818; +package org.elasticsearch.index.codec.vectors; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.index.ByteVectorValues; @@ -23,12 +23,12 @@ import java.util.Collection; import java.util.Map; -class MergeReaderWrapper extends FlatVectorsReader { +public class MergeReaderWrapper extends FlatVectorsReader { private final FlatVectorsReader mainReader; private final FlatVectorsReader mergeReader; - protected MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) { + public MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) { super(mainReader.getFlatVectorScorer()); this.mainReader = mainReader; this.mergeReader = mergeReader; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java index 8e328b5c500ad..c907bac2a2e13 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/DirectIOLucene99FlatVectorsFormat.java @@ -31,6 +31,7 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.MergeInfo; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; import org.elasticsearch.index.store.FsDirectoryFactory; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java index e59b9ba7e6535..2d3699d915f81 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -71,7 +71,7 @@ public class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader { private final ES818BinaryFlatVectorsScorer vectorScorer; @SuppressWarnings("this-escape") - ES818BinaryQuantizedVectorsReader( + public ES818BinaryQuantizedVectorsReader( SegmentReadState state, FlatVectorsReader rawVectorsReader, ES818BinaryFlatVectorsScorer vectorsScorer diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java index 22520567f2954..abceb15f62ad5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -85,7 +85,7 @@ public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { * @param vectorsScorer the scorer to use for scoring vectors */ @SuppressWarnings("this-escape") - protected ES818BinaryQuantizedVectorsWriter( + public ES818BinaryQuantizedVectorsWriter( ES818BinaryFlatVectorsScorer vectorsScorer, FlatVectorsWriter rawVectorDelegate, SegmentWriteState state diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java new file mode 100644 index 0000000000000..639ea6e8eed5a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/BFloat16.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.util.BitUtil; + +import java.nio.ByteOrder; +import java.nio.ShortBuffer; + +class BFloat16 { + + public static final int BYTES = Short.BYTES; + + public static short floatToBFloat16(float f) { + // this rounds towards 0 + // zero - zero exp, zero fraction + // denormal - zero exp, non-zero fraction + // infinity - all-1 exp, zero fraction + // NaN - all-1 exp, non-zero fraction + // the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into + // infinities + return (short) (Float.floatToIntBits(f) >>> 16); + } + + public static float bFloat16ToFloat(short bf) { + return Float.intBitsToFloat(bf << 16); + } + + public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) { + assert bFloats.remaining() == floats.length; + assert bFloats.order() == ByteOrder.LITTLE_ENDIAN; + for (float v : floats) { + bFloats.put(floatToBFloat16(v)); + } + } + + public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) { + assert floats.length * 2 == bfBytes.length; + for (int i = 0; i < floats.length; i++) { + floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2)); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java new file mode 100644 index 0000000000000..7097d71ecc2fe --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsFormat.java @@ -0,0 +1,129 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.FlushInfo; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.MergeInfo; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.index.codec.vectors.MergeReaderWrapper; +import org.elasticsearch.index.codec.vectors.es818.DirectIOHint; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.store.FsDirectoryFactory; + +import java.io.IOException; +import java.util.Set; + +public final class ES92BFloat16FlatVectorsFormat extends FlatVectorsFormat { + + static final String NAME = "ES92BFloat16FlatVectorsFormat"; + static final String META_CODEC_NAME = "ES92BFloat16FlatVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES92BFloat16FlatVectorsFormatData"; + static final String META_EXTENSION = "vemf"; + static final String VECTOR_DATA_EXTENSION = "vec"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private final FlatVectorsScorer vectorsScorer; + + public ES92BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { + super(NAME); + this.vectorsScorer = vectorsScorer; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES92BFloat16FlatVectorsWriter(state, vectorsScorer); + } + + static boolean shouldUseDirectIO(SegmentReadState state) { + return ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO && FsDirectoryFactory.isHybridFs(state.directory); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + if (shouldUseDirectIO(state) && state.context.context() == IOContext.Context.DEFAULT) { + // only override the context for the random-access use case + SegmentReadState directIOState = new SegmentReadState( + state.directory, + state.segmentInfo, + state.fieldInfos, + new DirectIOContext(state.context.hints()), + state.segmentSuffix + ); + // Use mmap for merges and direct I/O for searches. + // TODO: Open the mmap file with sequential access instead of random (current behavior). + return new MergeReaderWrapper( + new ES92BFloat16FlatVectorsReader(directIOState, vectorsScorer), + new ES92BFloat16FlatVectorsReader(state, vectorsScorer) + ); + } else { + return new ES92BFloat16FlatVectorsReader(state, vectorsScorer); + } + } + + @Override + public String toString() { + return "ES92BFloat16FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')'; + } + + static class DirectIOContext implements IOContext { + + final Set hints; + + DirectIOContext(Set hints) { + // always add DirectIOHint to the hints given + this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE)); + } + + @Override + public Context context() { + return Context.DEFAULT; + } + + @Override + public MergeInfo mergeInfo() { + return null; + } + + @Override + public FlushInfo flushInfo() { + return null; + } + + @Override + public Set hints() { + return hints; + } + + @Override + public IOContext withHints(FileOpenHint... hints) { + return new DirectIOContext(Set.of(hints)); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java new file mode 100644 index 0000000000000..99d7302e4bc56 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsReader.java @@ -0,0 +1,328 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.core.IOUtils; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +public final class ES92BFloat16FlatVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES92BFloat16FlatVectorsReader.class); + + private final IntObjectHashMap fields = new IntObjectHashMap<>(); + private final IndexInput vectorData; + private final FieldInfos fieldInfos; + + public ES92BFloat16FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + int versionMeta = readMetadata(state); + this.fieldInfos = state.fieldInfos; + boolean success = false; + try { + vectorData = openDataInput( + state, + versionMeta, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Flat formats are used to randomly access vectors from their node ID that is stored + // in the HNSW graph. + state.context.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private int readMetadata(SegmentReadState state) throws IOException { + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BFloat16FlatVectorsFormat.META_EXTENSION + ); + int versionMeta = -1; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES92BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES92BFloat16FlatVectorsFormat.VERSION_START, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + return versionMeta; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES92BFloat16FlatVectorsFormat.VERSION_START, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = FieldEntry.create(meta, info); + fields.put(info.number, fieldEntry); + } + } + + @Override + public long ramBytesUsed() { + return ES92BFloat16FlatVectorsReader.SHALLOW_SIZE + fields.ramBytesUsed(); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + final FieldEntry entry = getFieldEntryOrThrow(fieldInfo.name); + return Map.of(ES92BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION, entry.vectorDataLength()); + } + + @Override + public void checkIntegrity() throws IOException { + CodecUtil.checksumEntireFile(vectorData); + } + + @Override + public FlatVectorsReader getMergeInstance() { + try { + // Update the read advice since vectors are guaranteed to be accessed sequentially for merge + this.vectorData.updateReadAdvice(ReadAdvice.SEQUENTIAL); + return this; + } catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private FieldEntry getFieldEntryOrThrow(String field) { + final FieldInfo info = fieldInfos.fieldInfo(field); + final FieldEntry entry; + if (info == null || (entry = fields.get(info.number)) == null) { + throw new IllegalArgumentException("field=\"" + field + "\" not found"); + } + return entry; + } + + private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) { + final FieldEntry fieldEntry = getFieldEntryOrThrow(field); + if (fieldEntry.vectorEncoding != expectedEncoding) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fieldEntry.vectorEncoding + " expected: " + expectedEncoding + ); + } + return fieldEntry; + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + throw new IllegalStateException(field + " only supports float vectors"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32); + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, + OffHeapBFloat16VectorValues.load( + fieldEntry.similarityFunction, + vectorScorer, + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + throw new UnsupportedOperationException(field + " only supports float vectors"); + } + + @Override + public void finishMerge() throws IOException { + // This makes sure that the access pattern hint is reverted back since HNSW implementation + // needs it + this.vectorData.updateReadAdvice(ReadAdvice.RANDOM); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorData); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long vectorDataOffset, + long vectorDataLength, + int dimension, + int size, + OrdToDocDISIReaderConfiguration ordToDoc, + FieldInfo info + ) { + + FieldEntry { + if (vectorEncoding == VectorEncoding.BYTE) { + throw new IllegalStateException( + "Incorrect vector encoding for field=\"" + info.name + "\"; " + vectorEncoding + " not supported" + ); + } + + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + int infoVectorDimension = info.getVectorDimension(); + if (infoVectorDimension != dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + infoVectorDimension + " != " + dimension + ); + } + + int byteSize = BFloat16.BYTES; + long vectorBytes = Math.multiplyExact((long) infoVectorDimension, byteSize); + long numBytes = Math.multiplyExact(vectorBytes, size); + if (numBytes != vectorDataLength) { + throw new IllegalStateException( + "Vector data length " + + vectorDataLength + + " not matching size=" + + size + + " * dim=" + + dimension + + " * byteSize=" + + byteSize + + " = " + + numBytes + ); + } + } + + static FieldEntry create(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final var vectorDataOffset = input.readVLong(); + final var vectorDataLength = input.readVLong(); + final var dimension = input.readVInt(); + final var size = input.readInt(); + final var ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry(similarityFunction, vectorEncoding, vectorDataOffset, vectorDataLength, dimension, size, ordToDoc, info); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java new file mode 100644 index 0000000000000..b40765b051538 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BFloat16FlatVectorsWriter.java @@ -0,0 +1,433 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.store.DataAccessHint; +import org.apache.lucene.store.FileDataHint; +import org.apache.lucene.store.FileTypeHint; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.core.IOUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.elasticsearch.index.codec.vectors.es92.ES92BFloat16FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +public final class ES92BFloat16FlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ES92BFloat16FlatVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final IndexOutput meta, vectorData; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + public ES92BFloat16FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) throws IOException { + super(scorer); + segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BFloat16FlatVectorsFormat.META_EXTENSION + ); + + String vectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_EXTENSION + ); + + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES92BFloat16FlatVectorsFormat.META_CODEC_NAME, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + vectorData, + ES92BFloat16FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES92BFloat16FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo); + fields.add(newField); + return newField; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + for (FieldWriter field : fields) { + if (sortMap == null) { + writeField(field, maxDoc); + } else { + writeSortingField(field, maxDoc, sortMap); + } + field.finish(); + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + switch (fieldData.fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeBFloat16Vectors(fieldData); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField); + } + + private void writeBFloat16Vectors(FieldWriter fieldData) throws IOException { + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldData.vectors) { + BFloat16.floatToBFloat16((float[]) v, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + } + + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) throws IOException { + final int[] ordMap = new int[fieldData.docsWithField.cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.docsWithField, sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = switch (fieldData.fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeSortedBFloat16Vectors(fieldData, ordMap); + case BYTE -> throw new IllegalStateException( + "Incorrect encoding for field " + fieldData.fieldInfo.name + ": " + VectorEncoding.BYTE + ); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); + } + + private long writeSortedBFloat16Vectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + final ByteBuffer buffer = ByteBuffer.allocate(fieldData.dim * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] vector = (float[]) fieldData.vectors.get(ordinal); + BFloat16.floatToBFloat16(vector, buffer.asShortBuffer()); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + return vectorDataOffset; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + // No need to use temporary file as we don't have to re-open for reading + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeVectorData(vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(BFloat16.BYTES); + IndexOutput tempVectorData = segmentWriteState.directory.createTempOutput(vectorData.getName(), "temp", segmentWriteState.context); + IndexInput vectorDataInput = null; + boolean success = false; + try { + // write the vector data to a temporary file + DocsWithFieldSet docsWithField = switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> writeVectorData(tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + case BYTE -> throw new UnsupportedOperationException("ES92BFloat16FlatVectorsWriter only supports float vectors"); + }; + CodecUtil.writeFooter(tempVectorData); + IOUtils.close(tempVectorData); + + // This temp file will be accessed in a random-access fashion to construct the HNSW graph. + // Note: don't use the context from the state, which is a flush/merge context, not expecting + // to perform random reads. + vectorDataInput = segmentWriteState.directory.openInput( + tempVectorData.getName(), + IOContext.DEFAULT.withHints(FileTypeHint.DATA, FileDataHint.KNN_VECTORS, DataAccessHint.RANDOM) + ); + // copy the temporary file vectors to the actual data file + vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(vectorDataInput); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta(fieldInfo, segmentWriteState.segmentInfo.maxDoc(), vectorDataOffset, vectorDataLength, docsWithField); + success = true; + final IndexInput finalVectorDataInput = vectorDataInput; + final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * BFloat16.BYTES, + vectorsScorer, + fieldInfo.getVectorSimilarityFunction() + ) + ); + return new FlatCloseableRandomVectorScorerSupplier(() -> { + IOUtils.close(finalVectorDataInput); + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + }, docsWithField.cardinality(), randomVectorScorerSupplier); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData); + try { + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + } catch (Exception e) { + // ignore + } + } + } + } + + private void writeMeta(FieldInfo field, int maxDoc, long vectorDataOffset, long vectorDataLength, DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + + // write docIDs + int count = docsWithField.cardinality(); + meta.writeInt(count); + OrdToDocDISIReaderConfiguration.writeStoredMeta(DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * BFloat16.BYTES).order(ByteOrder.LITTLE_ENDIAN); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { + // write vector + float[] value = floatVectorValues.vectorValue(iter.index()); + BFloat16.floatToBFloat16(value, buffer.asShortBuffer()); + output.writeBytes(buffer.array(), buffer.limit()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData); + } + + private abstract static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final int dim; + private final DocsWithFieldSet docsWithField; + private final List vectors; + private boolean finished; + + private int lastDocID = -1; + + static FieldWriter create(FieldInfo fieldInfo) { + int dim = fieldInfo.getVectorDimension(); + return switch (fieldInfo.getVectorEncoding()) { + case FLOAT32 -> new ES92BFloat16FlatVectorsWriter.FieldWriter(fieldInfo) { + @Override + public float[] copyValue(float[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + case BYTE -> throw new IllegalStateException("Incorrect encoding for field " + fieldInfo.name + ": " + VectorEncoding.BYTE); + }; + } + + FieldWriter(FieldInfo fieldInfo) { + super(); + this.fieldInfo = fieldInfo; + this.dim = fieldInfo.getVectorDimension(); + this.docsWithField = new DocsWithFieldSet(); + vectors = new ArrayList<>(); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + if (finished) { + throw new IllegalStateException("already finished, cannot add more values"); + } + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)" + ); + } + assert docID > lastDocID; + T copy = copyValue(vectorValue); + docsWithField.add(docID); + vectors.add(copy); + lastDocID = docID; + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_RAM_BYTES_USED; + if (vectors.size() == 0) return size; + + int byteSize = fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32 + ? BFloat16.BYTES + : fieldInfo.getVectorEncoding().byteSize; + + return size + docsWithField.ramBytesUsed() + (long) vectors.size() * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) vectors.size() * fieldInfo.getVectorDimension() * byteSize; + } + + @Override + public List getVectors() { + return vectors; + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return docsWithField; + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + this.finished = true; + } + + @Override + public boolean isFinished() { + return finished; + } + } + + static final class FlatCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + + private final RandomVectorScorerSupplier supplier; + private final Closeable onClose; + private final int numVectors; + + FlatCloseableRandomVectorScorerSupplier(Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { + this.onClose = onClose; + this.supplier = supplier; + this.numVectors = numVectors; + } + + @Override + public UpdateableRandomVectorScorer scorer() throws IOException { + return supplier.scorer(); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return numVectors; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java new file mode 100644 index 0000000000000..c029cecb2200b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormat.java @@ -0,0 +1,135 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsWriter; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + * Codec for encoding/decoding binary quantized vectors The binary quantization format used here + * is a per-vector optimized scalar quantization. Also see {@link + * OptimizedScalarQuantizer}. Some of key features are: + * + *
    + *
  • Estimating the distance between two vectors using their centroid normalized distance. This + * requires some additional corrective factors, but allows for centroid normalization to occur. + *
  • Optimized scalar quantization to bit level of centroid normalized vectors. + *
  • Asymmetric quantization of vectors, where query vectors are quantized to half-byte + * precision (normalized to the centroid) and then compared directly against the single bit + * quantized vectors in the index. + *
  • Transforming the half-byte quantized query vectors in such a way that the comparison with + * single bit vectors can be done with bit arithmetic. + *
+ * + * The format is stored in two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

    + *
  • For each vector: + *
      + *
    • [byte] the binary quantized values, each byte holds 8 bits. + *
    • [float] the optimized quantiles and an additional similarity dependent corrective factor. + *
    • short the sum of the quantized components
    • + *
    + *
  • After the vectors, sparse vector information keeping track of monotonic blocks. + *
+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

    + *
  • int the field number + *
  • int the vector encoding ordinal + *
  • int the vector similarity ordinal + *
  • vint the vector dimensions + *
  • vlong the offset to the vector data in the .veb file + *
  • vlong the length of the vector data in the .veb file + *
  • vint the number of vectors + *
  • [float] the centroid
  • + *
  • float the centroid square magnitude
  • + *
  • The sparse vector information, if required, mapping vector ordinal to doc ID + *
+ */ +public class ES92BinaryQuantizedBFloat16VectorsFormat extends FlatVectorsFormat { + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "ES92BinaryQuantizedBFloat16VectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "ES92BinaryQuantizedBFloat16VectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES92BinaryQuantizedBFloat16VectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = new ES92BFloat16FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES92BinaryQuantizedBFloat16VectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES92BinaryQuantizedBFloat16VectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java new file mode 100644 index 0000000000000..38b54d1520f9f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormat.java @@ -0,0 +1,145 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES92HnswBinaryQuantizedBFloat16VectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "ES92HnswBinaryQuantizedBFloat16VectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = new ES92BinaryQuantizedBFloat16VectorsFormat(); + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public ES92HnswBinaryQuantizedBFloat16VectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public ES92HnswBinaryQuantizedBFloat16VectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public ES92HnswBinaryQuantizedBFloat16VectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES92HnswBinaryQuantizedBFloat16VectorsFormat(name=ES92HnswBinaryQuantizedBFloat16VectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java new file mode 100644 index 0000000000000..f793c94d47a54 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es92/OffHeapBFloat16VectorValues.java @@ -0,0 +1,311 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; + +import java.io.IOException; + +abstract class OffHeapBFloat16VectorValues extends FloatVectorValues implements HasIndexSlice { + + protected final int dimension; + protected final int size; + protected final IndexInput slice; + protected final int byteSize; + protected int lastOrd = -1; + protected final byte[] bfloatBytes; + protected final float[] value; + protected final VectorSimilarityFunction similarityFunction; + protected final FlatVectorsScorer flatVectorsScorer; + + OffHeapBFloat16VectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + this.dimension = dimension; + this.size = size; + this.slice = slice; + this.byteSize = byteSize; + this.similarityFunction = similarityFunction; + this.flatVectorsScorer = flatVectorsScorer; + bfloatBytes = new byte[dimension * 2]; + value = new float[dimension]; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public IndexInput getSlice() { + return slice; + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return value; + } + slice.seek((long) targetOrd * byteSize); + // no readShorts() method + slice.readBytes(bfloatBytes, 0, bfloatBytes.length); + BFloat16.bFloat16ToFloat(bfloatBytes, value); + lastOrd = targetOrd; + return value; + } + + static OffHeapBFloat16VectorValues load( + VectorSimilarityFunction vectorSimilarityFunction, + FlatVectorsScorer flatVectorsScorer, + OrdToDocDISIReaderConfiguration configuration, + VectorEncoding vectorEncoding, + int dimension, + int size, + long vectorDataOffset, + long vectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty() || vectorEncoding != VectorEncoding.FLOAT32) { + return new EmptyOffHeapVectorValues(dimension, flatVectorsScorer, vectorSimilarityFunction); + } + IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); + int byteSize = dimension * BFloat16.BYTES; + if (configuration.isDense()) { + return new DenseOffHeapVectorValues(dimension, size, bytesSlice, byteSize, flatVectorsScorer, vectorSimilarityFunction); + } else { + return new SparseOffHeapVectorValues( + configuration, + vectorData, + bytesSlice, + dimension, + size, + byteSize, + flatVectorsScorer, + vectorSimilarityFunction + ); + } + } + + /** + * Dense vector values that are stored off-heap. This is the most common case when every doc has a + * vector. + */ + static class DenseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + DenseOffHeapVectorValues( + int dimension, + int size, + IndexInput slice, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) { + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); + } + + @Override + public int ordToDoc(int ord) { + return ord; + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.docID()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class SparseOffHeapVectorValues extends OffHeapBFloat16VectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + IndexInput dataIn, + IndexInput slice, + int dimension, + int size, + int byteSize, + FlatVectorsScorer flatVectorsScorer, + VectorSimilarityFunction similarityFunction + ) throws IOException { + + super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dataIn, + slice.clone(), + dimension, + size, + byteSize, + flatVectorsScorer, + similarityFunction + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); + return new VectorScorer() { + @Override + public float score() throws IOException { + return randomVectorScorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBFloat16VectorValues { + + EmptyOffHeapVectorValues(int dimension, FlatVectorsScorer flatVectorsScorer, VectorSimilarityFunction similarityFunction) { + super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); + } + + @Override + public int dimension() { + return super.dimension(); + } + + @Override + public int size() { + return 0; + } + + @Override + public EmptyOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] vectorValue(int targetOrd) { + throw new UnsupportedOperationException(); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] query) { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index 7ba2dfb9a69f5..be52ced30bff5 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -47,6 +47,7 @@ public class MapperFeatures implements FeatureSpecification { static final NodeFeature BBQ_DISK_SUPPORT = new NodeFeature("mapper.bbq_disk_support"); static final NodeFeature SEARCH_LOAD_PER_SHARD = new NodeFeature("mapper.search_load_per_shard"); static final NodeFeature PATTERNED_TEXT = new NodeFeature("mapper.patterned_text"); + public static final NodeFeature BBQ_BFLOAT16 = new NodeFeature("mapper.bbq_bfloat16"); @Override public Set getTestFeatures() { @@ -80,7 +81,8 @@ public Set getTestFeatures() { BBQ_DISK_SUPPORT, SEARCH_LOAD_PER_SHARD, SPARSE_VECTOR_INDEX_OPTIONS_FEATURE, - PATTERNED_TEXT + PATTERNED_TEXT, + BBQ_BFLOAT16 ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index c9c14d027ebfd..b63a023140e58 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -58,6 +58,8 @@ import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -1120,6 +1122,182 @@ public void checkDimensions(Integer dvDims, int qvDims) { ); } } + }, + + BFLOAT16 { + + @Override + public String toString() { + return "bfloat16"; + } + + @Override + public void writeValue(ByteBuffer byteBuffer, float value) { + byteBuffer.putShort((short) (Float.floatToIntBits(value) >>> 16)); + } + + @Override + public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException { + b.value(Float.intBitsToFloat(byteBuffer.getShort() << 16)); + } + + private KnnFloatVectorField createKnnVectorField(String name, float[] vector, VectorSimilarityFunction function) { + if (vector == null) { + throw new IllegalArgumentException("vector value must not be null"); + } + FieldType denseVectorFieldType = new FieldType(); + denseVectorFieldType.setVectorAttributes(vector.length, VectorEncoding.FLOAT32, function); + denseVectorFieldType.freeze(); + return new KnnFloatVectorField(name, vector, denseVectorFieldType); + } + + @Override + IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext) { + return new VectorIndexFieldData.Builder( + denseVectorFieldType.name(), + CoreValuesSourceType.KEYWORD, + denseVectorFieldType.indexVersionCreated, + this, + denseVectorFieldType.dims, + denseVectorFieldType.indexed, + denseVectorFieldType.indexVersionCreated.onOrAfter(NORMALIZE_COSINE) + && denseVectorFieldType.indexed + && denseVectorFieldType.similarity.equals(VectorSimilarity.COSINE) ? r -> new FilterLeafReader(r) { + @Override + public CacheHelper getCoreCacheHelper() { + return r.getCoreCacheHelper(); + } + + @Override + public CacheHelper getReaderCacheHelper() { + return r.getReaderCacheHelper(); + } + + @Override + public FloatVectorValues getFloatVectorValues(String fieldName) throws IOException { + FloatVectorValues values = in.getFloatVectorValues(fieldName); + if (values == null) { + return null; + } + return new DenormalizedCosineFloatVectorValues( + values, + in.getNumericDocValues(fieldName + COSINE_MAGNITUDE_FIELD_SUFFIX) + ); + } + } : r -> r + ); + } + + @Override + StringBuilder checkVectorErrors(float[] vector) { + return checkNanAndInfinite(vector); + } + + @Override + void checkVectorMagnitude( + VectorSimilarity similarity, + Function appender, + float squaredMagnitude + ) { + StringBuilder errorBuilder = null; + + if (Float.isNaN(squaredMagnitude) || Float.isInfinite(squaredMagnitude)) { + errorBuilder = new StringBuilder( + "NaN or Infinite magnitude detected, this usually means the vector values are too extreme to fit within a float." + ); + } + if (errorBuilder != null) { + throw new IllegalArgumentException(appender.apply(errorBuilder).toString()); + } + + if (similarity == VectorSimilarity.DOT_PRODUCT && isNotUnitVector(squaredMagnitude)) { + errorBuilder = new StringBuilder( + "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors." + ); + } else if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) { + errorBuilder = new StringBuilder( + "The [" + VectorSimilarity.COSINE + "] similarity does not support vectors with zero magnitude." + ); + } + + if (errorBuilder != null) { + throw new IllegalArgumentException(appender.apply(errorBuilder).toString()); + } + } + + @Override + public double computeSquaredMagnitude(VectorData vectorData) { + return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector()); + } + + @Override + public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + int index = 0; + float[] vector = new float[fieldMapper.fieldType().dims]; + float squaredMagnitude = 0; + for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + fieldMapper.checkDimensionExceeded(index, context); + ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); + + float value = context.parser().floatValue(true); + vector[index++] = value; + squaredMagnitude += value * value; + } + fieldMapper.checkDimensionMatches(index, context); + checkVectorBounds(vector); + checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); + if (fieldMapper.indexCreatedVersion.onOrAfter(NORMALIZE_COSINE) + && fieldMapper.fieldType().similarity.equals(VectorSimilarity.COSINE) + && isNotUnitVector(squaredMagnitude)) { + float length = (float) Math.sqrt(squaredMagnitude); + for (int i = 0; i < vector.length; i++) { + vector[i] /= length; + } + final String fieldName = fieldMapper.fieldType().name() + COSINE_MAGNITUDE_FIELD_SUFFIX; + Field magnitudeField = new FloatDocValuesField(fieldName, length); + context.doc().addWithKey(fieldName, magnitudeField); + } + Field field = createKnnVectorField( + fieldMapper.fieldType().name(), + vector, + fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this) + ); + context.doc().addWithKey(fieldMapper.fieldType().name(), field); + } + + @Override + public VectorData parseKnnVector( + DocumentParserContext context, + int dims, + IntBooleanConsumer dimChecker, + VectorSimilarity similarity + ) throws IOException { + int index = 0; + float squaredMagnitude = 0; + float[] vector = new float[dims]; + for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + dimChecker.accept(index, false); + ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); + float value = context.parser().floatValue(true); + vector[index] = value; + squaredMagnitude += value * value; + index++; + } + dimChecker.accept(index, true); + checkVectorBounds(vector); + checkVectorMagnitude(similarity, errorFloatElementsAppender(vector), squaredMagnitude); + return VectorData.fromFloats(vector); + } + + @Override + public int getNumBytes(int dimensions) { + return dimensions * Short.BYTES; + } + + @Override + public ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) { + return ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN); + } }; public abstract void writeValue(ByteBuffer byteBuffer, float value); @@ -1285,7 +1463,9 @@ public static ElementType fromString(String name) { ElementType.FLOAT.toString(), ElementType.FLOAT, ElementType.BIT.toString(), - ElementType.BIT + ElementType.BIT, + ElementType.BFLOAT16.toString(), + ElementType.BFLOAT16 ); public enum VectorSimilarity { @@ -1293,7 +1473,7 @@ public enum VectorSimilarity { @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> 1f / (1f + similarity * similarity); + case BYTE, FLOAT, BFLOAT16 -> 1f / (1f + similarity * similarity); case BIT -> (dim - similarity) / dim; }; } @@ -1308,7 +1488,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { assert elementType != ElementType.BIT; return switch (elementType) { - case BYTE, FLOAT -> (1 + similarity) / 2f; + case BYTE, FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1325,7 +1505,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15)); - case FLOAT -> (1 + similarity) / 2f; + case FLOAT, BFLOAT16 -> (1 + similarity) / 2f; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1339,7 +1519,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi @Override float score(float similarity, ElementType elementType, int dim) { return switch (elementType) { - case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; + case BYTE, FLOAT, BFLOAT16 -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]"); }; } @@ -1643,7 +1823,7 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); + case BFLOAT16 -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(m, efConstruction); + default -> throw new AssertionError(); + }; } @Override @@ -2249,8 +2432,11 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { - assert elementType == ElementType.FLOAT; - return new ES818BinaryQuantizedVectorsFormat(); + return switch (elementType) { + case FLOAT -> new ES818BinaryQuantizedVectorsFormat(); + case BFLOAT16 -> new ES92BinaryQuantizedBFloat16VectorsFormat(); + default -> throw new AssertionError(); + }; } @Override @@ -2482,7 +2668,7 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity) } Query knnQuery = switch (elementType) { case BYTE -> createExactKnnByteQuery(queryVector.asByteVector()); - case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector()); + case FLOAT, BFLOAT16 -> createExactKnnFloatQuery(queryVector.asFloatVector()); case BIT -> createExactKnnBitQuery(queryVector.asByteVector()); }; if (vectorSimilarity != null) { @@ -2555,7 +2741,7 @@ public Query createKnnQuery( knnSearchStrategy, hnswEarlyTermination ); - case FLOAT -> createKnnFloatQuery( + case FLOAT, BFLOAT16 -> createKnnFloatQuery( queryVector.asFloatVector(), k, numCands, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java index e44202d353629..fc99438c6d076 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java @@ -69,14 +69,14 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { if (indexed) { return switch (elementType) { case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); - case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); + case FLOAT, BFLOAT16 -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); }; } else { BinaryDocValues values = DocValues.getBinary(reader, field); return switch (elementType) { case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims); - case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); + case FLOAT, BFLOAT16 -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims); }; } @@ -138,7 +138,7 @@ public Object nextValue() { return vectorValue; } }; - case FLOAT -> new FormattedDocValues() { + case FLOAT, BFLOAT16 -> new FormattedDocValues() { float[] vector = new float[dims]; private FloatVectorValues floatVectorValues; // use when indexed private KnnVectorValues.DocIndexIterator iterator; // use when indexed diff --git a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java index d319c8d8b40aa..878cb42bdcce0 100644 --- a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java @@ -209,7 +209,7 @@ public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL1Norm(scoreScript, field, (List) queryVector); } @@ -319,7 +319,7 @@ public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatL2Norm(scoreScript, field, (List) queryVector); } @@ -477,7 +477,7 @@ public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatDotProduct(scoreScript, field, (List) queryVector); } @@ -546,7 +546,7 @@ public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fiel } throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new FloatCosineSimilarity(scoreScript, field, (List) queryVector); } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java new file mode 100644 index 0000000000000..21326faf1b6ab --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BFloat16RankVectorsDocValuesField.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ShortBuffer; +import java.util.Iterator; + +public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField { + + private final BinaryDocValues input; + private final BinaryDocValues magnitudes; + private boolean decoded; + private final int dims; + private BytesRef value; + private BytesRef magnitudesValue; + private BFloat16VectorIterator vectorValues; + private int numVectors; + private float[] buffer; + + public BFloat16RankVectorsDocValuesField( + BinaryDocValues input, + BinaryDocValues magnitudes, + String name, + ElementType elementType, + int dims + ) { + super(name, elementType); + this.input = input; + this.magnitudes = magnitudes; + this.dims = dims; + this.buffer = new float[dims]; + } + + @Override + public void setNextDocId(int docId) throws IOException { + decoded = false; + if (input.advanceExact(docId)) { + boolean magnitudesFound = magnitudes.advanceExact(docId); + assert magnitudesFound; + + value = input.binaryValue(); + assert value.length % (Short.BYTES * dims) == 0; + numVectors = value.length / (Short.BYTES * dims); + magnitudesValue = magnitudes.binaryValue(); + assert magnitudesValue.length == (Float.BYTES * numVectors); + } else { + value = null; + magnitudesValue = null; + numVectors = 0; + } + } + + @Override + public RankVectorsScriptDocValues toScriptDocValues() { + return new RankVectorsScriptDocValues(this, dims); + } + + @Override + public boolean isEmpty() { + return value == null; + } + + @Override + public RankVectors get() { + if (isEmpty()) { + return RankVectors.EMPTY; + } + decodeVectorIfNecessary(); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public RankVectors get(RankVectors defaultValue) { + if (isEmpty()) { + return defaultValue; + } + decodeVectorIfNecessary(); + return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims); + } + + @Override + public RankVectors getInternal() { + return get(null); + } + + @Override + public int size() { + return value == null ? 0 : value.length / (Short.BYTES * dims); + } + + private void decodeVectorIfNecessary() { + if (decoded == false && value != null) { + vectorValues = new BFloat16VectorIterator(value, buffer, numVectors); + decoded = true; + } + } + + public static class BFloat16VectorIterator implements VectorIterator { + private final float[] buffer; + private final ShortBuffer vectorValues; + private final BytesRef vectorValueBytesRef; + private final int size; + private int idx = 0; + + public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) { + assert vectorValues.length == (buffer.length * Short.BYTES * size); + this.vectorValueBytesRef = vectorValues; + this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) + .order(ByteOrder.LITTLE_ENDIAN) + .asShortBuffer(); + this.size = size; + this.buffer = buffer; + } + + @Override + public boolean hasNext() { + return idx < size; + } + + @Override + public float[] next() { + if (hasNext() == false) { + throw new IllegalArgumentException("No more elements in the iterator"); + } + for (int i = 0; i < buffer.length; i++) { + buffer[i] = Float.intBitsToFloat(vectorValues.get() << 16); + } + idx++; + return buffer; + } + + @Override + public Iterator copy() { + return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size); + } + + @Override + public void reset() { + idx = 0; + vectorValues.rewind(); + } + } +} diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 14e68029abc3b..8ca8b2e3c3a7f 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -7,4 +7,6 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat +org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat org.elasticsearch.index.codec.vectors.IVFVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..fa3aa9c25f72b --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92BinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,307 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.SoftDeletesRetentionMergePolicy; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.misc.store.DirectIODirectory; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.CheckJoinIndex; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.join.QueryBitSetProducer; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.store.MockDirectoryWrapper; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardPath; +import org.elasticsearch.index.store.FsDirectoryFactory; +import org.elasticsearch.test.IndexSettingsModule; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.OptionalLong; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES92BinaryQuantizedBFloat16VectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES92BinaryQuantizedBFloat16VectorsFormat()); + + @Override + protected Codec getCodec() { + return codec; + } + + static String encodeInts(int[] i) { + return Arrays.toString(i); + } + + static BitSetProducer parentFilter(IndexReader r) throws IOException { + // Create a filter that defines "parent" documents in the index + BitSetProducer parentsFilter = new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent"))); + CheckJoinIndex.check(r, parentsFilter); + return parentsFilter; + } + + Document makeParent(int[] children) { + Document parent = new Document(); + parent.add(newStringField("docType", "_parent", Field.Store.NO)); + parent.add(newStringField("id", encodeInts(children), Field.Store.YES)); + return parent; + } + + public void testEmptyDiversifiedChildSearch() throws Exception { + String fieldName = "field"; + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + try (Directory d = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig().setCodec(codec); + iwc.setMergePolicy(new SoftDeletesRetentionMergePolicy("soft_delete", MatchAllDocsQuery::new, iwc.getMergePolicy())); + try (IndexWriter w = new IndexWriter(d, iwc)) { + List toAdd = new ArrayList<>(); + for (int j = 1; j <= 5; j++) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction)); + doc.add(newStringField("id", Integer.toString(j), Field.Store.YES)); + toAdd.add(doc); + } + toAdd.add(makeParent(new int[] { 1, 2, 3, 4, 5 })); + w.addDocuments(toAdd); + w.addDocuments(List.of(makeParent(new int[] { 6, 7, 8, 9, 10 }))); + w.deleteDocuments(new FieldExistsQuery(fieldName), new TermQuery(new Term("id", encodeInts(new int[] { 1, 2, 3, 4, 5 })))); + w.flush(); + w.commit(); + w.forceMerge(1); + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + BitSetProducer parentFilter = parentFilter(searcher.getIndexReader()); + Query query = new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, null, 1, parentFilter); + assertTrue(searcher.search(query, 1).scoreDocs.length == 0); + } + } + + } + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES92BinaryQuantizedBFloat16VectorsFormat(); + } + }; + String expectedPattern = "ES92BinaryQuantizedBFloat16VectorsFormat(" + + "name=ES92BinaryQuantizedBFloat16VectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + // bfloat16 makes the results of these tests slightly out of bounds + @Override + public void testWriterRamEstimate() throws Exception {} + + @Override + public void testRandom() throws Exception {} + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception {} + + @Override + public void testSparseVectors() throws Exception {} + + public void testSimpleOffHeapSize() throws IOException { + try (Directory dir = newDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeFSDir() throws IOException { + checkDirectIOSupported(); + var config = newIndexWriterConfig().setUseCompoundFile(false); // avoid compound files to allow directIO + try (Directory dir = newFSDirectory()) { + testSimpleOffHeapSizeImpl(dir, config, false); + } + } + + public void testSimpleOffHeapSizeMMapDir() throws IOException { + try (Directory dir = newMMapDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (IndexWriter w = new IndexWriter(dir, config)) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(expectVecOffHeap ? 2 : 1, offHeap.size()); + assertTrue(offHeap.get("veb") > 0L); + if (expectVecOffHeap) { + assertEquals(vector.length * BFloat16.BYTES, (long) offHeap.get("vec")); + } + } + } + } + } + + static Directory newMMapDirectory() throws IOException { + Directory dir = new MMapDirectory(createTempDir("ES92BinaryQuantizedBFloat16VectorsFormatTests")); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + private Directory newFSDirectory() throws IOException { + Settings settings = Settings.builder() + .put(IndexModule.INDEX_STORE_TYPE_SETTING.getKey(), IndexModule.Type.HYBRIDFS.name().toLowerCase(Locale.ROOT)) + .build(); + IndexSettings idxSettings = IndexSettingsModule.newIndexSettings("foo", settings); + Path tempDir = createTempDir().resolve(idxSettings.getUUID()).resolve("0"); + Files.createDirectories(tempDir); + ShardPath path = new ShardPath(false, tempDir, tempDir, new ShardId(idxSettings.getIndex(), 0)); + Directory dir = (new FsDirectoryFactory()).newDirectory(idxSettings, path); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + static void checkDirectIOSupported() { + assumeTrue("Direct IO is not enabled", ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO); + + Path path = createTempDir("directIOProbe"); + try (Directory dir = open(path); IndexOutput out = dir.createOutput("out", IOContext.DEFAULT)) { + out.writeString("test"); + } catch (IOException e) { + assumeNoException("test requires a filesystem that supports Direct IO", e); + } + } + + static DirectIODirectory open(Path path) throws IOException { + return new DirectIODirectory(FSDirectory.open(path)) { + @Override + protected boolean useDirectIO(String name, IOContext context, OptionalLong fileLength) { + return true; + } + }; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java new file mode 100644 index 0000000000000..73d732e5fcc3e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es92/ES92HnswBinaryQuantizedBFloat16VectorsFormatTests.java @@ -0,0 +1,259 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es92; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.misc.store.DirectIODirectory; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.store.MockDirectoryWrapper; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.SameThreadExecutorService; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardPath; +import org.elasticsearch.index.store.FsDirectoryFactory; +import org.elasticsearch.test.IndexSettingsModule; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Locale; +import java.util.OptionalLong; + +import static java.lang.String.format; +import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES92HnswBinaryQuantizedBFloat16VectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES92HnswBinaryQuantizedBFloat16VectorsFormat()); + + @Override + protected Codec getCodec() { + return codec; + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES92HnswBinaryQuantizedBFloat16VectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "ES92HnswBinaryQuantizedBFloat16VectorsFormat(name=ES92HnswBinaryQuantizedBFloat16VectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES92BinaryQuantizedBFloat16VectorsFormat(name=ES92BinaryQuantizedBFloat16VectorsFormat," + + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.01f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true score + assertEquals(trueScore, td.scoreDocs[0].score, 0.01f); + } + } + } + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(-1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(0, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(20, 0)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(20, -1)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(512 + 1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES92HnswBinaryQuantizedBFloat16VectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) + ); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } + + // bfloat16 makes the results of these tests slightly out of bounds + @Override + public void testWriterRamEstimate() throws Exception {} + + @Override + public void testRandom() throws Exception {} + + @Override + public void testRandomWithUpdatesAndGraph() throws Exception {} + + @Override + public void testVectorValuesReportCorrectDocs() throws Exception {} + + @Override + public void testSparseVectors() throws Exception {} + + public void testSimpleOffHeapSize() throws IOException { + try (Directory dir = newDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeFSDir() throws IOException { + checkDirectIOSupported(); + var config = newIndexWriterConfig().setUseCompoundFile(false); // avoid compound files to allow directIO + try (Directory dir = newFSDirectory()) { + testSimpleOffHeapSizeImpl(dir, config, false); + } + } + + public void testSimpleOffHeapSizeMMapDir() throws IOException { + try (Directory dir = newMMapDirectory()) { + testSimpleOffHeapSizeImpl(dir, newIndexWriterConfig(), true); + } + } + + public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, boolean expectVecOffHeap) throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (IndexWriter w = new IndexWriter(dir, config)) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(expectVecOffHeap ? 3 : 2, offHeap.size()); + assertEquals(1L, (long) offHeap.get("vex")); + assertTrue(offHeap.get("veb") > 0L); + if (expectVecOffHeap) { + assertEquals(vector.length * BFloat16.BYTES, (long) offHeap.get("vec")); + } + } + } + } + } + + static Directory newMMapDirectory() throws IOException { + Directory dir = new MMapDirectory(createTempDir("ES92HnswBinaryQuantizedBFloat16VectorsFormatTests")); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + private Directory newFSDirectory() throws IOException { + Settings settings = Settings.builder() + .put(IndexModule.INDEX_STORE_TYPE_SETTING.getKey(), IndexModule.Type.HYBRIDFS.name().toLowerCase(Locale.ROOT)) + .build(); + IndexSettings idxSettings = IndexSettingsModule.newIndexSettings("foo", settings); + Path tempDir = createTempDir().resolve(idxSettings.getUUID()).resolve("0"); + Files.createDirectories(tempDir); + ShardPath path = new ShardPath(false, tempDir, tempDir, new ShardId(idxSettings.getIndex(), 0)); + Directory dir = (new FsDirectoryFactory()).newDirectory(idxSettings, path); + if (random().nextBoolean()) { + dir = new MockDirectoryWrapper(random(), dir); + } + return dir; + } + + static void checkDirectIOSupported() { + assumeTrue("Direct IO is not enabled", ES818BinaryQuantizedVectorsFormat.USE_DIRECT_IO); + + Path path = createTempDir("directIOProbe"); + try (Directory dir = open(path); IndexOutput out = dir.createOutput("out", IOContext.DEFAULT)) { + out.writeString("test"); + } catch (IOException e) { + assumeNoException("test requires a filesystem that supports Direct IO", e); + } + } + + static DirectIODirectory open(Path path) throws IOException { + return new DirectIODirectory(FSDirectory.open(path)) { + @Override + protected boolean useDirectIO(String name, IOContext context, OptionalLong fileLength) { + return true; + } + }; + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java index 9478508da88d0..c3e3ed0c93c98 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java @@ -22,14 +22,14 @@ private DenseVectorFieldMapperTestUtils() {} public static List getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) { return switch (elementType) { - case FLOAT, BYTE -> List.of(SimilarityMeasure.values()); + case FLOAT, BFLOAT16, BYTE -> List.of(SimilarityMeasure.values()); case BIT -> List.of(SimilarityMeasure.L2_NORM); }; } public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BFLOAT16, BYTE -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; @@ -43,7 +43,7 @@ public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType } return switch (elementType) { - case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); + case FLOAT, BFLOAT16, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max); case BIT -> { if (max < 8) { throw new IllegalArgumentException("max must be at least 8 for bit vectors"); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 02ef40eeda0ca..8e50c56c64725 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -2479,7 +2479,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) ft; return switch (vectorFieldType.getElementType()) { case BYTE -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions()); - case FLOAT -> { + case FLOAT, BFLOAT16 -> { float[] floats = new float[vectorFieldType.getVectorDimensions()]; float magnitude = 0; for (int i = 0; i < floats.length; i++) { @@ -3080,7 +3080,7 @@ public SyntheticSourceExample example(int maxValues) throws IOException { Object value = switch (elementType) { case BYTE, BIT: yield randomList(dims, dims, ESTestCase::randomByte); - case FLOAT: + case FLOAT, BFLOAT16: yield randomList(dims, dims, ESTestCase::randomFloat); }; return new SyntheticSourceExample(value, value, this::mapping); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index a8d9b1259cb41..cf85e8cea49b5 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -243,7 +243,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que filterQuery, expectedStrategy ); - case FLOAT -> new ESKnnFloatVectorQuery( + case FLOAT, BFLOAT16 -> new ESKnnFloatVectorQuery( VECTOR_FIELD, queryBuilder.queryVector().asFloatVector(), k, @@ -264,7 +264,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD); } } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (filterQuery != null) { yield new BooleanQuery.Builder().add( new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD), diff --git a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java index 8e62e18cf02c1..06b05f6fa0f11 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/RescoreKnnVectorQueryTests.java @@ -37,6 +37,8 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92BinaryQuantizedBFloat16VectorsFormat; +import org.elasticsearch.index.codec.vectors.es92.ES92HnswBinaryQuantizedBFloat16VectorsFormat; import org.elasticsearch.index.codec.zstd.Zstd814StoredFieldsFormat; import org.elasticsearch.index.mapper.vectors.VectorSimilarityFloatValueSource; import org.elasticsearch.search.profile.query.QueryProfiler; @@ -215,6 +217,8 @@ private static void addRandomDocuments(int numDocs, Directory d, int numDims) th IndexWriterConfig iwc = new IndexWriterConfig(); // Pick codec from quantized vector formats to ensure scores use real scores when using knn rescore KnnVectorsFormat format = randomFrom( + new ES92BinaryQuantizedBFloat16VectorsFormat(), + new ES92HnswBinaryQuantizedBFloat16VectorsFormat(), new ES818BinaryQuantizedVectorsFormat(), new ES818HnswBinaryQuantizedVectorsFormat(), new ES813Int8FlatVectorFormat(), diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 044af0ab1d37d..ca35823c452b8 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -255,7 +255,7 @@ private static List generateEmbedding(String input, int dimensions, Dense // Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) { return switch (elementType) { - case FLOAT, BYTE -> dimensions; + case FLOAT, BYTE, BFLOAT16 -> dimensions; case BIT -> { assert dimensions % Byte.SIZE == 0; yield dimensions / Byte.SIZE; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java index 6e1407beab1d8..b59854528758e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java @@ -96,7 +96,7 @@ public TextEmbeddingResults responseBody(SageMakerModel model, InvokeEndpoint return switch (model.apiServiceSettings().elementType()) { case BIT -> TextEmbeddingBinary.PARSER.apply(p, null); case BYTE -> TextEmbeddingBytes.PARSER.apply(p, null); - case FLOAT -> TextEmbeddingFloat.PARSER.apply(p, null); + case FLOAT, BFLOAT16 -> TextEmbeddingFloat.PARSER.apply(p, null); }; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index 175c3e90f798d..9bc1736a85c7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -269,7 +269,7 @@ private static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index d1499f4009d0a..fb6b5b85c9414 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -190,7 +190,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model mo return switch (model.getTaskType()) { case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); }; default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); @@ -222,7 +222,7 @@ public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Mode public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingFloat(Model model, List inputs) { DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType(); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); - assert elementType == DenseVectorFieldMapper.ElementType.FLOAT; + assert elementType == DenseVectorFieldMapper.ElementType.FLOAT || elementType == DenseVectorFieldMapper.ElementType.BFLOAT16; List chunks = new ArrayList<>(); for (String input : inputs) { @@ -272,7 +272,7 @@ public static SemanticTextField randomSemanticText( ) throws IOException { ChunkedInference results = switch (model.getTaskType()) { case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { - case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs); case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); }; case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); @@ -415,7 +415,7 @@ public static ChunkedInference toChunkedResult( ChunkedInference.TextOffset offset = createOffset(useLegacyFormat, entryChunk, matchedText); double[] values = parseDenseVector(entryChunk.rawEmbeddings(), embeddingLength, field.contentType()); EmbeddingResults.Embedding embedding = switch (elementType) { - case FLOAT -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); + case FLOAT, BFLOAT16 -> new TextEmbeddingFloatResults.Embedding(FloatConversionUtils.floatArrayOf(values)); case BYTE, BIT -> new TextEmbeddingByteResults.Embedding(byteArrayOf(values)); }; chunks.add(new EmbeddingResults.Chunk(embedding, offset)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index a7eda7112723e..6d54a4dd9eaca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -267,7 +267,7 @@ private void assertTextEmbeddingLuceneQuery(Query query) { Query innerQuery = assertOuterBooleanQuery(query); Class expectedKnnQueryClass = switch (denseVectorElementType) { - case FLOAT -> KnnFloatVectorQuery.class; + case FLOAT, BFLOAT16 -> KnnFloatVectorQuery.class; case BYTE, BIT -> KnnByteVectorQuery.class; }; assertThat(innerQuery, instanceOf(expectedKnnQueryClass)); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java index b858b935c1483..60d6bfc5586ee 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java @@ -16,6 +16,7 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.script.field.DocValuesScriptFieldFactory; +import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; @@ -24,7 +25,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; final class RankVectorsDVLeafFieldData implements LeafFieldData { @@ -123,7 +123,47 @@ public Object nextValue() { VectorIterator iterator = new FloatRankVectorsDocValuesField.FloatVectorIterator(ref, vector, numVecs); while (iterator.hasNext()) { float[] v = iterator.next(); - vectors.add(Arrays.copyOf(v, v.length)); + vectors.add(v.clone()); + } + return vectors; + } + }; + case BFLOAT16 -> new FormattedDocValues() { + private final float[] vector = new float[dims]; + private BytesRef ref = null; + private int numVecs = -1; + private final BinaryDocValues binary; + { + try { + binary = DocValues.getBinary(reader, field); + } catch (IOException e) { + throw new IllegalStateException("Cannot load doc values", e); + } + } + + @Override + public boolean advanceExact(int docId) throws IOException { + if (binary == null || binary.advanceExact(docId) == false) { + return false; + } + ref = binary.binaryValue(); + assert ref.length % (Short.BYTES * dims) == 0; + numVecs = ref.length / (Short.BYTES * dims); + return true; + } + + @Override + public int docValueCount() { + return 1; + } + + @Override + public Object nextValue() { + List vectors = new ArrayList<>(numVecs); + VectorIterator iterator = new BFloat16RankVectorsDocValuesField.BFloat16VectorIterator(ref, vector, numVecs); + while (iterator.hasNext()) { + float[] v = iterator.next(); + vectors.add(v.clone()); } return vectors; } @@ -140,6 +180,7 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) { case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); + case BFLOAT16 -> new BFloat16RankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims); }; } catch (IOException e) { throw new IllegalStateException("Cannot load doc values for multi-vector field!", e); diff --git a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java index 3d692db08ff9e..df47d43c6a002 100644 --- a/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java +++ b/x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java @@ -351,7 +351,7 @@ public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fiel yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list); } } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (queryVector instanceof List) { yield new MaxSimFloatDotProduct(scoreScript, field, (List>) queryVector); } diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java index 6f80d7b46f563..3e3ee2fd67336 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java @@ -417,7 +417,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) { } yield vectors; } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { float[][] vectors = new float[numVectors][vectorFieldType.getVectorDimensions()]; for (int i = 0; i < numVectors; i++) { for (int j = 0; j < vectorFieldType.getVectorDimensions(); j++) { @@ -473,7 +473,7 @@ public SyntheticSourceExample example(int maxValues) { Object value = switch (elementType) { case BYTE, BIT: yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomByte)); - case FLOAT: + case FLOAT, BFLOAT16: yield randomList(numVecs, numVecs, () -> randomList(dims, dims, ESTestCase::randomFloat)); }; return new SyntheticSourceExample(value, value, this::mapping); diff --git a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java index f0b00849557bd..8080762f2d492 100644 --- a/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java +++ b/x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsScriptDocValuesTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues; +import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField; import org.elasticsearch.script.field.vectors.RankVectors; @@ -51,6 +52,36 @@ public void testFloatGetVectorValueAndGetMagnitude() throws IOException { } } + public void testBFloat16GetVectorValueAndGetMagnitude() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] expectedMagnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(expectedMagnitudes); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + assertEquals(vectors[i].length, field.size()); + assertEquals(dims, scriptDocValues.dims()); + Iterator iterator = scriptDocValues.getVectorValues(); + float[] magnitudes = scriptDocValues.getMagnitudes(); + assertEquals(expectedMagnitudes[i].length, magnitudes.length); + for (int j = 0; j < vectors[i].length; j++) { + assertTrue(iterator.hasNext()); + assertArrayEquals(vectors[i][j], iterator.next(), 0.0001f); + assertEquals(expectedMagnitudes[i][j], magnitudes[j], 0.0001f); + } + } + } + public void testByteGetVectorValueAndGetMagnitude() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -97,6 +128,36 @@ public void testFloatMetadataAndIterator() throws IOException { assertEquals(dv, RankVectors.EMPTY); } + public void testBFloat16MetadataAndIterator() throws IOException { + int dims = 3; + float[][][] vectors = new float[][][] { + fill(new float[3][dims], ElementType.BFLOAT16), + fill(new float[2][dims], ElementType.BFLOAT16) }; + float[][] magnitudes = new float[][] { new float[3], new float[2] }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + for (int i = 0; i < vectors.length; i++) { + field.setNextDocId(i); + RankVectors dv = field.get(); + assertEquals(vectors[i].length, dv.size()); + assertFalse(dv.isEmpty()); + assertEquals(dims, dv.getDims()); + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator); + assertEquals("Cannot iterate over single valued rank_vectors field, use get() instead", e.getMessage()); + } + field.setNextDocId(vectors.length); + RankVectors dv = field.get(); + assertEquals(dv, RankVectors.EMPTY); + } + public void testByteMetadataAndIterator() throws IOException { int dims = 3; float[][][] vectors = new float[][][] { fill(new float[3][dims], ElementType.BYTE), fill(new float[2][dims], ElementType.BYTE) }; @@ -145,6 +206,30 @@ public void testFloatMissingValues() throws IOException { assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); } + public void testBFloat16MissingValues() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + RankVectorsDocValuesField field = new FloatRankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(3); + assertEquals(0, field.size()); + Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValues); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); + + e = expectThrows(IllegalArgumentException.class, scriptDocValues::getMagnitudes); + assertEquals("A document doesn't have a value for a rank-vectors field!", e.getMessage()); + } + public void testByteMissingValues() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -183,6 +268,32 @@ public void testFloatGetFunctionIsNotAccessible() throws IOException { ); } + public void testBFloat16GetFunctionIsNotAccessible() throws IOException { + int dims = 3; + float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; + float[][] magnitudes = { { 1.7320f, 2.4495f, 3.3166f }, { 2.2361f } }; + BinaryDocValues docValues = wrap(vectors, ElementType.BFLOAT16); + BinaryDocValues magnitudeValues = wrap(magnitudes); + RankVectorsDocValuesField field = new BFloat16RankVectorsDocValuesField( + docValues, + magnitudeValues, + "test", + ElementType.BFLOAT16, + dims + ); + RankVectorsScriptDocValues scriptDocValues = field.toScriptDocValues(); + + field.setNextDocId(0); + Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); + assertThat( + e.getMessage(), + containsString( + "accessing a rank-vectors field's value through 'get' or 'value' is not supported," + + " use 'vectorValues' or 'magnitudes' instead." + ) + ); + } + public void testByteGetFunctionIsNotAccessible() throws IOException { int dims = 3; float[][][] vectors = { { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } }, { { 1, 0, 2 } } }; @@ -304,12 +415,11 @@ public static BytesRef mockEncodeDenseVector(float[][] values, ElementType eleme ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes * values.length); for (float[] vector : values) { for (float value : vector) { - if (elementType == ElementType.FLOAT) { - byteBuffer.putFloat(value); - } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) { - byteBuffer.put((byte) value); - } else { - throw new IllegalStateException("unknown element_type [" + elementType + "]"); + switch (elementType) { + case FLOAT -> byteBuffer.putFloat(value); + case BFLOAT16 -> byteBuffer.putShort((short) (Float.floatToIntBits(value) >>> 16)); + case BYTE, BIT -> byteBuffer.put((byte) value); + default -> throw new IllegalStateException("unknown element_type [" + elementType + "]"); } } } diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java index f21525103d37d..51cb6a0b58439 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/SemanticTextUpgradeIT.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.MapperFeatures; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.index.query.NestedQueryBuilder; @@ -85,12 +86,23 @@ public static Iterable parameters() { public void testSemanticTextOperations() throws Exception { switch (CLUSTER_TYPE) { - case OLD -> createAndPopulateIndex(); + case OLD -> { + checkSupportsBFloat16(); + createAndPopulateIndex(); + } case MIXED, UPGRADED -> performIndexQueryHighlightOps(); default -> throw new UnsupportedOperationException("Unknown cluster type [" + CLUSTER_TYPE + "]"); } } + private static void checkSupportsBFloat16() { + assumeTrue( + "The old cluster needs to support bfloat16 if it is used", + DENSE_MODEL.getServiceSettings().elementType() != DenseVectorFieldMapper.ElementType.BFLOAT16 + || clusterHasFeature(MapperFeatures.BBQ_BFLOAT16) + ); + } + private void createAndPopulateIndex() throws IOException { final String indexName = getIndexName(); final String mapping = Strings.format("""