Skip to content

Commit 3cfb4e2

Browse files
V0.2.0 (#15)
* non determinstic and dep defaults * collection func * example output * example output * examples and non-deterministic tests * test file rename * test file rename * test deterministic * interval bugfix
1 parent 555b045 commit 3cfb4e2

File tree

11 files changed

+73
-59
lines changed

11 files changed

+73
-59
lines changed

tsml/feature_based/_catch22.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class Catch22Classifier(ClassifierMixin, BaseTimeSeriesEstimator):
5252
while to process for large values.
5353
replace_nans : bool, optional, default=True
5454
Replace NaN or inf values from the Catch22 transform with 0.
55-
use_pycatch22 : bool, optional, default=False
55+
use_pycatch22 : bool, optional, default=True
5656
Wraps the C based pycatch22 implementation for tsml.
5757
(https://github.com/DynamicsAndNeuralSystems/pycatch22). This requires the
5858
``pycatch22`` package to be installed if True.
@@ -125,7 +125,7 @@ def __init__(
125125
catch24=True,
126126
outlier_norm=False,
127127
replace_nans=True,
128-
use_pycatch22=False,
128+
use_pycatch22=True,
129129
estimator=None,
130130
random_state=None,
131131
n_jobs=1,
@@ -385,8 +385,8 @@ class Catch22Regressor(RegressorMixin, BaseTimeSeriesEstimator):
385385
>>> reg.fit(X, y)
386386
Catch22Regressor(...)
387387
>>> reg.predict(X)
388-
array([0.42955043, 1.31287811, 1.03757454, 0.68456511, 0.61327938,
389-
1.2048977 , 0.56586089, 1.1263876 ])
388+
array([0.44505834, 1.28376726, 1.09799075, 0.64209462, 0.59410108,
389+
1.1746538 , 0.70590611, 1.13361721])
390390
"""
391391

392392
def __init__(

tsml/feature_based/_fpca.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,6 @@ class FPCAClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
7474
--------
7575
FPCATransformer
7676
FPCARegressor
77-
78-
Examples
79-
--------
80-
>>> from tsml.feature_based import FPCAClassifier
81-
>>> from tsml.utils.testing import generate_3d_test_data
82-
>>> X, y = generate_3d_test_data(n_samples=8, series_length=10, random_state=0)
83-
>>> clf = FPCAClassifier(random_state=0, n_components=6)
84-
>>> clf.fit(X, y)
85-
FPCAClassifier(...)
86-
>>> clf.predict(X)
87-
array([0, 1, 1, 0, 0, 1, 0, 1])
8877
"""
8978

9079
def __init__(
@@ -302,19 +291,6 @@ class FPCARegressor(RegressorMixin, BaseTimeSeriesEstimator):
302291
--------
303292
FPCATransformer
304293
FPCAClassifier
305-
306-
Examples
307-
--------
308-
>>> from tsml.feature_based import FPCARegressor
309-
>>> from tsml.utils.testing import generate_3d_test_data
310-
>>> X, y = generate_3d_test_data(n_samples=8, series_length=10,
311-
... regression_target=True, random_state=0)
312-
>>> reg = FPCARegressor(random_state=0, n_components=6)
313-
>>> reg.fit(X, y)
314-
FPCARegressor(...)
315-
>>> reg.predict(X)
316-
array([0.31804196, 1.4151935 , 1.06572351, 0.68621331, 0.56749254,
317-
1.26541066, 0.52730157, 1.09266818])
318294
"""
319295

320296
def __init__(

tsml/hybrid/_rist.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ class RISTClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
6767
A list or tuple of transformers will extract intervals from
6868
all transformations concatenate the output. Including None in the list or tuple
6969
will use the series as is for interval extraction.
70+
use_pycatch22 : bool, optional, default=True
71+
Wraps the C based pycatch22 implementation for aeon.
72+
(https://github.com/DynamicsAndNeuralSystems/pycatch22). This requires the
73+
``pycatch22`` package to be installed if True.
74+
use_pyfftw : bool, default=True
75+
Whether to use the pyfftw library for FFT calculations. Requires the pyfftw
76+
package to be installed.
7077
estimator : sklearn classifier, default=None
7178
An sklearn estimator to be built using the transformed data. Defaults to an
7279
ExtraTreesClassifier with 200 trees.
@@ -117,17 +124,17 @@ def __init__(
117124
n_intervals=None,
118125
n_shapelets=None,
119126
series_transformers="default",
120-
use_pyfftw=False,
121-
use_pycatch22=False,
127+
use_pycatch22=True,
128+
use_pyfftw=True,
122129
estimator=None,
123130
n_jobs=1,
124131
random_state=None,
125132
):
126133
self.n_intervals = n_intervals
127134
self.n_shapelets = n_shapelets
128135
self.series_transformers = series_transformers
129-
self.use_pyfftw = use_pyfftw
130136
self.use_pycatch22 = use_pycatch22
137+
self.use_pyfftw = use_pyfftw
131138
self.estimator = estimator
132139
self.random_state = random_state
133140
self.n_jobs = n_jobs
@@ -251,6 +258,7 @@ def predict_proba(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
251258
def _more_tags(self) -> dict:
252259
return {
253260
"optional_dependency": self.use_pycatch22 or self.use_pyfftw,
261+
"non_deterministic": True,
254262
}
255263

256264
@classmethod
@@ -315,6 +323,13 @@ class RISTRegressor(RegressorMixin, BaseTimeSeriesEstimator):
315323
A list or tuple of transformers will extract intervals from
316324
all transformations concatenate the output. Including None in the list or tuple
317325
will use the series as is for interval extraction.
326+
use_pycatch22 : bool, optional, default=True
327+
Wraps the C based pycatch22 implementation for aeon.
328+
(https://github.com/DynamicsAndNeuralSystems/pycatch22). This requires the
329+
``pycatch22`` package to be installed if True.
330+
use_pyfftw : bool, default=True
331+
Whether to use the pyfftw library for FFT calculations. Requires the pyfftw
332+
package to be installed.
318333
estimator : sklearn classifier, default=None
319334
An sklearn estimator to be built using the transformed data. Defaults to an
320335
ExtraTreesRegressor with 200 trees.
@@ -361,17 +376,17 @@ def __init__(
361376
n_intervals=None,
362377
n_shapelets=None,
363378
series_transformers="default",
364-
use_pyfftw=False,
365-
use_pycatch22=False,
379+
use_pycatch22=True,
380+
use_pyfftw=True,
366381
estimator=None,
367382
n_jobs=1,
368383
random_state=None,
369384
):
370385
self.n_intervals = n_intervals
371386
self.n_shapelets = n_shapelets
372387
self.series_transformers = series_transformers
373-
self.use_pyfftw = use_pyfftw
374388
self.use_pycatch22 = use_pycatch22
389+
self.use_pyfftw = use_pyfftw
375390
self.estimator = estimator
376391
self.random_state = random_state
377392
self.n_jobs = n_jobs
@@ -458,6 +473,7 @@ def predict(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
458473
def _more_tags(self) -> dict:
459474
return {
460475
"optional_dependency": self.use_pycatch22 or self.use_pyfftw,
476+
"non_deterministic": True,
461477
}
462478

463479
@classmethod

tsml/shapelet_based/_rdst.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ class RDSTClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
119119
>>> clf = RDSTClassifier(random_state=0)
120120
>>> clf.fit(X, y)
121121
RDSTClassifier(...)
122-
>>> clf.predict(X)
123-
array([0, 1, 1, 0, 0, 1, 0, 1])
122+
>>> pred = clf.predict(X)
124123
"""
125124

126125
def __init__(
@@ -292,6 +291,9 @@ def predict_proba(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
292291
dists[i, self.class_dictionary_[preds[i]]] = 1
293292
return dists
294293

294+
def _more_tags(self) -> dict:
295+
return {"non_deterministic": True}
296+
295297
@classmethod
296298
def get_test_params(
297299
cls, parameter_set: Union[str, None] = None
@@ -401,9 +403,7 @@ class RDSTRegressor(RegressorMixin, BaseTimeSeriesEstimator):
401403
>>> reg = RDSTRegressor(random_state=0)
402404
>>> reg.fit(X, y)
403405
RDSTRegressor(...)
404-
>>> reg.predict(X)
405-
array([0.31798367, 1.41426266, 1.06414746, 0.69247204, 0.56660161,
406-
1.26538904, 0.52324829, 1.09394045])
406+
>>> pred = reg.predict(X)
407407
"""
408408

409409
def __init__(
@@ -516,6 +516,9 @@ def predict(self, X: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
516516

517517
return self._estimator.predict(X_t)
518518

519+
def _more_tags(self) -> dict:
520+
return {"non_deterministic": True}
521+
519522
@classmethod
520523
def get_test_params(
521524
cls, parameter_set: Union[str, None] = None

tsml/tests/estimator_checks.py renamed to tsml/tests/test_estimator_checks.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""Checks for all estimators in tsml."""
32

43
__author__ = ["MatthewMiddlehurst"]
@@ -17,7 +16,7 @@
1716
check_set_params,
1817
)
1918

20-
import tsml.tests._sklearn_checks as patched_checks
19+
import tsml.tests.test_estimators_sklearn as patched_checks
2120
import tsml.utils.testing as test_utils
2221
from tsml.base import _clone_estimator
2322
from tsml.utils._tags import _safe_tags
@@ -32,6 +31,7 @@ def _yield_all_time_series_checks(estimator):
3231
warnings.warn(
3332
f"Explicit SKIP via _skip_test tag for estimator {name}.",
3433
SkipTestWarning,
34+
stacklevel=2,
3535
)
3636
return
3737

@@ -56,7 +56,6 @@ def _yield_all_time_series_checks(estimator):
5656

5757

5858
def _yield_checks(estimator):
59-
"""sklearn"""
6059
tags = _safe_tags(estimator)
6160

6261
yield check_no_attributes_set_in_init
@@ -198,11 +197,11 @@ def check_estimator_input_types(name, estimator_orig):
198197
# test a single function with this priority
199198
def _get_func(est):
200199
if hasattr(est, "predict_proba"):
201-
return getattr(est, "predict_proba")
200+
return est.predict_proba
202201
elif hasattr(est, "predict"):
203-
return getattr(est, "predict")
202+
return est.predict
204203
elif hasattr(est, "transform"):
205-
return getattr(est, "transform")
204+
return est.transform
206205

207206
X, y = test_utils.generate_3d_test_data()
208207
first_result = None
@@ -251,31 +250,37 @@ def _get_func(est):
251250

252251
@ignore_warnings(category=FutureWarning)
253252
def check_fit3d_predict2d(name, estimator_orig):
253+
"""Todo."""
254254
pass
255255

256256

257257
@ignore_warnings(category=FutureWarning)
258258
def check_estimator_cannot_handle_multivariate_data(name, estimator_orig):
259+
"""Todo."""
259260
pass
260261

261262

262263
@ignore_warnings(category=FutureWarning)
263264
def check_estimator_handles_multivariate_data(name, estimator_orig):
265+
"""Todo."""
264266
pass
265267

266268

267269
@ignore_warnings(category=FutureWarning)
268270
def check_estimator_cannot_handle_unequal_data(name, estimator_orig):
271+
"""Todo."""
269272
pass
270273

271274

272275
@ignore_warnings(category=FutureWarning)
273276
def check_estimator_handles_unequal_data(name, estimator_orig):
277+
"""Todo."""
274278
pass
275279

276280

277281
@ignore_warnings(category=FutureWarning)
278282
def check_n_features_unequal(name, estimator_orig):
283+
"""Todo."""
279284
pass
280285

281286

tsml/tests/_sklearn_checks.py renamed to tsml/tests/test_estimators_sklearn.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
"""Patched estimator checks originating from scikit-learn."""
32

43
__author__ = ["MatthewMiddlehurst"]
@@ -955,6 +954,7 @@ def check_supervised_y_2d(name, estimator_orig):
955954
warnings.simplefilter("always", DataConversionWarning)
956955
warnings.simplefilter("ignore", RuntimeWarning)
957956
estimator.fit(X, y[:, np.newaxis])
957+
958958
y_pred_2d = estimator.predict(X)
959959
msg = "expected 1 DataConversionWarning, got: %s" % ", ".join(
960960
[str(w_x) for w_x in w]
@@ -966,6 +966,9 @@ def check_supervised_y_2d(name, estimator_orig):
966966
" was passed when a 1d array was expected" in msg
967967
)
968968

969+
if _safe_tags(estimator_orig, key="non_deterministic"):
970+
raise SkipTest(name + " is non deterministic")
971+
969972
assert_allclose(y_pred.ravel(), y_pred_2d.ravel())
970973

971974

@@ -1077,6 +1080,10 @@ def check_regressors_int(name, regressor_orig):
10771080
pred1 = regressor_1.predict(X)
10781081
regressor_2.fit(X, y.astype(float))
10791082
pred2 = regressor_2.predict(X)
1083+
1084+
if _safe_tags(regressor_orig, key="non_deterministic"):
1085+
raise SkipTest(name + " is non deterministic")
1086+
10801087
assert_allclose(pred1, pred2, atol=1e-2, err_msg=name)
10811088

10821089

@@ -1223,6 +1230,10 @@ def _check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type):
12231230
pred1 = estimator_1.predict(X_)
12241231
estimator_2.fit(X, y)
12251232
pred2 = estimator_2.predict(X)
1233+
1234+
if _safe_tags(estimator_orig, key="non_deterministic"):
1235+
raise SkipTest(name + " is non deterministic")
1236+
12261237
assert_allclose(pred1, pred2, atol=1e-2, err_msg=name)
12271238

12281239

tsml/transformations/_catch22.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class Catch22Transformer(TransformerMixin, BaseTimeSeriesEstimator):
8282
while to process for large values.
8383
replace_nans : bool, optional, default=False
8484
Replace NaN or inf values from the Catch22 transform with 0.
85-
use_pycatch22 : bool, optional, default=False
85+
use_pycatch22 : bool, optional, default=True
8686
Wraps the C based pycatch22 implementation for tsml.
8787
(https://github.com/DynamicsAndNeuralSystems/pycatch22). This requires the
8888
``pycatch22`` package to be installed if True.
@@ -126,12 +126,12 @@ class Catch22Transformer(TransformerMixin, BaseTimeSeriesEstimator):
126126
>>> tnf.fit(X)
127127
Catch22Transformer(...)
128128
>>> print(tnf.transform(X)[0])
129-
[1.15639532e+00 1.31700575e+00 3.00000000e+00 2.00000000e-01
130-
0.00000000e+00 1.00000000e+00 2.00000000e+00 1.10933565e-32
131-
1.96349541e+00 5.10744398e-01 2.33853577e-01 3.89048349e-01
132-
2.00000000e+00 1.00000000e+00 4.00000000e+00 1.88915916e+00
133-
1.00000000e+00 1.70859420e-01 0.00000000e+00 0.00000000e+00
134-
2.46913580e-02 0.00000000e+00]
129+
[6.27596874e-02 3.53871087e-01 4.00000000e+00 7.00000000e-01
130+
2.00000000e-01 5.66227710e-01 2.00000000e+00 3.08148791e-34
131+
1.96349541e+00 9.99913411e-01 1.39251594e+00 3.89048349e-01
132+
2.00000000e+00 1.00000000e+00 3.00000000e+00 2.04319187e+00
133+
1.00000000e+00 2.44474814e-01 0.00000000e+00 0.00000000e+00
134+
8.23045267e-03 0.00000000e+00]
135135
"""
136136

137137
def __init__(
@@ -140,7 +140,7 @@ def __init__(
140140
catch24=False,
141141
outlier_norm=False,
142142
replace_nans=False,
143-
use_pycatch22=False,
143+
use_pycatch22=True,
144144
n_jobs=1,
145145
parallel_backend=None,
146146
):

tsml/transformations/_interval_extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def set_features_to_transform(self, arr, raise_error=True):
465465
else:
466466
length += 1
467467

468-
if len(arr) != length * self.n_intervals or not all(
468+
if len(arr) != length * self.n_intervals_ or not all(
469469
isinstance(b, bool) for b in arr
470470
):
471471
if raise_error:

tsml/transformations/_shapelet_transform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,10 @@ def _check_input_params(self):
10041004
self.threshold_percentiles_ = np.asarray(self.threshold_percentiles_)
10051005

10061006
def _more_tags(self) -> dict:
1007-
return {"requires_y": True}
1007+
return {
1008+
"requires_y": True,
1009+
"non_deterministic": True,
1010+
}
10081011

10091012
@classmethod
10101013
def get_test_params(cls, parameter_set="default"):

tsml/transformations/tests/test_interval_extraction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ def test_supervised_transformers():
5353
)
5454
X_t = sit.fit_transform(X, y)
5555

56-
assert X_t.shape == (X.shape[0], 7)
56+
assert X_t.shape == (X.shape[0], 8)

0 commit comments

Comments
 (0)