Skip to content

Commit 1e3de78

Browse files
interval callable
1 parent 3cfb4e2 commit 1e3de78

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

tsml/interval_based/_interval_pipelines.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@ class RandomIntervalClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
3535
3636
Parameters
3737
----------
38-
n_intervals : int, default=100,
38+
n_intervals : int or callable, default=100,
3939
The number of intervals of random length, position and dimension to be
40-
extracted.
40+
extracted. Input should be an int or a function that takes a 3D np.ndarray
41+
input and returns an int.
4142
min_interval_length : int, default=3
4243
The minimum length of extracted intervals. Minimum value of 3.
4344
max_interval_length : int, default=3
@@ -74,9 +75,9 @@ class RandomIntervalClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
7475
----------
7576
n_instances_ : int
7677
The number of train cases in the training set.
77-
n_channels : int
78+
n_channels_ : int
7879
The number of dimensions per case in the training set.
79-
n_timepoints : int
80+
n_timepoints_ : int
8081
The length of each series in the training set.
8182
n_classes_ : int
8283
Number of classes. Extracted from the data.
@@ -277,9 +278,10 @@ class RandomIntervalRegressor(RegressorMixin, BaseTimeSeriesEstimator):
277278
278279
Parameters
279280
----------
280-
n_intervals : int, default=100,
281+
n_intervals : int or callable, default=100,
281282
The number of intervals of random length, position and dimension to be
282-
extracted.
283+
extracted. Input should be an int or a function that takes a 3D np.ndarray
284+
input and returns an int.
283285
min_interval_length : int, default=3
284286
The minimum length of extracted intervals. Minimum value of 3.
285287
max_interval_length : int, default=3
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Tests for the interval pipeline classes."""
2+
3+
from tsml.interval_based import RandomIntervalClassifier
4+
from tsml.utils.testing import generate_3d_test_data
5+
6+
7+
def test_random_interval_callable():
8+
"""Test RandomIntervalClassifier with a callable n_intervals."""
9+
X, y = generate_3d_test_data()
10+
11+
def interval_func(X):
12+
return int(X.shape[2] / 5)
13+
14+
est = RandomIntervalClassifier(
15+
n_intervals=interval_func,
16+
)
17+
est.fit(X, y)
18+
19+
assert est._transformer._n_intervals == 2

tsml/transformations/_interval_extraction.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ class RandomIntervalTransformer(TransformerMixin, BaseTimeSeriesEstimator):
4848
4949
Parameters
5050
----------
51-
n_intervals : int, default=100,
51+
n_intervals : int or callable, default=100,
5252
The number of intervals of random length, position and dimension to be
53-
extracted.
53+
extracted. Input should be an int or a function that takes a 3D np.ndarray
54+
input and returns an int.
5455
min_interval_length : int, default=3
5556
The minimum length of extracted intervals. Minimum value of 3.
5657
max_interval_length : int, default=3
@@ -150,7 +151,7 @@ def fit_transform(self, X, y=None):
150151
rng.randint(np.iinfo(np.int32).max),
151152
True,
152153
)
153-
for _ in range(self.n_intervals)
154+
for _ in range(self._n_intervals)
154155
)
155156

156157
(
@@ -176,7 +177,7 @@ def fit_transform(self, X, y=None):
176177
removed_idx.append(i)
177178

178179
Xt = transformed_intervals[0]
179-
for i in range(1, self.n_intervals):
180+
for i in range(1, self._n_intervals):
180181
if i not in removed_idx:
181182
Xt = np.hstack((Xt, transformed_intervals[i]))
182183

@@ -261,6 +262,11 @@ def _fit_setup(self, X):
261262

262263
self.n_instances_, self.n_dims_, self.series_length_ = X.shape
263264

265+
if callable(self.n_intervals):
266+
self._n_intervals = self.n_intervals(X)
267+
else:
268+
self._n_intervals = self.n_intervals
269+
264270
self._min_interval_length = self.min_interval_length
265271
if self.min_interval_length < 3:
266272
self._min_interval_length = 3

0 commit comments

Comments
 (0)