Skip to content

Commit a903c1c

Browse files
committed
Simplify logic for the future
1 parent 3bcd002 commit a903c1c

File tree

4 files changed

+18
-33
lines changed

4 files changed

+18
-33
lines changed

sklearn/tree/_oblique_tree.pxd

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,12 @@ cdef class ObliqueTree(Tree):
4242
self,
4343
const DTYPE_t[:, :] X_ndarray,
4444
SIZE_t sample_index,
45-
Node *node,
46-
SIZE_t node_id
45+
Node *node
4746
) nogil
4847
cdef void _compute_feature_importances(
4948
self,
5049
DOUBLE_t* importance_data,
51-
Node* node,
52-
SIZE_t node_id
50+
Node* node
5351
) nogil
5452

5553
cpdef cnp.ndarray get_projection_matrix(self)

sklearn/tree/_oblique_tree.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ cdef class ObliqueTree(Tree):
246246
self.proj_vec_indices[node_id] = deref(deref(oblique_split_node).proj_vec_indices)
247247
return 1
248248

249-
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:, :] X_ndarray, SIZE_t sample_index, Node *node, SIZE_t node_id) nogil:
249+
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:, :] X_ndarray, SIZE_t sample_index, Node *node) nogil:
250250
"""Compute feature from a given data matrix, X.
251251
252252
In oblique-aligned trees, this is the projection of X.
@@ -257,6 +257,9 @@ cdef class ObliqueTree(Tree):
257257
cdef SIZE_t j = 0
258258
cdef SIZE_t feature_index
259259
cdef SIZE_t n_features = self.n_features
260+
261+
# get the index of the node
262+
cdef SIZE_t node_id = node - self.nodes
260263

261264
# cdef SIZE_t n_projections = proj_vec_indices.size()
262265
# compute projection of the data based on trained tree
@@ -274,7 +277,7 @@ cdef class ObliqueTree(Tree):
274277
return proj_feat
275278

276279
cdef void _compute_feature_importances(self, DOUBLE_t* importance_data,
277-
Node* node, SIZE_t node_id) nogil:
280+
Node* node) nogil:
278281
"""Compute feature importances from a Node in the Tree.
279282
280283
Wrapped in a private function to allow subclassing that
@@ -284,6 +287,9 @@ cdef class ObliqueTree(Tree):
284287
cdef Node* left
285288
cdef Node* right
286289

290+
# get the index of the node
291+
cdef SIZE_t node_id = node - self.nodes
292+
287293
left = &nodes[node.left_child]
288294
right = &nodes[node.right_child]
289295

sklearn/tree/_tree.pxd

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,12 @@ cdef class BaseTree:
8787
self,
8888
const DTYPE_t[:, :] X_ndarray,
8989
SIZE_t sample_index,
90-
Node *node,
91-
SIZE_t node_id
90+
Node *node
9291
) nogil
9392
cdef void _compute_feature_importances(
9493
self,
9594
DOUBLE_t* importance_data,
9695
Node* node,
97-
SIZE_t node_id
9896
) nogil
9997

10098
cdef class Tree(BaseTree):

sklearn/tree/_tree.pyx

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ cdef class BaseTree:
623623

624624
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:, :] X_ndarray,
625625
SIZE_t sample_index,
626-
Node *node, SIZE_t node_id) nogil:
626+
Node *node) nogil:
627627
"""Compute feature from a given data matrix, X.
628628
629629
In axis-aligned trees, this is simply the value in the column of X
@@ -725,29 +725,23 @@ cdef class BaseTree:
725725
cdef Node* node = NULL
726726
cdef SIZE_t i = 0
727727

728-
# to keep track of the current ID of each node
729-
cdef SIZE_t node_id = 0
730-
731728
# the feature value
732729
cdef DTYPE_t feature_value = 0
733730

734731
with nogil:
735732
for i in range(n_samples):
736733
node = self.nodes
737-
node_id = 0
738734

739735
# While node not a leaf
740736
while node.left_child != _TREE_LEAF:
741737
# ... and node.right_child != _TREE_LEAF:
742738

743739
# compute the feature value to compare against threshold
744-
feature_value = self._compute_feature(X_ndarray, i, node, node_id)
740+
feature_value = self._compute_feature(X_ndarray, i, node)
745741
if feature_value <= node.threshold:
746-
node_id = node.left_child
747-
node = &self.nodes[node_id]
742+
node = &self.nodes[node.left_child]
748743
else:
749-
node_id = node.right_child
750-
node = &self.nodes[node_id]
744+
node = &self.nodes[node.right_child]
751745

752746
out_ptr[i] = <SIZE_t>(node - self.nodes) # node offset
753747
return out
@@ -861,9 +855,6 @@ cdef class BaseTree:
861855
cdef Node* node = NULL
862856
cdef SIZE_t i = 0
863857

864-
# to keep track of the current ID of each node
865-
cdef SIZE_t node_id = 0
866-
867858
# the feature index
868859
cdef DOUBLE_t feature
869860

@@ -879,12 +870,10 @@ cdef class BaseTree:
879870
indptr_ptr[i + 1] += 1
880871

881872
# compute the feature value to compare against threshold
882-
feature = self._compute_feature(X_ndarray, i, node, node_id)
873+
feature = self._compute_feature(X_ndarray, i, node)
883874
if feature <= node.threshold:
884-
node_id = node.left_child
885875
node = &self.nodes[node.left_child]
886876
else:
887-
node_id = node.right_child
888877
node = &self.nodes[node.right_child]
889878

890879
# Add the leave node
@@ -1030,20 +1019,14 @@ cdef class BaseTree:
10301019
importances = np.zeros((self.n_features,))
10311020
cdef DOUBLE_t* importance_data = <DOUBLE_t*>importances.data
10321021

1033-
# to keep track of the current ID of each node
1034-
cdef SIZE_t node_id = 0
1035-
10361022
with nogil:
1037-
node_id = 0
1038-
10391023
while node != end_node:
10401024
if node.left_child != _TREE_LEAF:
10411025
# ... and node.right_child != _TREE_LEAF:
10421026
self._compute_feature_importances(
1043-
importance_data, node, node_id)
1027+
importance_data, node)
10441028

10451029
node += 1
1046-
node_id += 1
10471030

10481031
importances /= nodes[0].weighted_n_node_samples
10491032

@@ -1057,7 +1040,7 @@ cdef class BaseTree:
10571040
return importances
10581041

10591042
cdef void _compute_feature_importances(self, DOUBLE_t* importance_data,
1060-
Node* node, SIZE_t node_id) nogil:
1043+
Node* node) nogil:
10611044
"""Compute feature importances from a Node in the Tree.
10621045
10631046
Wrapped in a private function to allow subclassing that

0 commit comments

Comments
 (0)