From 496d348dcabfaa726810b0f5325ff10a663f6dc7 Mon Sep 17 00:00:00 2001 From: mathleur Date: Fri, 31 Oct 2025 10:58:46 +0100 Subject: [PATCH] add k nearest neighbour in rust quadtree --- rust/Cargo.toml | 1 + rust/src/quadtree_mod.rs | 66 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 92a20b24..d9b2a222 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] pyo3 = { version = "0.20", features = ["extension-module"] } geo = { version = "0.30"} +ordered-float = "4.2" [lib] name = "polytope_rs" diff --git a/rust/src/quadtree_mod.rs b/rust/src/quadtree_mod.rs index 52f77349..7ec9e6af 100644 --- a/rust/src/quadtree_mod.rs +++ b/rust/src/quadtree_mod.rs @@ -11,6 +11,10 @@ use pyo3::exceptions::PyRuntimeError; use crate::slicing_tools::{is_contained_in, slice_in_two}; use crate::distance::{dist2, box_dist2}; +use std::collections::BinaryHeap; +use std::cmp::Reverse; +use ordered_float::OrderedFloat; + #[derive(Debug)] @@ -55,6 +59,24 @@ impl QuadTree { } } + + fn k_nearest_neighbor(&self, query: (f64, f64), k: usize, quadtree_points: Vec<(f64, f64)>) -> Option> { + if self.nodes.is_empty() { + return None; + } + // let mut heap = BinaryHeap::new(); + let mut heap = BinaryHeap::, usize)>>::new(); + self.knn_search(0, query, k, &mut heap, &quadtree_points); + + // keep only point indexes from distance heap and sort from nearest to farthest + let mut results: Vec<_> = heap.into_sorted_vec() + .into_iter() + .map(|Reverse((OrderedFloat(_d2), idx))| idx) + .collect(); + + Some(results) + } + fn nearest_neighbor(&self, query: (f64, f64), quadtree_points: Vec<(f64, f64)>) -> Option { if self.nodes.is_empty() { return None; @@ -156,6 +178,50 @@ impl QuadTree { points.into_iter().map(|(x, y)| [x,y]).collect() } + fn knn_search( + &self, + node_idx: usize, + query: (f64, f64), + k: usize, + heap: &mut BinaryHeap, usize)>>, // min-heap of distances + quadtree_points: &Vec<(f64, f64)>, + ) { + let node = &self.nodes[node_idx]; + + // use farthest distance in the current heap to prune + let prune_dist2 = if heap.len() < k { + f64::INFINITY + } else { + heap.peek().unwrap().0 .0 .into_inner() + }; + + // if this node is farther than the k-th current best, ignore + if box_dist2(node.center, node.size, query) > prune_dist2 { + return; + } + + // compare distance of points inside leaf node + if let Some(point_indices) = &node.points { + for &pi in point_indices { + let p = quadtree_points[pi]; + let d2 = dist2(p, query); + + if heap.len() < k { + heap.push(Reverse((OrderedFloat(d2), pi))); + } else if d2 < heap.peek().unwrap().0 .0 .into_inner(){ + heap.pop(); + heap.push(Reverse((OrderedFloat(d2), pi))); + } + } + return; + } + + // recurse into children + for &child_idx in &node.children { + self.knn_search(child_idx, query, k, heap, quadtree_points); + } + } + fn nn_search( &self, node_idx: usize,