Skip to content

Commit 95cf36e

Browse files
committed
Canonicalise LearnerND hull seeding and add regression test
1 parent 4206a06 commit 95cf36e

File tree

4 files changed

+92
-2053
lines changed

4 files changed

+92
-2053
lines changed

adaptive/learner/learnerND.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ def __init__(self, func, bounds, loss_per_simplex=None):
376376
# been returned has not been deleted. This checking is done by
377377
# _pop_highest_existing_simplex
378378
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
379-
self._skipped_bound_points: set[tuple[float, ...]] = set()
379+
self._next_bound_idx = 0
380+
self._bound_match_tol = 1e-10
380381

381382
def new(self) -> LearnerND:
382383
"""Create a new learner with the same function and bounds."""
@@ -489,6 +490,7 @@ def load_dataframe( # type: ignore[override]
489490
self.function = partial_function_from_dataframe(
490491
self.function, df, function_prefix
491492
)
493+
self._next_bound_idx = 0
492494

493495
@property
494496
def bounds_are_done(self):
@@ -556,6 +558,27 @@ def _simplex_exists(self, simplex):
556558
simplex = tuple(sorted(simplex))
557559
return simplex in self.tri.simplices
558560

561+
def _is_known_point(self, point):
562+
point = tuple(map(float, point))
563+
if point in self.data or point in self.pending_points:
564+
return True
565+
566+
tolerances = [
567+
max(self._bound_match_tol, self._bound_match_tol * (hi - lo))
568+
for lo, hi in self._bbox
569+
]
570+
571+
def _close(other):
572+
return all(abs(a - b) <= tol for (a, b, tol) in zip(point, other, tolerances))
573+
574+
for existing in self.data.keys():
575+
if _close(existing):
576+
return True
577+
for existing in self.pending_points:
578+
if _close(existing):
579+
return True
580+
return False
581+
559582
def inside_bounds(self, point):
560583
"""Check whether a point is inside the bounds."""
561584
if self._interior is not None:
@@ -633,24 +656,18 @@ def ask(self, n, tell_pending=True):
633656

634657
def _ask_bound_point(self):
635658
# get the next bound point that is still available
636-
while True:
637-
new_point = next(
638-
p
639-
for p in self._bounds_points
640-
if p not in self.data
641-
and p not in self.pending_points
642-
and p not in self._skipped_bound_points
643-
)
644-
try:
645-
self.tell_pending(new_point)
646-
except ValueError as exc:
647-
if str(exc) == "Point already in triangulation.":
648-
self.pending_points.discard(new_point)
649-
self._skipped_bound_points.add(new_point)
650-
continue
651-
raise
659+
while self._next_bound_idx < len(self._bounds_points):
660+
new_point = self._bounds_points[self._next_bound_idx]
661+
self._next_bound_idx += 1
662+
663+
if self._is_known_point(new_point):
664+
continue
665+
666+
self.tell_pending(new_point)
652667
return new_point, np.inf
653668

669+
raise StopIteration
670+
654671
def _ask_point_without_known_simplices(self):
655672
assert not self._bounds_available
656673
# pick a random point inside the bounds
@@ -715,20 +732,13 @@ def _ask_best_point(self):
715732
@property
716733
def _bounds_available(self):
717734
return any(
718-
(
719-
p not in self.pending_points
720-
and p not in self.data
721-
and p not in self._skipped_bound_points
722-
)
723-
for p in self._bounds_points
735+
not self._is_known_point(p)
736+
for p in self._bounds_points[self._next_bound_idx :]
724737
)
725738

726739
def _ask(self):
727740
if self._bounds_available:
728-
try:
729-
return self._ask_bound_point() # O(1)
730-
except StopIteration:
731-
pass
741+
return self._ask_bound_point() # O(1)
732742

733743
if self.tri is None:
734744
# All bound points are pending or have been evaluated, but we do not

adaptive/tests/data/issue_470_boundaries.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)