1
1
import enum
2
- import functools
3
2
import heapq
3
+ from abc import ABC , abstractmethod
4
4
from operator import itemgetter
5
5
from typing import TYPE_CHECKING , NamedTuple
6
6
@@ -16,9 +16,25 @@ class TestGroup(NamedTuple):
16
16
duration : float
17
17
18
18
19
- def least_duration (
20
- splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
21
- ) -> "List[TestGroup]" :
19
+ class AlgorithmBase (ABC ):
20
+ """Abstract base class for the algorithm implementations."""
21
+
22
+ @abstractmethod
23
+ def __call__ (
24
+ self , splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
25
+ ) -> "List[TestGroup]" :
26
+ pass
27
+
28
+ def __hash__ (self ) -> int :
29
+ return hash (self .__class__ .__name__ )
30
+
31
+ def __eq__ (self , other : object ) -> bool :
32
+ if not isinstance (other , AlgorithmBase ):
33
+ return NotImplemented
34
+ return self .__class__ .__name__ == other .__class__ .__name__
35
+
36
+
37
+ class LeastDurationAlgorithm (AlgorithmBase ):
22
38
"""
23
39
Split tests into groups by runtime.
24
40
It walks the test items, starting with the test with largest duration.
@@ -34,60 +50,65 @@ def least_duration(
34
50
:return:
35
51
List of groups
36
52
"""
37
- items_with_durations = _get_items_with_durations (items , durations )
38
53
39
- # add index of item in list
40
- items_with_durations_indexed = [
41
- ( * tup , i ) for i , tup in enumerate ( items_with_durations )
42
- ]
54
+ def __call__ (
55
+ self , splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
56
+ ) -> "List[TestGroup]" :
57
+ items_with_durations = _get_items_with_durations ( items , durations )
43
58
44
- # Sort by name to ensure it's always the same order
45
- items_with_durations_indexed = sorted (
46
- items_with_durations_indexed , key = lambda tup : str (tup [0 ])
47
- )
48
-
49
- # sort in ascending order
50
- sorted_items_with_durations = sorted (
51
- items_with_durations_indexed , key = lambda tup : tup [1 ], reverse = True
52
- )
53
-
54
- selected : List [List [Tuple [nodes .Item , int ]]] = [[] for _ in range (splits )]
55
- deselected : List [List [nodes .Item ]] = [[] for _ in range (splits )]
56
- duration : List [float ] = [0 for _ in range (splits )]
57
-
58
- # create a heap of the form (summed_durations, group_index)
59
- heap : List [Tuple [float , int ]] = [(0 , i ) for i in range (splits )]
60
- heapq .heapify (heap )
61
- for item , item_duration , original_index in sorted_items_with_durations :
62
- # get group with smallest sum
63
- summed_durations , group_idx = heapq .heappop (heap )
64
- new_group_durations = summed_durations + item_duration
65
-
66
- # store assignment
67
- selected [group_idx ].append ((item , original_index ))
68
- duration [group_idx ] = new_group_durations
69
- for i in range (splits ):
70
- if i != group_idx :
71
- deselected [i ].append (item )
72
-
73
- # store new duration - in case of ties it sorts by the group_idx
74
- heapq .heappush (heap , (new_group_durations , group_idx ))
75
-
76
- groups = []
77
- for i in range (splits ):
78
- # sort the items by their original index to maintain relative ordering
79
- # we don't care about the order of deselected items
80
- s = [
81
- item for item , original_index in sorted (selected [i ], key = lambda tup : tup [1 ])
59
+ # add index of item in list
60
+ items_with_durations_indexed = [
61
+ (* tup , i ) for i , tup in enumerate (items_with_durations )
82
62
]
83
- group = TestGroup (selected = s , deselected = deselected [i ], duration = duration [i ])
84
- groups .append (group )
85
- return groups
86
-
87
63
88
- def duration_based_chunks (
89
- splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
90
- ) -> "List[TestGroup]" :
64
+ # Sort by name to ensure it's always the same order
65
+ items_with_durations_indexed = sorted (
66
+ items_with_durations_indexed , key = lambda tup : str (tup [0 ])
67
+ )
68
+
69
+ # sort in ascending order
70
+ sorted_items_with_durations = sorted (
71
+ items_with_durations_indexed , key = lambda tup : tup [1 ], reverse = True
72
+ )
73
+
74
+ selected : List [List [Tuple [nodes .Item , int ]]] = [[] for _ in range (splits )]
75
+ deselected : List [List [nodes .Item ]] = [[] for _ in range (splits )]
76
+ duration : List [float ] = [0 for _ in range (splits )]
77
+
78
+ # create a heap of the form (summed_durations, group_index)
79
+ heap : List [Tuple [float , int ]] = [(0 , i ) for i in range (splits )]
80
+ heapq .heapify (heap )
81
+ for item , item_duration , original_index in sorted_items_with_durations :
82
+ # get group with smallest sum
83
+ summed_durations , group_idx = heapq .heappop (heap )
84
+ new_group_durations = summed_durations + item_duration
85
+
86
+ # store assignment
87
+ selected [group_idx ].append ((item , original_index ))
88
+ duration [group_idx ] = new_group_durations
89
+ for i in range (splits ):
90
+ if i != group_idx :
91
+ deselected [i ].append (item )
92
+
93
+ # store new duration - in case of ties it sorts by the group_idx
94
+ heapq .heappush (heap , (new_group_durations , group_idx ))
95
+
96
+ groups = []
97
+ for i in range (splits ):
98
+ # sort the items by their original index to maintain relative ordering
99
+ # we don't care about the order of deselected items
100
+ s = [
101
+ item
102
+ for item , original_index in sorted (selected [i ], key = lambda tup : tup [1 ])
103
+ ]
104
+ group = TestGroup (
105
+ selected = s , deselected = deselected [i ], duration = duration [i ]
106
+ )
107
+ groups .append (group )
108
+ return groups
109
+
110
+
111
+ class DurationBasedChunksAlgorithm (AlgorithmBase ):
91
112
"""
92
113
Split tests into groups by runtime.
93
114
Ensures tests are split into non-overlapping groups.
@@ -99,28 +120,34 @@ def duration_based_chunks(
99
120
:param durations: Our cached test runtimes. Assumes contains timings only of relevant tests
100
121
:return: List of TestGroup
101
122
"""
102
- items_with_durations = _get_items_with_durations (items , durations )
103
- time_per_group = sum (map (itemgetter (1 ), items_with_durations )) / splits
104
-
105
- selected : List [List [nodes .Item ]] = [[] for i in range (splits )]
106
- deselected : List [List [nodes .Item ]] = [[] for i in range (splits )]
107
- duration : List [float ] = [0 for i in range (splits )]
108
123
109
- group_idx = 0
110
- for item , item_duration in items_with_durations :
111
- if duration [group_idx ] >= time_per_group :
112
- group_idx += 1
113
-
114
- selected [group_idx ].append (item )
115
- for i in range (splits ):
116
- if i != group_idx :
117
- deselected [i ].append (item )
118
- duration [group_idx ] += item_duration
119
-
120
- return [
121
- TestGroup (selected = selected [i ], deselected = deselected [i ], duration = duration [i ])
122
- for i in range (splits )
123
- ]
124
+ def __call__ (
125
+ self , splits : int , items : "List[nodes.Item]" , durations : "Dict[str, float]"
126
+ ) -> "List[TestGroup]" :
127
+ items_with_durations = _get_items_with_durations (items , durations )
128
+ time_per_group = sum (map (itemgetter (1 ), items_with_durations )) / splits
129
+
130
+ selected : List [List [nodes .Item ]] = [[] for i in range (splits )]
131
+ deselected : List [List [nodes .Item ]] = [[] for i in range (splits )]
132
+ duration : List [float ] = [0 for i in range (splits )]
133
+
134
+ group_idx = 0
135
+ for item , item_duration in items_with_durations :
136
+ if duration [group_idx ] >= time_per_group :
137
+ group_idx += 1
138
+
139
+ selected [group_idx ].append (item )
140
+ for i in range (splits ):
141
+ if i != group_idx :
142
+ deselected [i ].append (item )
143
+ duration [group_idx ] += item_duration
144
+
145
+ return [
146
+ TestGroup (
147
+ selected = selected [i ], deselected = deselected [i ], duration = duration [i ]
148
+ )
149
+ for i in range (splits )
150
+ ]
124
151
125
152
126
153
def _get_items_with_durations (
@@ -153,9 +180,8 @@ def _remove_irrelevant_durations(
153
180
154
181
155
182
class Algorithms (enum .Enum ):
156
- # values have to wrapped inside functools to avoid them being considered method definitions
157
- duration_based_chunks = functools .partial (duration_based_chunks )
158
- least_duration = functools .partial (least_duration )
183
+ duration_based_chunks = DurationBasedChunksAlgorithm ()
184
+ least_duration = LeastDurationAlgorithm ()
159
185
160
186
@staticmethod
161
187
def names () -> "List[str]" :
0 commit comments