Skip to content

Commit bc3c01a

Browse files
interval series transformer
1 parent 1e3de78 commit bc3c01a

File tree

2 files changed

+146
-30
lines changed

2 files changed

+146
-30
lines changed

tsml/interval_based/_interval_pipelines.py

Lines changed: 126 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import numpy as np
1717
from sklearn.base import ClassifierMixin, RegressorMixin
1818
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
19-
from sklearn.utils.validation import check_is_fitted
19+
from sklearn.ensemble._base import _set_random_states
20+
from sklearn.utils.validation import check_is_fitted, check_random_state
2021

2122
from tsml.base import BaseTimeSeriesEstimator, _clone_estimator
2223
from tsml.transformations._interval_extraction import (
@@ -47,6 +48,13 @@ class RandomIntervalClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
4748
of said transformers and functions, default=None
4849
Transformers and functions used to extract features from selected intervals.
4950
If None, defaults to [mean, median, min, max, std, 25% quantile, 75% quantile]
51+
series_transformers : TransformerMixin, list, tuple, or None, default=None
52+
The transformers to apply to the series before extracting intervals and
53+
shapelets. If None, use the series as is.
54+
55+
A list or tuple of transformers will extract intervals from
56+
all transformations concatenate the output. Including None in the list or tuple
57+
will use the series as is for interval extraction.
5058
dilation : int, list or None, default=None
5159
Add dilation to extracted intervals. No dilation is added if None or 1. If a
5260
list of ints, a random dilation value is selected from the list for each
@@ -110,6 +118,7 @@ def __init__(
110118
min_interval_length=3,
111119
max_interval_length=np.inf,
112120
features=None,
121+
series_transformers=None,
113122
dilation=None,
114123
estimator=None,
115124
n_jobs=1,
@@ -120,6 +129,7 @@ def __init__(
120129
self.min_interval_length = min_interval_length
121130
self.max_interval_length = max_interval_length
122131
self.features = features
132+
self.series_transformers = series_transformers
123133
self.dilation = dilation
124134
self.estimator = estimator
125135
self.random_state = random_state
@@ -159,17 +169,42 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
159169
return self
160170

161171
self._n_jobs = check_n_jobs(self.n_jobs)
172+
rng = check_random_state(self.random_state)
162173

163-
self._transformer = RandomIntervalTransformer(
164-
n_intervals=self.n_intervals,
165-
min_interval_length=self.min_interval_length,
166-
max_interval_length=self.max_interval_length,
167-
features=self.features,
168-
dilation=self.dilation,
169-
random_state=self.random_state,
170-
n_jobs=self._n_jobs,
171-
parallel_backend=self.parallel_backend,
172-
)
174+
if isinstance(self.series_transformers, (list, tuple)):
175+
self._series_transformers = [
176+
None if st is None else _clone_estimator(st, random_state=rng)
177+
for st in self.series_transformers
178+
]
179+
else:
180+
self._series_transformers = [
181+
None
182+
if self.series_transformers is None
183+
else _clone_estimator(self.series_transformers, random_state=rng)
184+
]
185+
186+
X_t = np.empty((X.shape[0], 0))
187+
self._transformers = []
188+
for st in self._series_transformers:
189+
if st is not None:
190+
s = st.fit_transform(X, y)
191+
else:
192+
s = X
193+
194+
ct = RandomIntervalTransformer(
195+
n_intervals=self.n_intervals,
196+
min_interval_length=self.min_interval_length,
197+
max_interval_length=self.max_interval_length,
198+
features=self.features,
199+
dilation=self.dilation,
200+
n_jobs=self._n_jobs,
201+
parallel_backend=self.parallel_backend,
202+
)
203+
_set_random_states(ct, rng)
204+
self._transformers.append(ct)
205+
t = ct.fit_transform(s, y)
206+
207+
X_t = np.hstack((X_t, t))
173208

174209
self._estimator = _clone_estimator(
175210
RandomForestClassifier(n_estimators=200)
@@ -182,7 +217,6 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
182217
if m is not None:
183218
self._estimator.n_jobs = self._n_jobs
184219

185-
X_t = self._transformer.fit_transform(X, y)
186220
self._estimator.fit(X_t, y)
187221

188222
return self
@@ -209,7 +243,17 @@ def predict(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
209243
X = self._validate_data(X=X, reset=False, ensure_min_series_length=3)
210244
X = self._convert_X(X)
211245

212-
return self._estimator.predict(self._transformer.transform(X))
246+
X_t = np.empty((X.shape[0], 0))
247+
for i, st in enumerate(self._series_transformers):
248+
if st is not None:
249+
s = st.transform(X)
250+
else:
251+
s = X
252+
253+
t = self._transformers[i].transform(s)
254+
X_t = np.hstack((X_t, t))
255+
256+
return self._estimator.predict(X_t)
213257

214258
def predict_proba(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
215259
"""Predicts labels probabilities for sequences in X.
@@ -233,12 +277,22 @@ def predict_proba(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
233277
X = self._validate_data(X=X, reset=False, ensure_min_series_length=3)
234278
X = self._convert_X(X)
235279

280+
X_t = np.empty((X.shape[0], 0))
281+
for i, st in enumerate(self._series_transformers):
282+
if st is not None:
283+
s = st.transform(X)
284+
else:
285+
s = X
286+
287+
t = self._transformers[i].transform(s)
288+
X_t = np.hstack((X_t, t))
289+
236290
m = getattr(self._estimator, "predict_proba", None)
237291
if callable(m):
238-
return self._estimator.predict_proba(self._transformer.transform(X))
292+
return self._estimator.predict_proba(X_t)
239293
else:
240294
dists = np.zeros((X.shape[0], self.n_classes_))
241-
preds = self._estimator.predict(self._transformer.transform(X))
295+
preds = self._estimator.predict(X_t)
242296
for i in range(0, X.shape[0]):
243297
dists[i, self.class_dictionary_[preds[i]]] = 1
244298
return dists
@@ -290,6 +344,13 @@ class RandomIntervalRegressor(RegressorMixin, BaseTimeSeriesEstimator):
290344
of said transformers and functions, default=None
291345
Transformers and functions used to extract features from selected intervals.
292346
If None, defaults to [mean, median, min, max, std, 25% quantile, 75% quantile]
347+
series_transformers : TransformerMixin, list, tuple, or None, default=None
348+
The transformers to apply to the series before extracting intervals and
349+
shapelets. If None, use the series as is.
350+
351+
A list or tuple of transformers will extract intervals from
352+
all transformations concatenate the output. Including None in the list or tuple
353+
will use the series as is for interval extraction.
293354
dilation : int, list or None, default=None
294355
Add dilation to extracted intervals. No dilation is added if None or 1. If a
295356
list of ints, a random dilation value is selected from the list for each
@@ -338,8 +399,8 @@ class RandomIntervalRegressor(RegressorMixin, BaseTimeSeriesEstimator):
338399
>>> reg.fit(X, y)
339400
RandomIntervalRegressor(...)
340401
>>> reg.predict(X)
341-
array([0.46836751, 1.32023847, 1.13355919, 0.63979608, 0.58309353,
342-
1.18197903, 0.57859747, 1.0772939 ])
402+
array([0.44924979, 1.31424037, 1.11951504, 0.63780969, 0.58123516,
403+
1.17135463, 0.56450198, 1.10128837])
343404
"""
344405

345406
def __init__(
@@ -348,6 +409,7 @@ def __init__(
348409
min_interval_length=3,
349410
max_interval_length=np.inf,
350411
features=None,
412+
series_transformers=None,
351413
dilation=None,
352414
estimator=None,
353415
n_jobs=1,
@@ -358,6 +420,7 @@ def __init__(
358420
self.min_interval_length = min_interval_length
359421
self.max_interval_length = max_interval_length
360422
self.features = features
423+
self.series_transformers = series_transformers
361424
self.dilation = dilation
362425
self.estimator = estimator
363426
self.random_state = random_state
@@ -389,17 +452,42 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
389452
self.n_instances_, self.n_channels_, self.n_timepoints_ = X.shape
390453

391454
self._n_jobs = check_n_jobs(self.n_jobs)
455+
rng = check_random_state(self.random_state)
392456

393-
self._transformer = RandomIntervalTransformer(
394-
n_intervals=self.n_intervals,
395-
min_interval_length=self.min_interval_length,
396-
max_interval_length=self.max_interval_length,
397-
features=self.features,
398-
dilation=self.dilation,
399-
random_state=self.random_state,
400-
n_jobs=self._n_jobs,
401-
parallel_backend=self.parallel_backend,
402-
)
457+
if isinstance(self.series_transformers, (list, tuple)):
458+
self._series_transformers = [
459+
None if st is None else _clone_estimator(st, random_state=rng)
460+
for st in self.series_transformers
461+
]
462+
else:
463+
self._series_transformers = [
464+
None
465+
if self.series_transformers is None
466+
else _clone_estimator(self.series_transformers, random_state=rng)
467+
]
468+
469+
X_t = np.empty((X.shape[0], 0))
470+
self._transformers = []
471+
for st in self._series_transformers:
472+
if st is not None:
473+
s = st.fit_transform(X, y)
474+
else:
475+
s = X
476+
477+
ct = RandomIntervalTransformer(
478+
n_intervals=self.n_intervals,
479+
min_interval_length=self.min_interval_length,
480+
max_interval_length=self.max_interval_length,
481+
features=self.features,
482+
dilation=self.dilation,
483+
n_jobs=self._n_jobs,
484+
parallel_backend=self.parallel_backend,
485+
)
486+
_set_random_states(ct, rng)
487+
self._transformers.append(ct)
488+
t = ct.fit_transform(s, y)
489+
490+
X_t = np.hstack((X_t, t))
403491

404492
self._estimator = _clone_estimator(
405493
RandomForestRegressor(n_estimators=200)
@@ -412,7 +500,6 @@ def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
412500
if m is not None:
413501
self._estimator.n_jobs = self._n_jobs
414502

415-
X_t = self._transformer.fit_transform(X, y)
416503
self._estimator.fit(X_t, y)
417504

418505
return self
@@ -435,7 +522,17 @@ def predict(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
435522
X = self._validate_data(X=X, reset=False, ensure_min_series_length=3)
436523
X = self._convert_X(X)
437524

438-
return self._estimator.predict(self._transformer.transform(X))
525+
X_t = np.empty((X.shape[0], 0))
526+
for i, st in enumerate(self._series_transformers):
527+
if st is not None:
528+
s = st.transform(X)
529+
else:
530+
s = X
531+
532+
t = self._transformers[i].transform(s)
533+
X_t = np.hstack((X_t, t))
534+
535+
return self._estimator.predict(X_t)
439536

440537
@classmethod
441538
def get_test_params(

tsml/interval_based/tests/test_interval_pipelines.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Tests for the interval pipeline classes."""
22

33
from tsml.interval_based import RandomIntervalClassifier
4+
from tsml.transformations import FunctionTransformer
5+
from tsml.utils.numba_functions.general import first_order_differences_3d
46
from tsml.utils.testing import generate_3d_test_data
57

68

@@ -16,4 +18,21 @@ def interval_func(X):
1618
)
1719
est.fit(X, y)
1820

19-
assert est._transformer._n_intervals == 2
21+
assert est._transformers[0]._n_intervals == 2
22+
23+
24+
def test_random_interval_series_transform_callable():
25+
"""Test RandomIntervalClassifier with a series transformer."""
26+
X, y = generate_3d_test_data()
27+
28+
est = RandomIntervalClassifier(
29+
n_intervals=2,
30+
series_transformers=[
31+
None,
32+
FunctionTransformer(func=first_order_differences_3d, validate=False),
33+
],
34+
)
35+
est.fit(X, y)
36+
est.predict_proba(X)
37+
38+
assert len(est._transformers) == 2

0 commit comments

Comments
 (0)