> getSettings() {
+ if (GPU_FORMAT.isEnabled()) {
+ return List.of(VECTORS_INDEXING_USE_GPU_SETTING);
+ } else {
+ return List.of();
+ }
+ }
+
+ @Override
+ public VectorsFormatProvider getVectorsFormatProvider() {
+ return (indexSettings, indexOptions) -> {
+ if (GPU_FORMAT.isEnabled()) {
+ GpuMode gpuMode = indexSettings.getValue(VECTORS_INDEXING_USE_GPU_SETTING);
+ if (gpuMode == GpuMode.TRUE) {
+ if (vectorIndexTypeSupported(indexOptions.getType()) == false) {
+ throw new IllegalArgumentException(
+ "[index.vectors.indexing.use_gpu] doesn't support [index_options.type] of [" + indexOptions.getType() + "]."
+ );
+ }
+ if (GPUSupport.isSupported(true) == false) {
+ throw new IllegalArgumentException(
+ "[index.vectors.indexing.use_gpu] was set to [true], but GPU resources are not accessible on the node."
+ );
+ }
+ return getVectorsFormat(indexOptions);
+ }
+ if (gpuMode == GpuMode.AUTO && vectorIndexTypeSupported(indexOptions.getType()) && GPUSupport.isSupported(false)) {
+ return getVectorsFormat(indexOptions);
+ }
+ }
+ return null;
+ };
+ }
+
+ private boolean vectorIndexTypeSupported(DenseVectorFieldMapper.VectorIndexType type) {
+ return type == DenseVectorFieldMapper.VectorIndexType.HNSW || type == DenseVectorFieldMapper.VectorIndexType.INT8_HNSW;
+ }
+
+ private static KnnVectorsFormat getVectorsFormat(DenseVectorFieldMapper.DenseVectorIndexOptions indexOptions) {
+ if (indexOptions.getType() == DenseVectorFieldMapper.VectorIndexType.HNSW) {
+ DenseVectorFieldMapper.HnswIndexOptions hnswIndexOptions = (DenseVectorFieldMapper.HnswIndexOptions) indexOptions;
+ int efConstruction = hnswIndexOptions.efConstruction();
+ if (efConstruction == HnswGraphBuilder.DEFAULT_BEAM_WIDTH) {
+ efConstruction = ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH; // default value for GPU graph construction is 128
+ }
+ return new ES92GpuHnswVectorsFormat(hnswIndexOptions.m(), efConstruction);
+ } else if (indexOptions.getType() == DenseVectorFieldMapper.VectorIndexType.INT8_HNSW) {
+ DenseVectorFieldMapper.Int8HnswIndexOptions int8HnswIndexOptions = (DenseVectorFieldMapper.Int8HnswIndexOptions) indexOptions;
+ int efConstruction = int8HnswIndexOptions.efConstruction();
+ if (efConstruction == HnswGraphBuilder.DEFAULT_BEAM_WIDTH) {
+ efConstruction = ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH; // default value for GPU graph construction is 128
+ }
+ return new ES92GpuHnswSQVectorsFormat(
+ int8HnswIndexOptions.m(),
+ efConstruction,
+ int8HnswIndexOptions.confidenceInterval(),
+ 7,
+ false
+ );
+ } else {
+ throw new IllegalArgumentException(
+ "GPU vector indexing is not supported on this vector type: [" + indexOptions.getType() + "]"
+ );
+ }
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUSupport.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUSupport.java
new file mode 100644
index 0000000000000..c21bda894790a
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUSupport.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu;
+
+import com.nvidia.cuvs.CuVSResources;
+import com.nvidia.cuvs.GPUInfoProvider;
+import com.nvidia.cuvs.spi.CuVSProvider;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+
+public class GPUSupport {
+
+ private static final Logger LOG = LogManager.getLogger(GPUSupport.class);
+
+ // Set the minimum at 7.5GB: 8GB GPUs (which are our targeted minimum) report less than that via the API
+ private static final long MIN_DEVICE_MEMORY_IN_BYTES = 8053063680L;
+
+ /** Tells whether the platform supports cuvs. */
+ public static boolean isSupported(boolean logError) {
+ try {
+ var gpuInfoProvider = CuVSProvider.provider().gpuInfoProvider();
+ var availableGPUs = gpuInfoProvider.availableGPUs();
+ if (availableGPUs.isEmpty()) {
+ if (logError) {
+ LOG.warn("No GPU found");
+ }
+ return false;
+ }
+
+ for (var gpu : availableGPUs) {
+ if (gpu.computeCapabilityMajor() < GPUInfoProvider.MIN_COMPUTE_CAPABILITY_MAJOR
+ || (gpu.computeCapabilityMajor() == GPUInfoProvider.MIN_COMPUTE_CAPABILITY_MAJOR
+ && gpu.computeCapabilityMinor() < GPUInfoProvider.MIN_COMPUTE_CAPABILITY_MINOR)) {
+ if (logError) {
+ LOG.warn(
+ "GPU [{}] does not have the minimum compute capabilities (required: [{}.{}], found: [{}.{}])",
+ gpu.name(),
+ GPUInfoProvider.MIN_COMPUTE_CAPABILITY_MAJOR,
+ GPUInfoProvider.MIN_COMPUTE_CAPABILITY_MINOR,
+ gpu.computeCapabilityMajor(),
+ gpu.computeCapabilityMinor()
+ );
+ }
+ } else if (gpu.totalDeviceMemoryInBytes() < MIN_DEVICE_MEMORY_IN_BYTES) {
+ if (logError) {
+ LOG.warn(
+ "GPU [{}] does not have minimum memory required (required: [{}], found: [{}])",
+ gpu.name(),
+ MIN_DEVICE_MEMORY_IN_BYTES,
+ gpu.totalDeviceMemoryInBytes()
+ );
+ }
+ } else {
+ if (logError) {
+ LOG.info("Found compatible GPU [{}] (id: [{}])", gpu.name(), gpu.gpuId());
+ }
+ return true;
+ }
+ }
+
+ } catch (UnsupportedOperationException uoe) {
+ if (logError) {
+ final String msg;
+ if (uoe.getMessage() == null) {
+ msg = Strings.format(
+ "runtime Java version [%d], OS [%s], arch [%s]",
+ Runtime.version().feature(),
+ System.getProperty("os.name"),
+ System.getProperty("os.arch")
+ );
+ } else {
+ msg = uoe.getMessage();
+ }
+ LOG.warn("GPU based vector indexing is not supported on this platform; " + msg);
+ }
+ } catch (Throwable t) {
+ if (logError) {
+ if (t instanceof ExceptionInInitializerError ex) {
+ t = ex.getCause();
+ }
+ LOG.warn("Exception occurred during creation of cuvs resources", t);
+ }
+ }
+ return false;
+ }
+
+ /** Returns a resources if supported, otherwise null. */
+ public static CuVSResources cuVSResourcesOrNull(boolean logError) {
+ try {
+ var resources = CuVSResources.create();
+ return resources;
+ } catch (UnsupportedOperationException uoe) {
+ if (logError) {
+ String msg = "";
+ if (uoe.getMessage() == null) {
+ msg = "Runtime Java version: " + Runtime.version().feature();
+ } else {
+ msg = ": " + uoe.getMessage();
+ }
+ LOG.warn("GPU based vector indexing is not supported on this platform or java version; " + msg);
+ }
+ } catch (Throwable t) {
+ if (logError) {
+ if (t instanceof ExceptionInInitializerError ex) {
+ t = ex.getCause();
+ }
+ LOG.warn("Exception occurred during creation of cuvs resources", t);
+ }
+ }
+ return null;
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java
new file mode 100644
index 0000000000000..44240a848268b
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java
@@ -0,0 +1,270 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import com.nvidia.cuvs.CuVSMatrix;
+import com.nvidia.cuvs.CuVSResources;
+import com.nvidia.cuvs.GPUInfoProvider;
+import com.nvidia.cuvs.spi.CuVSProvider;
+
+import org.elasticsearch.core.Strings;
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+import org.elasticsearch.xpack.gpu.GPUSupport;
+
+import java.nio.file.Path;
+import java.util.Objects;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.ReentrantLock;
+
+/**
+ * A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU.
+ *
+ * All access to GPU resources is mediated through a manager. A manager helps coordinate usage threads to:
+ *
+ * - ensure single-threaded access to any particular resource at a time
+ * - Control the total number of concurrent operations that may be performed on a GPU
+ * - Pool resources, to avoid frequent creation and destruction, which are expensive operations.
+ *
+ *
+ * Fundamentally, a resource is used in compute and memory bound operations. The former occurs prior to the latter, e.g.
+ * index build (compute), followed by a copy/process of the newly built index (memory). The manager allows the resource
+ * user to indicate that compute is complete before releasing the resources. This can help improve parallelism of compute
+ * on the GPU - allowing the next compute operation to proceed before releasing the resources.
+ *
+ */
+public interface CuVSResourceManager {
+
+ /**
+ * Acquires a resource from the manager.
+ *
+ *
A manager can use the given parameters, numVectors and dims, to estimate the potential
+ * effect on GPU memory and compute usage to determine whether to give out
+ * another resource or wait for a resources to be returned before giving out another.
+ */
+ // numVectors and dims are currently unused, but could be used along with GPU metadata,
+ // memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
+ // to give out a resources or not.
+ ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException;
+
+ /** Marks the resources as finished with regard to compute. */
+ void finishedComputation(ManagedCuVSResources resources);
+
+ /** Returns the given resource to the manager. */
+ void release(ManagedCuVSResources resources);
+
+ /** Shuts down the manager, releasing all open resources. */
+ void shutdown();
+
+ /** Returns the system-wide pooling manager. */
+ static CuVSResourceManager pooling() {
+ return PoolingCuVSResourceManager.Holder.INSTANCE;
+ }
+
+ /**
+ * A manager that maintains a pool of resources.
+ */
+ class PoolingCuVSResourceManager implements CuVSResourceManager {
+
+ static final Logger logger = LogManager.getLogger(CuVSResourceManager.class);
+
+ /** A multiplier on input data to account for intermediate and output data size required while processing it */
+ static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;
+ static final int MAX_RESOURCES = 4;
+
+ static class Holder {
+ static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
+ MAX_RESOURCES,
+ CuVSProvider.provider().gpuInfoProvider()
+ );
+ }
+
+ private final ManagedCuVSResources[] pool;
+ private final int capacity;
+ private final GPUInfoProvider gpuInfoProvider;
+ private int createdCount;
+
+ ReentrantLock lock = new ReentrantLock();
+ Condition enoughResourcesCondition = lock.newCondition();
+
+ public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
+ if (capacity < 1 || capacity > MAX_RESOURCES) {
+ throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
+ }
+ this.capacity = capacity;
+ this.gpuInfoProvider = gpuInfoProvider;
+ this.pool = new ManagedCuVSResources[MAX_RESOURCES];
+ }
+
+ private ManagedCuVSResources getResourceFromPool() {
+ for (int i = 0; i < createdCount; ++i) {
+ var res = pool[i];
+ if (res.locked == false) {
+ return res;
+ }
+ }
+ if (createdCount < capacity) {
+ var res = new ManagedCuVSResources(Objects.requireNonNull(createNew()));
+ pool[createdCount++] = res;
+ return res;
+ }
+ return null;
+ }
+
+ private int numLockedResources() {
+ int lockedResources = 0;
+ for (int i = 0; i < createdCount; ++i) {
+ var res = pool[i];
+ if (res.locked) {
+ lockedResources++;
+ }
+ }
+ return lockedResources;
+ }
+
+ @Override
+ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException {
+ try {
+ lock.lock();
+
+ boolean allConditionsMet = false;
+ ManagedCuVSResources res = null;
+ while (allConditionsMet == false) {
+ res = getResourceFromPool();
+
+ final boolean enoughMemory;
+ if (res != null) {
+ long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType);
+ logger.debug(
+ "Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
+ numVectors,
+ dims,
+ dataType.name(),
+ requiredMemoryInBytes
+ );
+
+ // Check immutable constraints
+ long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes();
+ if (requiredMemoryInBytes > totalDeviceMemoryInBytes) {
+ String message = Strings.format(
+ "Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]",
+ numVectors,
+ dims,
+ totalDeviceMemoryInBytes
+ );
+ logger.error(message);
+ throw new IllegalArgumentException(message);
+ }
+
+ // If no resource in the pool is locked, short circuit to avoid livelock
+ if (numLockedResources() == 0) {
+ logger.debug("No resources currently locked, proceeding");
+ break;
+ }
+
+ // Check resources availability
+ long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
+ enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes;
+ logger.debug("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes, enoughMemory);
+ } else {
+ logger.debug("No resources available in pool");
+ enoughMemory = false;
+ }
+ // TODO: add enoughComputation / enoughComputationCondition here
+ allConditionsMet = enoughMemory; // && enoughComputation
+ if (allConditionsMet == false) {
+ enoughResourcesCondition.await();
+ }
+ }
+ res.locked = true;
+ return res;
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) {
+ int elementTypeBytes = switch (dataType) {
+ case FLOAT -> Float.BYTES;
+ case INT, UINT -> Integer.BYTES;
+ case BYTE -> Byte.BYTES;
+ };
+ return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes);
+ }
+
+ // visible for testing
+ protected CuVSResources createNew() {
+ return GPUSupport.cuVSResourcesOrNull(true);
+ }
+
+ @Override
+ public void finishedComputation(ManagedCuVSResources resources) {
+ logger.debug("Computation finished");
+ // currently does nothing, but could allow acquire to return possibly blocked resources
+ // enoughResourcesCondition.signalAll()
+ }
+
+ @Override
+ public void release(ManagedCuVSResources resources) {
+ logger.debug("Releasing resources to pool");
+ try {
+ lock.lock();
+ assert resources.locked;
+ resources.locked = false;
+ enoughResourcesCondition.signalAll();
+ } finally {
+ lock.unlock();
+ }
+ }
+
+ @Override
+ public void shutdown() {
+ for (int i = 0; i < createdCount; ++i) {
+ var res = pool[i];
+ assert res != null;
+ res.delegate.close();
+ }
+ }
+ }
+
+ /** A managed resource. Cannot be closed. */
+ final class ManagedCuVSResources implements CuVSResources {
+
+ final CuVSResources delegate;
+ boolean locked = false;
+
+ ManagedCuVSResources(CuVSResources resources) {
+ this.delegate = resources;
+ }
+
+ @Override
+ public ScopedAccess access() {
+ return delegate.access();
+ }
+
+ @Override
+ public int deviceId() {
+ return delegate.deviceId();
+ }
+
+ @Override
+ public void close() {
+ throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients");
+ }
+
+ @Override
+ public Path tempDirectory() {
+ return null;
+ }
+
+ @Override
+ public String toString() {
+ return "ManagedCuVSResources[delegate=" + delegate + "]";
+ }
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java
new file mode 100644
index 0000000000000..3a9fcb2c68cd8
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import com.nvidia.cuvs.CuVSMatrix;
+
+import org.apache.lucene.store.MemorySegmentAccessInput;
+
+import java.io.IOException;
+
+public interface DatasetUtils {
+
+ static DatasetUtils getInstance() {
+ return DatasetUtilsImpl.getInstance();
+ }
+
+ /** Returns a Dataset over the vectors of type {@code dataType} in the input. */
+ CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException;
+
+ /** Returns a Dataset over an input slice */
+ CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
+ throws IOException;
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java
new file mode 100644
index 0000000000000..0dfb0960cebbe
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import com.nvidia.cuvs.CuVSMatrix;
+import com.nvidia.cuvs.spi.CuVSProvider;
+
+import org.apache.lucene.store.MemorySegmentAccessInput;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.lang.invoke.MethodHandle;
+
+public class DatasetUtilsImpl implements DatasetUtils {
+
+ private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();
+
+ private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder();
+
+ static DatasetUtils getInstance() {
+ return INSTANCE;
+ }
+
+ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int dimensions, CuVSMatrix.DataType dataType) {
+ try {
+ return (CuVSMatrix) createDataset$mh.invokeExact(memorySegment, size, dimensions, dataType);
+ } catch (Throwable e) {
+ if (e instanceof Error err) {
+ throw err;
+ } else if (e instanceof RuntimeException re) {
+ throw re;
+ } else {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ private DatasetUtilsImpl() {}
+
+ @Override
+ public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException {
+ if (numVectors < 0 || dims < 0) {
+ throwIllegalArgumentException(numVectors, dims);
+ }
+ return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType);
+ }
+
+ @Override
+ public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
+ throws IOException {
+ if (pos < 0 || len < 0) {
+ throw new IllegalArgumentException("pos and len must be positive");
+ }
+ return createCuVSMatrix(input, pos, len, numVectors, dims, dataType);
+ }
+
+ private static CuVSMatrix createCuVSMatrix(
+ MemorySegmentAccessInput input,
+ long pos,
+ long len,
+ int numVectors,
+ int dims,
+ CuVSMatrix.DataType dataType
+ ) throws IOException {
+ MemorySegment ms = input.segmentSliceOrNull(pos, len);
+ assert ms != null; // TODO: this can be null if larger than 16GB or ...
+ final int byteSize = dataType == CuVSMatrix.DataType.FLOAT ? Float.BYTES : Byte.BYTES;
+ if (((long) numVectors * dims * byteSize) > ms.byteSize()) {
+ throwIllegalArgumentException(ms, numVectors, dims);
+ }
+ return fromMemorySegment(ms, numVectors, dims, dataType);
+ }
+
+ static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
+ var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims";
+ throw new IllegalArgumentException(s);
+ }
+
+ static void throwIllegalArgumentException(int numVectors, int dims) {
+ String s;
+ if (numVectors < 0) {
+ s = "negative number of vectors: " + numVectors;
+ } else {
+ s = "negative vector dims: " + dims;
+ }
+ throw new IllegalArgumentException(s);
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswSQVectorsFormat.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswSQVectorsFormat.java
new file mode 100644
index 0000000000000..b62766fb39c3a
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswSQVectorsFormat.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat;
+
+import java.io.IOException;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.DEFAULT_MAX_CONN;
+
+/**
+ * Codec format for GPU-accelerated scalar quantized HNSW vector indexes.
+ * HNSW graph is built on GPU, while scalar quantization and search is performed on CPU.
+ */
+public class ES92GpuHnswSQVectorsFormat extends KnnVectorsFormat {
+ public static final String NAME = "Lucene99HnswVectorsFormat";
+ static final int MAXIMUM_MAX_CONN = 512;
+ static final int MAXIMUM_BEAM_WIDTH = 3200;
+ private final int maxConn;
+ private final int beamWidth;
+
+ /** The format for storing, reading, merging vectors on disk */
+ private final FlatVectorsFormat flatVectorsFormat;
+ private final Supplier cuVSResourceManagerSupplier;
+
+ public ES92GpuHnswSQVectorsFormat() {
+ this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null, 7, false);
+ }
+
+ public ES92GpuHnswSQVectorsFormat(int maxConn, int beamWidth, Float confidenceInterval, int bits, boolean compress) {
+ super(NAME);
+ this.cuVSResourceManagerSupplier = CuVSResourceManager::pooling;
+ if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
+ throw new IllegalArgumentException(
+ "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
+ );
+ }
+ if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
+ throw new IllegalArgumentException(
+ "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
+ );
+ }
+ this.maxConn = maxConn;
+ this.beamWidth = beamWidth;
+ this.flatVectorsFormat = new ES814ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
+ }
+
+ @Override
+ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new ES92GpuHnswVectorsWriter(
+ cuVSResourceManagerSupplier.get(),
+ state,
+ maxConn,
+ beamWidth,
+ flatVectorsFormat.fieldsWriter(state)
+ );
+ }
+
+ @Override
+ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return MAX_DIMS_COUNT;
+ }
+
+ @Override
+ public String toString() {
+ return NAME
+ + "(name="
+ + NAME
+ + ", maxConn="
+ + maxConn
+ + ", beamWidth="
+ + beamWidth
+ + ", flatVectorFormat="
+ + flatVectorsFormat
+ + ")";
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormat.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormat.java
new file mode 100644
index 0000000000000..8761b9e12f22a
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormat.java
@@ -0,0 +1,102 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+
+import java.io.IOException;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
+
+/**
+ * Codec format for GPU-accelerated vector indexes. This format is designed to
+ * leverage GPU processing capabilities for vector search operations.
+ */
+public class ES92GpuHnswVectorsFormat extends KnnVectorsFormat {
+ public static final String NAME = "Lucene99HnswVectorsFormat";
+ public static final int VERSION_GROUPVARINT = 1;
+
+ static final String LUCENE99_HNSW_META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta";
+ static final String LUCENE99_HNSW_VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex";
+ static final String LUCENE99_HNSW_META_EXTENSION = "vem";
+ static final String LUCENE99_HNSW_VECTOR_INDEX_EXTENSION = "vex";
+ static final int LUCENE99_VERSION_CURRENT = VERSION_GROUPVARINT;
+
+ static final int DEFAULT_MAX_CONN = 16; // graph degree
+ public static final int DEFAULT_BEAM_WIDTH = 128; // intermediate graph degree
+ static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = 2;
+
+ private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(
+ FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
+ );
+
+ // How many nodes each node in the graph is connected to in the final graph
+ private final int maxConn;
+ // Intermediate graph degree, the number of connections for each node before pruning
+ private final int beamWidth;
+ private final Supplier cuVSResourceManagerSupplier;
+
+ public ES92GpuHnswVectorsFormat() {
+ this(CuVSResourceManager::pooling, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
+ }
+
+ public ES92GpuHnswVectorsFormat(int maxConn, int beamWidth) {
+ this(CuVSResourceManager::pooling, maxConn, beamWidth);
+ };
+
+ public ES92GpuHnswVectorsFormat(Supplier cuVSResourceManagerSupplier, int maxConn, int beamWidth) {
+ super(NAME);
+ this.cuVSResourceManagerSupplier = cuVSResourceManagerSupplier;
+ this.maxConn = maxConn;
+ this.beamWidth = beamWidth;
+ }
+
+ @Override
+ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new ES92GpuHnswVectorsWriter(
+ cuVSResourceManagerSupplier.get(),
+ state,
+ maxConn,
+ beamWidth,
+ flatVectorsFormat.fieldsWriter(state)
+ );
+ }
+
+ @Override
+ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return MAX_DIMS_COUNT;
+ }
+
+ @Override
+ public String toString() {
+ return NAME
+ + "(name="
+ + NAME
+ + ", maxConn="
+ + maxConn
+ + ", beamWidth="
+ + beamWidth
+ + ", flatVectorFormat="
+ + flatVectorsFormat.getName()
+ + ")";
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java
new file mode 100644
index 0000000000000..f848f715f913b
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsWriter.java
@@ -0,0 +1,683 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import com.nvidia.cuvs.CagraIndex;
+import com.nvidia.cuvs.CagraIndexParams;
+import com.nvidia.cuvs.CuVSMatrix;
+
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.KnnFieldVectorsWriter;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.DocsWithFieldSet;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.MergeState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.FilterIndexInput;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.hnsw.HnswGraph;
+import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
+import org.apache.lucene.util.packed.DirectMonotonicWriter;
+import org.apache.lucene.util.quantization.ScalarQuantizer;
+import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.core.SuppressForbidden;
+import org.elasticsearch.index.codec.vectors.ES814ScalarQuantizedVectorsFormat;
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
+import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter.mergeAndRecalculateQuantiles;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_HNSW_META_CODEC_NAME;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_HNSW_META_EXTENSION;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_HNSW_VECTOR_INDEX_CODEC_NAME;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_HNSW_VECTOR_INDEX_EXTENSION;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.LUCENE99_VERSION_CURRENT;
+import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.MIN_NUM_VECTORS_FOR_GPU_BUILD;
+
+/**
+ * Writer that builds an Nvidia Carga Graph on GPU and then writes it into the Lucene99 HNSW format,
+ * so that it can be searched on CPU with Lucene99HNSWVectorReader.
+ */
+final class ES92GpuHnswVectorsWriter extends KnnVectorsWriter {
+ private static final Logger logger = LogManager.getLogger(ES92GpuHnswVectorsWriter.class);
+ private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(ES92GpuHnswVectorsWriter.class);
+ private static final int LUCENE99_HNSW_DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
+
+ private final CuVSResourceManager cuVSResourceManager;
+ private final SegmentWriteState segmentWriteState;
+ private final IndexOutput meta, vectorIndex;
+ private final int M;
+ private final int beamWidth;
+ private final FlatVectorsWriter flatVectorWriter;
+
+ private final List fields = new ArrayList<>();
+ private boolean finished;
+ private final CuVSMatrix.DataType dataType;
+
+ ES92GpuHnswVectorsWriter(
+ CuVSResourceManager cuVSResourceManager,
+ SegmentWriteState state,
+ int M,
+ int beamWidth,
+ FlatVectorsWriter flatVectorWriter
+ ) throws IOException {
+ assert cuVSResourceManager != null : "CuVSResources must not be null";
+ this.cuVSResourceManager = cuVSResourceManager;
+ this.M = M;
+ this.beamWidth = beamWidth;
+ this.flatVectorWriter = flatVectorWriter;
+ if (flatVectorWriter instanceof ES814ScalarQuantizedVectorsFormat.ES814ScalarQuantizedVectorsWriter) {
+ dataType = CuVSMatrix.DataType.BYTE;
+ } else {
+ assert flatVectorWriter instanceof Lucene99FlatVectorsWriter;
+ dataType = CuVSMatrix.DataType.FLOAT;
+ }
+ this.segmentWriteState = state;
+ String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, LUCENE99_HNSW_META_EXTENSION);
+ String indexDataFileName = IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ LUCENE99_HNSW_VECTOR_INDEX_EXTENSION
+ );
+ boolean success = false;
+ try {
+ meta = state.directory.createOutput(metaFileName, state.context);
+ vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
+ CodecUtil.writeIndexHeader(
+ meta,
+ LUCENE99_HNSW_META_CODEC_NAME,
+ LUCENE99_VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix
+ );
+ CodecUtil.writeIndexHeader(
+ vectorIndex,
+ LUCENE99_HNSW_VECTOR_INDEX_CODEC_NAME,
+ LUCENE99_VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix
+ );
+ success = true;
+ } finally {
+ if (success == false) {
+ org.elasticsearch.core.IOUtils.closeWhileHandlingException(this);
+ }
+ }
+ }
+
+ @Override
+ public KnnFieldVectorsWriter> addField(FieldInfo fieldInfo) throws IOException {
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
+ throw new IllegalArgumentException(
+ "Field [" + fieldInfo.name + "] must have FLOAT32 encoding, got: " + fieldInfo.getVectorEncoding()
+ );
+ }
+ @SuppressWarnings("unchecked")
+ FlatFieldVectorsWriter flatFieldWriter = (FlatFieldVectorsWriter) flatVectorWriter.addField(fieldInfo);
+ FieldWriter newField = new FieldWriter(flatFieldWriter, fieldInfo);
+ fields.add(newField);
+ return newField;
+ }
+
+ /**
+ * Flushes vector data and associated data to disk.
+ *
+ * This method and the private helpers it calls only need to support FLOAT32.
+ * For FlatFieldVectorWriter we only need to support float[] during flush: during indexing users provide floats[], and pass floats to
+ * FlatFieldVectorWriter, even when we have a BYTE dataType (i.e. an "int8_hnsw" type).
+ * During merging, we use quantized data, so we need to support byte[] too (see {@link ES92GpuHnswVectorsWriter#mergeOneField}),
+ * but not here.
+ * That's how our other current formats work: use floats during indexing, and quantized data to build graph during merging.
+ *
+ */
+ @Override
+ // TODO: fix sorted index case
+ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
+ flatVectorWriter.flush(maxDoc, sortMap);
+ try {
+ flushFieldsWithoutMemoryMappedFile(sortMap);
+ } catch (Throwable t) {
+ throw new IOException("Failed to flush GPU index: ", t);
+ }
+ }
+
+ private void flushFieldsWithoutMemoryMappedFile(Sorter.DocMap sortMap) throws IOException, InterruptedException {
+ // No tmp file written, or the file cannot be mmapped
+ for (FieldWriter field : fields) {
+ var fieldInfo = field.fieldInfo;
+
+ var numVectors = field.flatFieldVectorsWriter.getVectors().size();
+ if (numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD) {
+ if (logger.isDebugEnabled()) {
+ logger.debug(
+ "Skip building carga index; vectors length {} < {} (min for GPU)",
+ numVectors,
+ MIN_NUM_VECTORS_FOR_GPU_BUILD
+ );
+ }
+ // Will not be indexed on the GPU
+ flushFieldWithMockGraph(fieldInfo, numVectors, sortMap);
+ } else {
+ var cuVSResources = cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), CuVSMatrix.DataType.FLOAT);
+ try {
+ var builder = CuVSMatrix.deviceBuilder(
+ cuVSResources,
+ numVectors,
+ fieldInfo.getVectorDimension(),
+ CuVSMatrix.DataType.FLOAT
+ );
+ for (var vector : field.flatFieldVectorsWriter.getVectors()) {
+ builder.addVector(vector);
+ }
+ try (var dataset = builder.build()) {
+ flushFieldWithGpuGraph(cuVSResources, fieldInfo, dataset, sortMap);
+ }
+ } finally {
+ cuVSResourceManager.release(cuVSResources);
+ }
+ }
+ }
+ }
+
+ private void flushFieldWithMockGraph(FieldInfo fieldInfo, int numVectors, Sorter.DocMap sortMap) throws IOException {
+ if (sortMap == null) {
+ generateMockGraphAndWriteMeta(fieldInfo, numVectors);
+ } else {
+ // TODO: use sortMap
+ generateMockGraphAndWriteMeta(fieldInfo, numVectors);
+ }
+ }
+
+ private void flushFieldWithGpuGraph(
+ CuVSResourceManager.ManagedCuVSResources resources,
+ FieldInfo fieldInfo,
+ CuVSMatrix dataset,
+ Sorter.DocMap sortMap
+ ) throws IOException {
+ if (sortMap == null) {
+ generateGpuGraphAndWriteMeta(resources, fieldInfo, dataset);
+ } else {
+ // TODO: use sortMap
+ generateGpuGraphAndWriteMeta(resources, fieldInfo, dataset);
+ }
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ throw new IllegalStateException("already finished");
+ }
+ finished = true;
+ flatVectorWriter.finish();
+
+ if (meta != null) {
+ // write end of fields marker
+ meta.writeInt(-1);
+ CodecUtil.writeFooter(meta);
+ }
+ if (vectorIndex != null) {
+ CodecUtil.writeFooter(vectorIndex);
+ }
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long total = SHALLOW_RAM_BYTES_USED;
+ for (FieldWriter field : fields) {
+ // the field tracks the delegate field usage
+ total += field.ramBytesUsed();
+ }
+ return total;
+ }
+
+ private void generateGpuGraphAndWriteMeta(
+ CuVSResourceManager.ManagedCuVSResources cuVSResources,
+ FieldInfo fieldInfo,
+ CuVSMatrix dataset
+ ) throws IOException {
+ try {
+ assert dataset.size() >= MIN_NUM_VECTORS_FOR_GPU_BUILD;
+
+ long vectorIndexOffset = vectorIndex.getFilePointer();
+ int[][] graphLevelNodeOffsets = new int[1][];
+ final HnswGraph graph;
+ try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) {
+ assert index != null : "GPU index should be built for field: " + fieldInfo.name;
+ graph = writeGraph(index.getGraph(), graphLevelNodeOffsets);
+ }
+ long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
+ writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, (int) dataset.size(), graph, graphLevelNodeOffsets);
+ } catch (IOException e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new IOException("Failed to write GPU index: ", t);
+ }
+ }
+
+ private void generateMockGraphAndWriteMeta(FieldInfo fieldInfo, int datasetSize) throws IOException {
+ try {
+ long vectorIndexOffset = vectorIndex.getFilePointer();
+ int[][] graphLevelNodeOffsets = new int[1][];
+ final HnswGraph graph = writeMockGraph(datasetSize, graphLevelNodeOffsets);
+ long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
+ writeMeta(fieldInfo, vectorIndexOffset, vectorIndexLength, datasetSize, graph, graphLevelNodeOffsets);
+ } catch (IOException e) {
+ throw e;
+ } catch (Throwable t) {
+ throw new IOException("Failed to write GPU index: ", t);
+ }
+ }
+
+ private CagraIndex buildGPUIndex(
+ CuVSResourceManager.ManagedCuVSResources cuVSResources,
+ VectorSimilarityFunction similarityFunction,
+ CuVSMatrix dataset
+ ) throws Throwable {
+ CagraIndexParams.CuvsDistanceType distanceType = switch (similarityFunction) {
+ case EUCLIDEAN -> CagraIndexParams.CuvsDistanceType.L2Expanded;
+ case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> CagraIndexParams.CuvsDistanceType.InnerProduct;
+ case COSINE -> CagraIndexParams.CuvsDistanceType.CosineExpanded;
+ };
+
+ // TODO: expose cagra index params for algorithm, NNDescentNumIterations
+ CagraIndexParams params = new CagraIndexParams.Builder().withNumWriterThreads(1) // TODO: how many CPU threads we can use?
+ .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
+ .withGraphDegree(M)
+ .withIntermediateGraphDegree(beamWidth)
+ .withMetric(distanceType)
+ .build();
+
+ long startTime = System.nanoTime();
+ var indexBuilder = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(params);
+ var index = indexBuilder.build();
+ cuVSResourceManager.finishedComputation(cuVSResources);
+ if (logger.isDebugEnabled()) {
+ logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, dataset.size());
+ }
+ return index;
+ }
+
+ private HnswGraph writeGraph(CuVSMatrix cagraGraph, int[][] levelNodeOffsets) throws IOException {
+ long startTime = System.nanoTime();
+
+ int maxElementCount = (int) cagraGraph.size();
+ int maxGraphDegree = (int) cagraGraph.columns();
+ int[] neighbors = new int[maxGraphDegree];
+
+ levelNodeOffsets[0] = new int[maxElementCount];
+ // write the cagra graph to the Lucene vectorIndex file
+ int[] scratch = new int[maxGraphDegree];
+ for (int node = 0; node < maxElementCount; node++) {
+ cagraGraph.getRow(node).toArray(neighbors);
+
+ // write to the Lucene vectorIndex file
+ long offsetStart = vectorIndex.getFilePointer();
+ Arrays.sort(neighbors);
+ int actualSize = 0;
+ if (maxGraphDegree > 0) {
+ scratch[0] = neighbors[0];
+ actualSize = 1;
+ }
+ for (int i = 1; i < maxGraphDegree; i++) {
+ assert neighbors[i] < maxElementCount : "node too large: " + neighbors[i] + ">=" + maxElementCount;
+ if (neighbors[i - 1] == neighbors[i]) {
+ continue;
+ }
+ scratch[actualSize++] = neighbors[i] - neighbors[i - 1];
+ }
+ // Write the size after duplicates are removed
+ vectorIndex.writeVInt(actualSize);
+ vectorIndex.writeGroupVInts(scratch, actualSize);
+ levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offsetStart);
+ }
+ if (logger.isDebugEnabled()) {
+ logger.debug("cagra_hnws index serialized to Lucene HNSW in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0);
+ }
+ return createMockGraph(maxElementCount, maxGraphDegree);
+ }
+
+ // create a mock graph where every node is connected to every other node
+ private HnswGraph writeMockGraph(int elementCount, int[][] levelNodeOffsets) throws IOException {
+ if (elementCount == 0) {
+ return null;
+ }
+ int nodeDegree = elementCount - 1;
+ levelNodeOffsets[0] = new int[elementCount];
+
+ int[] neighbors = new int[nodeDegree];
+ int[] scratch = new int[nodeDegree];
+ for (int node = 0; node < elementCount; node++) {
+ if (nodeDegree > 0) {
+ for (int j = 0; j < nodeDegree; j++) {
+ neighbors[j] = j < node ? j : j + 1; // skip self
+ }
+ scratch[0] = neighbors[0];
+ for (int i = 1; i < nodeDegree; i++) {
+ scratch[i] = neighbors[i] - neighbors[i - 1];
+ }
+ }
+
+ long offsetStart = vectorIndex.getFilePointer();
+ vectorIndex.writeVInt(nodeDegree);
+ vectorIndex.writeGroupVInts(scratch, nodeDegree);
+ levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offsetStart);
+ }
+ return createMockGraph(elementCount, nodeDegree);
+ }
+
+ private static HnswGraph createMockGraph(int elementCount, int graphDegree) {
+ return new HnswGraph() {
+ @Override
+ public int nextNeighbor() {
+ throw new UnsupportedOperationException("Not supported on a mock graph");
+ }
+
+ @Override
+ public void seek(int level, int target) {
+ throw new UnsupportedOperationException("Not supported on a mock graph");
+ }
+
+ @Override
+ public int size() {
+ return elementCount;
+ }
+
+ @Override
+ public int numLevels() {
+ return 1;
+ }
+
+ @Override
+ public int maxConn() {
+ return graphDegree;
+ }
+
+ @Override
+ public int entryNode() {
+ throw new UnsupportedOperationException("Not supported on a mock graph");
+ }
+
+ @Override
+ public int neighborCount() {
+ throw new UnsupportedOperationException("Not supported on a mock graph");
+ }
+
+ @Override
+ public NodesIterator getNodesOnLevel(int level) {
+ return new ArrayNodesIterator(size());
+ }
+ };
+ }
+
+ @SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
+ private static void deleteFilesIgnoringExceptions(Directory dir, String fileName) {
+ org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(dir, fileName);
+ }
+
+ // TODO check with deleted documents
+ @Override
+ // fix sorted index case
+ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ flatVectorWriter.mergeOneField(fieldInfo, mergeState);
+ final int numVectors;
+ String tempRawVectorsFileName = null;
+ boolean success = false;
+ // save merged vector values to a temp file
+ try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "vec_", IOContext.DEFAULT)) {
+ tempRawVectorsFileName = out.getName();
+ if (dataType == CuVSMatrix.DataType.BYTE) {
+ numVectors = writeByteVectorValues(out, getMergedByteVectorValues(fieldInfo, mergeState));
+ } else {
+ numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
+ }
+ CodecUtil.writeFooter(out);
+ success = true;
+ } finally {
+ if (success == false && tempRawVectorsFileName != null) {
+ deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
+ }
+ }
+ try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
+ var input = FilterIndexInput.unwrapOnlyTest(in);
+
+ if (numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
+ if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) {
+ // Direct access to mmapped file
+ final var dataset = DatasetUtils.getInstance()
+ .fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
+
+ var cuVSResources = cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType);
+ try {
+ generateGpuGraphAndWriteMeta(cuVSResources, fieldInfo, dataset);
+ } finally {
+ dataset.close();
+ cuVSResourceManager.release(cuVSResources);
+ }
+ } else {
+ logger.debug(
+ () -> "Cannot mmap merged raw vectors temporary file. IndexInput type [" + input.getClass().getSimpleName() + "]"
+ );
+
+ var cuVSResources = cuVSResourceManager.acquire(numVectors, fieldInfo.getVectorDimension(), dataType);
+ try {
+ // Read vector-by-vector
+ var builder = CuVSMatrix.deviceBuilder(cuVSResources, numVectors, fieldInfo.getVectorDimension(), dataType);
+
+ // During merging, we use quantized data, so we need to support byte[] too.
+ // That's how our current formats work: use floats during indexing, and quantized data to build a graph
+ // during merging.
+ if (dataType == CuVSMatrix.DataType.FLOAT) {
+ float[] vector = new float[fieldInfo.getVectorDimension()];
+ for (int i = 0; i < numVectors; ++i) {
+ input.readFloats(vector, 0, fieldInfo.getVectorDimension());
+ builder.addVector(vector);
+ }
+ } else {
+ assert dataType == CuVSMatrix.DataType.BYTE;
+ byte[] vector = new byte[fieldInfo.getVectorDimension()];
+ for (int i = 0; i < numVectors; ++i) {
+ input.readBytes(vector, 0, fieldInfo.getVectorDimension());
+ builder.addVector(vector);
+ }
+ }
+ try (var dataset = builder.build()) {
+ generateGpuGraphAndWriteMeta(cuVSResources, fieldInfo, dataset);
+ }
+ } finally {
+ cuVSResourceManager.release(cuVSResources);
+ }
+ }
+ } else {
+ // we don't really need real value for vectors here,
+ // we just build a mock graph where every node is connected to every other node
+ generateMockGraphAndWriteMeta(fieldInfo, numVectors);
+ }
+ } catch (Throwable t) {
+ throw new IOException("Failed to merge GPU index: ", t);
+ } finally {
+ deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
+ }
+ }
+
+ private ByteVectorValues getMergedByteVectorValues(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ // TODO: expose confidence interval from the format
+ final byte bits = 7;
+ final Float confidenceInterval = null;
+ ScalarQuantizer quantizer = mergeAndRecalculateQuantiles(mergeState, fieldInfo, confidenceInterval, bits);
+ return MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(fieldInfo, mergeState, quantizer);
+ }
+
+ private static int writeByteVectorValues(IndexOutput out, ByteVectorValues vectorValues) throws IOException {
+ int numVectors = 0;
+ byte[] vector;
+ final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ numVectors++;
+ vector = vectorValues.vectorValue(iterator.index());
+ out.writeBytes(vector, vector.length);
+ }
+ return numVectors;
+ }
+
+ private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
+ throws IOException {
+ int numVectors = 0;
+ final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
+ final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ numVectors++;
+ float[] vector = floatVectorValues.vectorValue(iterator.index());
+ buffer.asFloatBuffer().put(vector);
+ out.writeBytes(buffer.array(), buffer.array().length);
+ }
+ return numVectors;
+ }
+
+ private void writeMeta(
+ FieldInfo field,
+ long vectorIndexOffset,
+ long vectorIndexLength,
+ int count,
+ HnswGraph graph,
+ int[][] graphLevelNodeOffsets
+ ) throws IOException {
+ meta.writeInt(field.number);
+ meta.writeInt(field.getVectorEncoding().ordinal());
+ meta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
+ meta.writeVLong(vectorIndexOffset);
+ meta.writeVLong(vectorIndexLength);
+ meta.writeVInt(field.getVectorDimension());
+ meta.writeInt(count);
+ // write graph nodes on each level
+ if (graph == null) {
+ meta.writeVInt(M);
+ meta.writeVInt(0);
+ } else {
+ meta.writeVInt(graph.maxConn());
+ meta.writeVInt(graph.numLevels());
+ long valueCount = 0;
+
+ for (int level = 0; level < graph.numLevels(); level++) {
+ NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
+ valueCount += nodesOnLevel.size();
+ if (level > 0) {
+ int[] nol = new int[nodesOnLevel.size()];
+ int numberConsumed = nodesOnLevel.consume(nol);
+ Arrays.sort(nol);
+ assert numberConsumed == nodesOnLevel.size();
+ meta.writeVInt(nol.length); // number of nodes on a level
+ for (int i = nodesOnLevel.size() - 1; i > 0; --i) {
+ nol[i] -= nol[i - 1];
+ }
+ for (int n : nol) {
+ assert n >= 0 : "delta encoding for nodes failed; expected nodes to be sorted";
+ meta.writeVInt(n);
+ }
+ } else {
+ assert nodesOnLevel.size() == count : "Level 0 expects to have all nodes";
+ }
+ }
+ long start = vectorIndex.getFilePointer();
+ meta.writeLong(start);
+ meta.writeVInt(LUCENE99_HNSW_DIRECT_MONOTONIC_BLOCK_SHIFT);
+ final DirectMonotonicWriter memoryOffsetsWriter = DirectMonotonicWriter.getInstance(
+ meta,
+ vectorIndex,
+ valueCount,
+ LUCENE99_HNSW_DIRECT_MONOTONIC_BLOCK_SHIFT
+ );
+ long cumulativeOffsetSum = 0;
+ for (int[] levelOffsets : graphLevelNodeOffsets) {
+ for (int v : levelOffsets) {
+ memoryOffsetsWriter.add(cumulativeOffsetSum);
+ cumulativeOffsetSum += v;
+ }
+ }
+ memoryOffsetsWriter.finish();
+ meta.writeLong(vectorIndex.getFilePointer() - start);
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(meta, vectorIndex, flatVectorWriter);
+ }
+
+ static int distFuncToOrd(VectorSimilarityFunction func) {
+ for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
+ if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
+ return (byte) i;
+ }
+ }
+ throw new IllegalArgumentException("invalid distance function: " + func);
+ }
+
+ private static class FieldWriter extends KnnFieldVectorsWriter {
+ private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);
+
+ private final FieldInfo fieldInfo;
+ private int lastDocID = -1;
+ private final FlatFieldVectorsWriter flatFieldVectorsWriter;
+
+ FieldWriter(FlatFieldVectorsWriter flatFieldVectorsWriter, FieldInfo fieldInfo) {
+ this.fieldInfo = fieldInfo;
+ this.flatFieldVectorsWriter = Objects.requireNonNull(flatFieldVectorsWriter);
+ }
+
+ @Override
+ public void addValue(int docID, float[] vectorValue) throws IOException {
+ if (docID == lastDocID) {
+ throw new IllegalArgumentException(
+ "VectorValuesField \""
+ + fieldInfo.name
+ + "\" appears more than once in this document (only one value is allowed per field)"
+ );
+ }
+ flatFieldVectorsWriter.addValue(docID, vectorValue);
+ lastDocID = docID;
+ }
+
+ public DocsWithFieldSet getDocsWithFieldSet() {
+ return flatFieldVectorsWriter.getDocsWithFieldSet();
+ }
+
+ @Override
+ public float[] copyValue(float[] vectorValue) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ return SHALLOW_SIZE + flatFieldVectorsWriter.ramBytesUsed();
+ }
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/MergedQuantizedVectorValues.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/MergedQuantizedVectorValues.java
new file mode 100644
index 0000000000000..4d3d5013dd381
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/MergedQuantizedVectorValues.java
@@ -0,0 +1,372 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2025 Elasticsearch B.V.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
+import org.apache.lucene.index.DocIDMerger;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.MergeState;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
+import org.apache.lucene.util.quantization.QuantizedVectorsReader;
+import org.apache.lucene.util.quantization.ScalarQuantizer;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
+
+/**
+ * A copy from Lucene99ScalarQuantizedVectorsWriter to access mergeQuantizedByteVectorValues
+ * during segment merge.
+ */
+class MergedQuantizedVectorValues extends QuantizedByteVectorValues {
+ private static final float REQUANTIZATION_LIMIT = 0.2f;
+
+ private final List subs;
+ private final DocIDMerger docIdMerger;
+ private final int size;
+ private QuantizedByteVectorValueSub current;
+
+ private MergedQuantizedVectorValues(List subs, MergeState mergeState) throws IOException {
+ this.subs = subs;
+ docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
+ int totalSize = 0;
+ for (QuantizedByteVectorValueSub sub : subs) {
+ totalSize += sub.values.size();
+ }
+ size = totalSize;
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) throws IOException {
+ return current.values.vectorValue(current.index());
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return new MergedQuantizedVectorValues.CompositeIterator();
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public int dimension() {
+ return subs.get(0).values.dimension();
+ }
+
+ @Override
+ public float getScoreCorrectionConstant(int ord) throws IOException {
+ return current.values.getScoreCorrectionConstant(current.index());
+ }
+
+ private class CompositeIterator extends DocIndexIterator {
+ private int docId;
+ private int ord;
+
+ CompositeIterator() {
+ docId = -1;
+ ord = -1;
+ }
+
+ @Override
+ public int index() {
+ return ord;
+ }
+
+ @Override
+ public int docID() {
+ return docId;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ current = docIdMerger.next();
+ if (current == null) {
+ docId = NO_MORE_DOCS;
+ ord = NO_MORE_DOCS;
+ } else {
+ docId = current.mappedDocID;
+ ++ord;
+ }
+ return docId;
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long cost() {
+ return size;
+ }
+ }
+
+ private static QuantizedVectorsReader getQuantizedKnnVectorsReader(KnnVectorsReader vectorsReader, String fieldName) {
+ if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+ vectorsReader = candidateReader.getFieldReader(fieldName);
+ }
+ if (vectorsReader instanceof QuantizedVectorsReader reader) {
+ return reader;
+ }
+ return null;
+ }
+
+ static MergedQuantizedVectorValues mergeQuantizedByteVectorValues(
+ FieldInfo fieldInfo,
+ MergeState mergeState,
+ ScalarQuantizer scalarQuantizer
+ ) throws IOException {
+ assert fieldInfo != null && fieldInfo.hasVectorValues();
+
+ List subs = new ArrayList<>();
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ if (hasVectorValues(mergeState.fieldInfos[i], fieldInfo.name)) {
+ QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
+ assert scalarQuantizer != null;
+ final QuantizedByteVectorValueSub sub;
+ // Either our quantization parameters are way different than the merged ones
+ // Or we have never been quantized.
+ if (reader == null || reader.getQuantizationState(fieldInfo.name) == null
+ // For smaller `bits` values, we should always recalculate the quantiles
+ // TODO: this is very conservative, could we reuse information for even int4
+ // quantization?
+ || scalarQuantizer.getBits() <= 4
+ || shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
+ FloatVectorValues toQuantize = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
+ if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
+ toQuantize = new NormalizedFloatVectorValues(toQuantize);
+ }
+ sub = new QuantizedByteVectorValueSub(
+ mergeState.docMaps[i],
+ new QuantizedFloatVectorValues(toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer)
+ );
+ } else {
+ sub = new QuantizedByteVectorValueSub(
+ mergeState.docMaps[i],
+ new OffsetCorrectedQuantizedByteVectorValues(
+ reader.getQuantizedVectorValues(fieldInfo.name),
+ fieldInfo.getVectorSimilarityFunction(),
+ scalarQuantizer,
+ reader.getQuantizationState(fieldInfo.name)
+ )
+ );
+ }
+ subs.add(sub);
+ }
+ }
+ return new MergedQuantizedVectorValues(subs, mergeState);
+ }
+
+ private static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantizer newQuantiles) {
+ float tol = REQUANTIZATION_LIMIT * (newQuantiles.getUpperQuantile() - newQuantiles.getLowerQuantile()) / 128f;
+ if (Math.abs(existingQuantiles.getUpperQuantile() - newQuantiles.getUpperQuantile()) > tol) {
+ return true;
+ }
+ return Math.abs(existingQuantiles.getLowerQuantile() - newQuantiles.getLowerQuantile()) > tol;
+ }
+
+ private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
+ private final QuantizedByteVectorValues values;
+ private final KnnVectorValues.DocIndexIterator iterator;
+
+ QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
+ super(docMap);
+ this.values = values;
+ iterator = values.iterator();
+ assert iterator.docID() == -1;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ return iterator.nextDoc();
+ }
+
+ public int index() {
+ return iterator.index();
+ }
+ }
+
+ private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
+ private final FloatVectorValues values;
+ private final ScalarQuantizer quantizer;
+ private final byte[] quantizedVector;
+ private int lastOrd = -1;
+ private float offsetValue = 0f;
+
+ private final VectorSimilarityFunction vectorSimilarityFunction;
+
+ QuantizedFloatVectorValues(FloatVectorValues values, VectorSimilarityFunction vectorSimilarityFunction, ScalarQuantizer quantizer) {
+ this.values = values;
+ this.quantizer = quantizer;
+ this.quantizedVector = new byte[values.dimension()];
+ this.vectorSimilarityFunction = vectorSimilarityFunction;
+ }
+
+ @Override
+ public float getScoreCorrectionConstant(int ord) {
+ if (ord != lastOrd) {
+ throw new IllegalStateException(
+ "attempt to retrieve score correction for different ord " + ord + " than the quantization was done for: " + lastOrd
+ );
+ }
+ return offsetValue;
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) throws IOException {
+ if (ord != lastOrd) {
+ offsetValue = quantize(ord);
+ lastOrd = ord;
+ }
+ return quantizedVector;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ private float quantize(int ord) throws IOException {
+ return quantizer.quantize(values.vectorValue(ord), quantizedVector, vectorSimilarityFunction);
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+ }
+
+ private static final class NormalizedFloatVectorValues extends FloatVectorValues {
+ private final FloatVectorValues values;
+ private final float[] normalizedVector;
+
+ NormalizedFloatVectorValues(FloatVectorValues values) {
+ this.values = values;
+ this.normalizedVector = new float[values.dimension()];
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
+ VectorUtil.l2normalize(normalizedVector);
+ return normalizedVector;
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public NormalizedFloatVectorValues copy() throws IOException {
+ return new NormalizedFloatVectorValues(values.copy());
+ }
+ }
+
+ private static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByteVectorValues {
+ private final QuantizedByteVectorValues in;
+ private final VectorSimilarityFunction vectorSimilarityFunction;
+ private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer;
+
+ OffsetCorrectedQuantizedByteVectorValues(
+ QuantizedByteVectorValues in,
+ VectorSimilarityFunction vectorSimilarityFunction,
+ ScalarQuantizer scalarQuantizer,
+ ScalarQuantizer oldScalarQuantizer
+ ) {
+ this.in = in;
+ this.vectorSimilarityFunction = vectorSimilarityFunction;
+ this.scalarQuantizer = scalarQuantizer;
+ this.oldScalarQuantizer = oldScalarQuantizer;
+ }
+
+ @Override
+ public float getScoreCorrectionConstant(int ord) throws IOException {
+ return scalarQuantizer.recalculateCorrectiveOffset(in.vectorValue(ord), oldScalarQuantizer, vectorSimilarityFunction);
+ }
+
+ @Override
+ public int dimension() {
+ return in.dimension();
+ }
+
+ @Override
+ public int size() {
+ return in.size();
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) throws IOException {
+ return in.vectorValue(ord);
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return in.ordToDoc(ord);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return in.iterator();
+ }
+ }
+}
diff --git a/x-pack/plugin/gpu/src/main/plugin-metadata/entitlement-policy.yaml b/x-pack/plugin/gpu/src/main/plugin-metadata/entitlement-policy.yaml
new file mode 100644
index 0000000000000..d0c571b8538b2
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/plugin-metadata/entitlement-policy.yaml
@@ -0,0 +1,2 @@
+com.nvidia.cuvs:
+ - load_native_libraries
diff --git a/x-pack/plugin/gpu/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/x-pack/plugin/gpu/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
new file mode 100644
index 0000000000000..7aa308150b6de
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
@@ -0,0 +1,3 @@
+
+org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat
+org.elasticsearch.xpack.gpu.codec.ES92GpuHnswSQVectorsFormat
diff --git a/x-pack/plugin/gpu/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification b/x-pack/plugin/gpu/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification
new file mode 100644
index 0000000000000..63e111db1dd79
--- /dev/null
+++ b/x-pack/plugin/gpu/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification
@@ -0,0 +1,8 @@
+#
+# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+# or more contributor license agreements. Licensed under the Elastic License
+# 2.0; you may not use this file except in compliance with the Elastic License
+# 2.0.
+#
+
+org.elasticsearch.xpack.gpu.GPUFeatures
diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java
new file mode 100644
index 0000000000000..b466f37cbe9c9
--- /dev/null
+++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java
@@ -0,0 +1,235 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import com.nvidia.cuvs.CuVSMatrix;
+import com.nvidia.cuvs.CuVSResources;
+import com.nvidia.cuvs.CuVSResourcesInfo;
+import com.nvidia.cuvs.GPUInfo;
+import com.nvidia.cuvs.GPUInfoProvider;
+
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+import org.elasticsearch.test.ESTestCase;
+
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.LongSupplier;
+
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
+
+public class CuVSResourceManagerTests extends ESTestCase {
+
+ private static final Logger log = LogManager.getLogger(CuVSResourceManagerTests.class);
+
+ public static final long TOTAL_DEVICE_MEMORY_IN_BYTES = 256L * 1024 * 1024;
+
+ public void testBasic() throws InterruptedException {
+ var mgr = new MockPoolingCuVSResourceManager(2);
+ var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ assertThat(res1.toString(), containsString("id=0"));
+ assertThat(res2.toString(), containsString("id=1"));
+ mgr.release(res1);
+ mgr.release(res2);
+ res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ assertThat(res1.toString(), containsString("id=0"));
+ assertThat(res2.toString(), containsString("id=1"));
+ mgr.release(res1);
+ mgr.release(res2);
+ mgr.shutdown();
+ }
+
+ public void testBlocking() throws Exception {
+ var mgr = new MockPoolingCuVSResourceManager(2);
+ var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+
+ AtomicReference holder = new AtomicReference<>();
+ Thread t = new Thread(() -> {
+ try {
+ var res3 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ holder.set(res3);
+ } catch (InterruptedException e) {
+ throw new AssertionError(e);
+ }
+ });
+ t.start();
+ Thread.sleep(1_000);
+ assertNull(holder.get());
+ mgr.release(randomFrom(res1, res2));
+ t.join();
+ assertThat(holder.get().toString(), anyOf(containsString("id=0"), containsString("id=1")));
+ mgr.shutdown();
+ }
+
+ public void testBlockingOnInsufficientMemory() throws Exception {
+ var mgr = new MockPoolingCuVSResourceManager(2);
+ var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT);
+
+ AtomicReference holder = new AtomicReference<>();
+ Thread t = new Thread(() -> {
+ try {
+ var res2 = mgr.acquire((16 * 1024) + 1, 1024, CuVSMatrix.DataType.FLOAT);
+ holder.set(res2);
+ } catch (InterruptedException e) {
+ throw new AssertionError(e);
+ }
+ });
+ t.start();
+ Thread.sleep(1_000);
+ assertNull(holder.get());
+ mgr.release(res1);
+ t.join();
+ assertThat(holder.get().toString(), anyOf(containsString("id=0"), containsString("id=1")));
+ mgr.shutdown();
+ }
+
+ public void testNotBlockingOnSufficientMemory() throws Exception {
+ var mgr = new MockPoolingCuVSResourceManager(2);
+ var res1 = mgr.acquire(16 * 1024, 1024, CuVSMatrix.DataType.FLOAT);
+
+ AtomicReference holder = new AtomicReference<>();
+ Thread t = new Thread(() -> {
+ try {
+ var res2 = mgr.acquire((16 * 1024) - 1, 1024, CuVSMatrix.DataType.FLOAT);
+ holder.set(res2);
+ } catch (InterruptedException e) {
+ throw new AssertionError(e);
+ }
+ });
+ t.start();
+ t.join(5_000);
+ assertNotNull(holder.get());
+ assertThat(holder.get().toString(), not(equalTo(res1.toString())));
+ mgr.shutdown();
+ }
+
+ public void testManagedResIsNotClosable() throws Exception {
+ var mgr = new MockPoolingCuVSResourceManager(1);
+ var res = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ assertThrows(UnsupportedOperationException.class, res::close);
+ mgr.release(res);
+ mgr.shutdown();
+ }
+
+ public void testDoubleRelease() throws InterruptedException {
+ var mgr = new MockPoolingCuVSResourceManager(2);
+ var res1 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ var res2 = mgr.acquire(0, 0, CuVSMatrix.DataType.FLOAT);
+ mgr.release(res1);
+ mgr.release(res2);
+ assertThrows(AssertionError.class, () -> mgr.release(randomFrom(res1, res2)));
+ mgr.shutdown();
+ }
+
+ static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager {
+
+ private final AtomicInteger idGenerator = new AtomicInteger();
+ private final List allocations;
+
+ MockPoolingCuVSResourceManager(int capacity) {
+ this(capacity, new ArrayList<>());
+ }
+
+ private MockPoolingCuVSResourceManager(int capacity, List allocationList) {
+ super(capacity, new MockGPUInfoProvider(() -> freeMemoryFunction(allocationList)));
+ this.allocations = allocationList;
+ }
+
+ private static long freeMemoryFunction(List allocations) {
+ return TOTAL_DEVICE_MEMORY_IN_BYTES - allocations.stream().mapToLong(x -> x).sum();
+ }
+
+ @Override
+ protected CuVSResources createNew() {
+ return new MockCuVSResources(idGenerator.getAndIncrement());
+ }
+
+ @Override
+ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException {
+ var res = super.acquire(numVectors, dims, dataType);
+ long memory = (long) (numVectors * dims * Float.BYTES
+ * CuVSResourceManager.PoolingCuVSResourceManager.GPU_COMPUTATION_MEMORY_FACTOR);
+ allocations.add(memory);
+ log.info("Added [{}]", memory);
+ return res;
+ }
+
+ @Override
+ public void release(ManagedCuVSResources resources) {
+ if (allocations.isEmpty() == false) {
+ var x = allocations.removeLast();
+ log.info("Removed [{}]", x);
+ }
+ super.release(resources);
+ }
+ }
+
+ static class MockCuVSResources implements CuVSResources {
+
+ final int id;
+
+ MockCuVSResources(int id) {
+ this.id = id;
+ }
+
+ @Override
+ public ScopedAccess access() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int deviceId() {
+ return 0;
+ }
+
+ @Override
+ public void close() {}
+
+ @Override
+ public Path tempDirectory() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String toString() {
+ return "MockCuVSResources[id=" + id + "]";
+ }
+ }
+
+ private static class MockGPUInfoProvider implements GPUInfoProvider {
+ private final LongSupplier freeMemorySupplier;
+
+ MockGPUInfoProvider(LongSupplier freeMemorySupplier) {
+ this.freeMemorySupplier = freeMemorySupplier;
+ }
+
+ @Override
+ public List availableGPUs() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List compatibleGPUs() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public CuVSResourcesInfo getCurrentInfo(CuVSResources cuVSResources) {
+ return new CuVSResourcesInfo(freeMemorySupplier.getAsLong(), TOTAL_DEVICE_MEMORY_IN_BYTES);
+ }
+ }
+}
diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsTests.java
new file mode 100644
index 0000000000000..6c43843dbd830
--- /dev/null
+++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsTests.java
@@ -0,0 +1,78 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import com.nvidia.cuvs.CuVSMatrix;
+
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.gpu.GPUSupport;
+import org.junit.Before;
+
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+import java.nio.ByteOrder;
+
+import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED;
+
+public class DatasetUtilsTests extends ESTestCase {
+
+ DatasetUtils datasetUtils;
+
+ @Before
+ public void setup() { // TODO: abstract out setup in to common GPUTestcase
+ assumeTrue("cuvs runtime only supported on 22 or greater, your JDK is " + Runtime.version(), Runtime.version().feature() >= 22);
+ assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
+ datasetUtils = DatasetUtils.getInstance();
+ }
+
+ static final ValueLayout.OfFloat JAVA_FLOAT_LE = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN);
+
+ public void testBasic() throws Exception {
+ try (Directory dir = new MMapDirectory(createTempDir("testBasic"))) {
+ int numVecs = randomIntBetween(1, 100);
+ int dims = randomIntBetween(128, 2049);
+
+ try (var out = dir.createOutput("vector.data", IOContext.DEFAULT)) {
+ var ba = new byte[dims * Float.BYTES];
+ var seg = MemorySegment.ofArray(ba);
+ for (int v = 0; v < numVecs; v++) {
+ var src = MemorySegment.ofArray(randomVector(dims));
+ MemorySegment.copy(src, JAVA_FLOAT_UNALIGNED, 0L, seg, JAVA_FLOAT_LE, 0L, numVecs);
+ out.writeBytes(ba, 0, ba.length);
+ }
+ }
+ try (
+ var in = dir.openInput("vector.data", IOContext.DEFAULT);
+ var dataset = datasetUtils.fromInput((MemorySegmentAccessInput) in, numVecs, dims, CuVSMatrix.DataType.FLOAT)
+ ) {
+ assertEquals(numVecs, dataset.size());
+ assertEquals(dims, dataset.columns());
+ }
+ }
+ }
+
+ static final Class IAE = IllegalArgumentException.class;
+
+ public void testIllegal() {
+ MemorySegmentAccessInput in = null; // TODO: make this non-null
+ expectThrows(IAE, () -> datasetUtils.fromInput(in, -1, 1, CuVSMatrix.DataType.FLOAT));
+ expectThrows(IAE, () -> datasetUtils.fromInput(in, 1, -1, CuVSMatrix.DataType.FLOAT));
+ }
+
+ float[] randomVector(int dims) {
+ float[] fa = new float[dims];
+ for (int i = 0; i < dims; ++i) {
+ fa[i] = random().nextFloat();
+ }
+ return fa;
+ }
+}
diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswSQVectorsFormatTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswSQVectorsFormatTests.java
new file mode 100644
index 0000000000000..f1c13b15795c5
--- /dev/null
+++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswSQVectorsFormatTests.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.gpu.codec;
+
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.tests.util.TestUtil;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.xpack.gpu.GPUSupport;
+import org.junit.BeforeClass;
+
+@LuceneTestCase.SuppressSysoutChecks(bugUrl = "https://github.com/rapidsai/cuvs/issues/1310")
+public class ES92GpuHnswSQVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
+
+ static {
+ LogConfigurator.loadLog4jPlugins();
+ LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+ }
+
+ static Codec codec;
+
+ @BeforeClass
+ public static void beforeClass() {
+ assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
+ codec = TestUtil.alwaysKnnVectorsFormat(new ES92GpuHnswSQVectorsFormat());
+ }
+
+ @Override
+ protected Codec getCodec() {
+ return codec;
+ }
+
+ @Override
+ protected VectorSimilarityFunction randomSimilarity() {
+ return VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)];
+ }
+
+ @Override
+ protected VectorEncoding randomVectorEncoding() {
+ return VectorEncoding.FLOAT32;
+ }
+
+ @Override
+ public void testRandomBytes() {
+ // No bytes support
+ }
+
+ @Override
+ public void testSortedIndexBytes() {
+ // No bytes support
+ }
+
+ @Override
+ public void testByteVectorScorerIteration() {
+ // No bytes support
+ }
+
+ @Override
+ public void testEmptyByteVectorData() {
+ // No bytes support
+ }
+
+ @Override
+ public void testMergingWithDifferentByteKnnFields() {
+ // No bytes support
+ }
+
+ @Override
+ public void testMismatchedFields() {
+ // No bytes support
+ }
+}
diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormatTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormatTests.java
new file mode 100644
index 0000000000000..e7ce310d15d9b
--- /dev/null
+++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/ES92GpuHnswVectorsFormatTests.java
@@ -0,0 +1,81 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.gpu.codec;
+
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.apache.lucene.tests.util.TestUtil;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.xpack.gpu.GPUSupport;
+import org.junit.BeforeClass;
+
+// CuVS prints tons of logs to stdout
+@LuceneTestCase.SuppressSysoutChecks(bugUrl = "https://github.com/rapidsai/cuvs/issues/1310")
+public class ES92GpuHnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
+
+ static {
+ LogConfigurator.loadLog4jPlugins();
+ LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+ }
+
+ static Codec codec;
+
+ @BeforeClass
+ public static void beforeClass() {
+ assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
+ codec = TestUtil.alwaysKnnVectorsFormat(new ES92GpuHnswVectorsFormat());
+ }
+
+ @Override
+ protected Codec getCodec() {
+ return codec;
+ }
+
+ @Override
+ protected VectorSimilarityFunction randomSimilarity() {
+ return VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)];
+ }
+
+ @Override
+ protected VectorEncoding randomVectorEncoding() {
+ return VectorEncoding.FLOAT32;
+ }
+
+ @Override
+ public void testRandomBytes() throws Exception {
+ // No bytes support
+ }
+
+ @Override
+ public void testSortedIndexBytes() throws Exception {
+ // No bytes support
+ }
+
+ @Override
+ public void testByteVectorScorerIteration() throws Exception {
+ // No bytes support
+ }
+
+ @Override
+ public void testEmptyByteVectorData() throws Exception {
+ // No bytes support
+ }
+
+ @Override
+ public void testMergingWithDifferentByteKnnFields() throws Exception {
+ // No bytes support
+ }
+
+ @Override
+ public void testMismatchedFields() throws Exception {
+ // No bytes support
+ }
+
+}
diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java
new file mode 100644
index 0000000000000..2648691d03eec
--- /dev/null
+++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.gpu.codec;
+
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.index.codec.CodecService;
+import org.elasticsearch.index.codec.LegacyPerFieldMapperCodec;
+import org.elasticsearch.index.codec.PerFieldMapperCodec;
+import org.elasticsearch.index.mapper.MapperService;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.xpack.gpu.GPUPlugin;
+import org.elasticsearch.xpack.gpu.GPUSupport;
+import org.junit.BeforeClass;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+
+import static org.hamcrest.Matchers.instanceOf;
+
+public class GPUDenseVectorFieldMapperTests extends DenseVectorFieldMapperTests {
+
+ @BeforeClass
+ public static void setup() {
+ assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
+ }
+
+ @Override
+ protected Collection getPlugins() {
+ var plugin = new GPUPlugin();
+ return Collections.singletonList(plugin);
+ }
+
+ @Override
+ public void testKnnVectorsFormat() throws IOException {
+ // TODO improve test with custom parameters
+ KnnVectorsFormat knnVectorsFormat = getKnnVectorsFormat("hnsw");
+ String expectedStr = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, "
+ + "maxConn=16, beamWidth=128, flatVectorFormat=Lucene99FlatVectorsFormat)";
+ assertEquals(expectedStr, knnVectorsFormat.toString());
+ }
+
+ @Override
+ public void testKnnQuantizedHNSWVectorsFormat() throws IOException {
+ // TOD improve the test with custom parameters
+ KnnVectorsFormat knnVectorsFormat = getKnnVectorsFormat("int8_hnsw");
+ String expectedStr = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, "
+ + "maxConn=16, beamWidth=128, flatVectorFormat=ES814ScalarQuantizedVectorsFormat";
+ assertTrue(knnVectorsFormat.toString().startsWith(expectedStr));
+ }
+
+ private KnnVectorsFormat getKnnVectorsFormat(String indexOptionsType) throws IOException {
+ final int dims = randomIntBetween(128, 4096);
+ MapperService mapperService = createMapperService(fieldMapping(b -> {
+ b.field("type", "dense_vector");
+ b.field("dims", dims);
+ b.field("index", true);
+ b.field("similarity", "dot_product");
+ b.startObject("index_options");
+ b.field("type", indexOptionsType);
+ b.endObject();
+ }));
+ CodecService codecService = new CodecService(mapperService, BigArrays.NON_RECYCLING_INSTANCE);
+ Codec codec = codecService.codec("default");
+ if (CodecService.ZSTD_STORED_FIELDS_FEATURE_FLAG) {
+ assertThat(codec, instanceOf(PerFieldMapperCodec.class));
+ return ((PerFieldMapperCodec) codec).getKnnVectorsFormatForField("field");
+ } else {
+ if (codec instanceof CodecService.DeduplicateFieldInfosCodec deduplicateFieldInfosCodec) {
+ codec = deduplicateFieldInfosCodec.delegate();
+ }
+ assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class));
+ return ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field");
+ }
+ }
+}
diff --git a/x-pack/plugin/gpu/src/yamlRestTest/java/org/elasticsearch/xpack/gpu/GPUClientYamlTestSuiteIT.java b/x-pack/plugin/gpu/src/yamlRestTest/java/org/elasticsearch/xpack/gpu/GPUClientYamlTestSuiteIT.java
new file mode 100644
index 0000000000000..c4e7e936b0111
--- /dev/null
+++ b/x-pack/plugin/gpu/src/yamlRestTest/java/org/elasticsearch/xpack/gpu/GPUClientYamlTestSuiteIT.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.gpu;
+
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
+import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+
+public class GPUClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
+
+ @BeforeClass
+ public static void setup() {
+ assumeTrue("cuvs not supported", GPUSupport.isSupported(false));
+ }
+
+ @ClassRule
+ public static ElasticsearchCluster cluster = createCluster();
+
+ private static ElasticsearchCluster createCluster() {
+ var builder = ElasticsearchCluster.local()
+ .nodes(1)
+ .module("gpu")
+ .setting("xpack.license.self_generated.type", "trial")
+ .setting("xpack.security.enabled", "false");
+
+ var libraryPath = System.getenv("LD_LIBRARY_PATH");
+ if (libraryPath != null) {
+ builder.environment("LD_LIBRARY_PATH", libraryPath);
+ }
+ return builder.build();
+ }
+
+ public GPUClientYamlTestSuiteIT(final ClientYamlTestCandidate testCandidate) {
+ super(testCandidate);
+ }
+
+ @ParametersFactory
+ public static Iterable