Skip to content

Commit 64bc4ca

Browse files
committed
Canonicalise LearnerND hull seeding and add regression test
1 parent 4206a06 commit 64bc4ca

File tree

4 files changed

+85
-2053
lines changed

4 files changed

+85
-2053
lines changed

adaptive/learner/learnerND.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ def __init__(self, func, bounds, loss_per_simplex=None):
376376
# been returned has not been deleted. This checking is done by
377377
# _pop_highest_existing_simplex
378378
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
379-
self._skipped_bound_points: set[tuple[float, ...]] = set()
379+
self._next_bound_idx = 0
380+
self._bound_match_tol = 1e-10
380381

381382
def new(self) -> LearnerND:
382383
"""Create a new learner with the same function and bounds."""
@@ -489,6 +490,7 @@ def load_dataframe( # type: ignore[override]
489490
self.function = partial_function_from_dataframe(
490491
self.function, df, function_prefix
491492
)
493+
self._next_bound_idx = 0
492494

493495
@property
494496
def bounds_are_done(self):
@@ -556,6 +558,20 @@ def _simplex_exists(self, simplex):
556558
simplex = tuple(sorted(simplex))
557559
return simplex in self.tri.simplices
558560

561+
def _is_known_point(self, point):
562+
point = tuple(map(float, point))
563+
if point in self.data or point in self.pending_points:
564+
return True
565+
566+
tol = self._bound_match_tol
567+
for existing in self.data.keys():
568+
if all(abs(a - b) <= tol for a, b in zip(point, existing)):
569+
return True
570+
for existing in self.pending_points:
571+
if all(abs(a - b) <= tol for a, b in zip(point, existing)):
572+
return True
573+
return False
574+
559575
def inside_bounds(self, point):
560576
"""Check whether a point is inside the bounds."""
561577
if self._interior is not None:
@@ -633,24 +649,18 @@ def ask(self, n, tell_pending=True):
633649

634650
def _ask_bound_point(self):
635651
# get the next bound point that is still available
636-
while True:
637-
new_point = next(
638-
p
639-
for p in self._bounds_points
640-
if p not in self.data
641-
and p not in self.pending_points
642-
and p not in self._skipped_bound_points
643-
)
644-
try:
645-
self.tell_pending(new_point)
646-
except ValueError as exc:
647-
if str(exc) == "Point already in triangulation.":
648-
self.pending_points.discard(new_point)
649-
self._skipped_bound_points.add(new_point)
650-
continue
651-
raise
652+
while self._next_bound_idx < len(self._bounds_points):
653+
new_point = self._bounds_points[self._next_bound_idx]
654+
self._next_bound_idx += 1
655+
656+
if self._is_known_point(new_point):
657+
continue
658+
659+
self.tell_pending(new_point)
652660
return new_point, np.inf
653661

662+
raise StopIteration
663+
654664
def _ask_point_without_known_simplices(self):
655665
assert not self._bounds_available
656666
# pick a random point inside the bounds
@@ -715,20 +725,13 @@ def _ask_best_point(self):
715725
@property
716726
def _bounds_available(self):
717727
return any(
718-
(
719-
p not in self.pending_points
720-
and p not in self.data
721-
and p not in self._skipped_bound_points
722-
)
723-
for p in self._bounds_points
728+
not self._is_known_point(p)
729+
for p in self._bounds_points[self._next_bound_idx :]
724730
)
725731

726732
def _ask(self):
727733
if self._bounds_available:
728-
try:
729-
return self._ask_bound_point() # O(1)
730-
except StopIteration:
731-
pass
734+
return self._ask_bound_point() # O(1)
732735

733736
if self.tri is None:
734737
# All bound points are pending or have been evaluated, but we do not

adaptive/tests/data/issue_470_boundaries.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)