Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions spacv/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,10 @@ def split(self, XYs):
indices = XYs.index.values

for test_indices, train_excluded in self._iter_test_indices(XYs):
# Exclude the training indices within buffer
# Combine training indices within buffer with test indices
train_excluded = np.concatenate([test_indices, train_excluded])
train_index = np.setdiff1d(
np.union1d(
indices,
train_excluded
), np.intersect1d(indices, train_excluded)
)
# Exclude test indices and training indices within buffer to get final training indices
train_index = np.setdiff1d(indices, train_excluded)
if len(train_index) < 1:
raise ValueError(
"Training set is empty. Try lowering buffer_radius to include more training instances."
Expand All @@ -68,11 +64,11 @@ def _remove_buffered_indices(self, XYs, test_indices, buffer_radius, geometry_bu
geometry_buffer = convert_geodataframe(geometry_buffer)
deadzone_points = gpd.sjoin(candidate_deadzone, geometry_buffer)
train_exclude = deadzone_points.loc[~deadzone_points.index.isin(test_indices)].index.values
return test_indices, train_exclude
else:
# Yield empty array because no training data removed in dead zone when buffer is zero
_ = np.empty([], dtype=np.int)
return test_indices, _
# Yield empty array (with the same dimensions as test_indices) because no training data removed in dead zone
# when buffer is zero
train_exclude = np.empty([0] * test_indices.ndim, dtype=np.int)
return test_indices, train_exclude

@abstractmethod
def _iter_test_indices(self, XYs):
Expand Down