Skip to content

Commit b10caf4

Browse files
[MNT] Fixes, Py12 and release 0.4.0 (#46)
* fixes * commed mypy * c22 * transform no longer requires y * fixes * fixes * fixes * v0.3.0 * yml * py 12 and fixes * readme * pyfftw bound * fda bound again * mrsqm * skip pyfftw docs * fix
1 parent 00d9f8b commit b10caf4

File tree

21 files changed

+171
-169
lines changed

21 files changed

+171
-169
lines changed

.github/workflows/precommit_autoupdate.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ jobs:
3434
branch: pre-commit-hooks-update
3535
title: "[MNT] Automated `pre-commit` hook update"
3636
body: "Automated weekly update to `.pre-commit-config.yaml` hook versions."
37-
labels: maintenance, full pre-commit, no changelog
37+
labels: maintenance, full pre-commit

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
strategy:
5353
matrix:
5454
os: [ ubuntu-22.04, macOS-12, windows-2022 ]
55-
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
55+
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
5656

5757
steps:
5858
- uses: actions/checkout@v4

.github/workflows/tests.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
fail-fast: false
2929
matrix:
3030
os: [ ubuntu-22.04, macOS-12, windows-2022 ]
31-
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
31+
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
3232

3333
steps:
3434
- uses: actions/checkout@v4
@@ -38,7 +38,11 @@ jobs:
3838
python-version: ${{ matrix.python-version }}
3939

4040
- name: Install
41-
run: python -m pip install .[dev,all_extras,unstable_extras]
41+
uses: nick-fields/retry@v2
42+
with:
43+
timeout_minutes: 30
44+
max_attempts: 3
45+
command: python -m pip install .[dev,all_extras,unstable_extras]
4246

4347
- name: Tests
4448
run: python -m pytest

.pre-commit-config.yaml

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
33
rev: v4.5.0
44
hooks:
5+
- id: check-ast
56
- id: check-added-large-files
67
args: ["--maxkb=10000"]
7-
- id: check-ast
88
- id: check-case-conflict
99
- id: check-docstring-first
1010
- id: check-merge-conflict
@@ -21,26 +21,44 @@ repos:
2121
- id: requirements-txt-fixer
2222
- id: trailing-whitespace
2323

24+
- repo: https://github.com/lk16/detect-missing-init
25+
rev: v0.1.6
26+
hooks:
27+
- id: detect-missing-init
28+
args: [ "--create", "--python-folders", "tsml" ]
29+
30+
- repo: https://github.com/astral-sh/ruff-pre-commit
31+
rev: v0.4.1
32+
hooks:
33+
- id: ruff
34+
args: [ "--fix" ]
35+
36+
- repo: https://github.com/asottile/pyupgrade
37+
rev: v3.15.2
38+
hooks:
39+
- id: pyupgrade
40+
args: [ "--py38-plus" ]
41+
2442
- repo: https://github.com/pycqa/isort
2543
rev: 5.13.2
2644
hooks:
2745
- id: isort
2846
name: isort (python)
2947
args: [ "--profile=black", "--multi-line=3" ]
3048

31-
- repo: https://github.com/psf/black
32-
rev: 23.12.1
33-
hooks:
34-
- id: black
35-
language_version: python3
36-
3749
- repo: https://github.com/pycqa/flake8
3850
rev: 7.0.0
3951
hooks:
4052
- id: flake8
4153
additional_dependencies: [ flake8-bugbear, flake8-print, Flake8-pyproject ]
4254
args: [ "--max-line-length=88", "--extend-ignore=E203" ]
4355

56+
- repo: https://github.com/psf/black
57+
rev: 23.12.1
58+
hooks:
59+
- id: black
60+
language_version: python3
61+
4462
- repo: https://github.com/nbQA-dev/nbQA
4563
rev: 1.7.1
4664
hooks:
@@ -54,20 +72,6 @@ repos:
5472
additional_dependencies: [ flake8 ]
5573
args: [ "--nbqa-dont-skip-bad-cells", "--extend-ignore=E402,E203", "--max-line-length=88" ]
5674

57-
- repo: https://github.com/pycqa/pydocstyle
58-
rev: 6.3.0
59-
hooks:
60-
- id: pydocstyle
61-
args: ["--convention=numpy"]
62-
additional_dependencies: [ toml, tomli ]
63-
64-
# - repo: https://github.com/pre-commit/mirrors-mypy
65-
# rev: v1.8.0
66-
# hooks:
67-
# - id: mypy
68-
# files: tsml/
69-
# additional_dependencies: [ pytest ]
70-
7175
- repo: https://github.com/mgedmin/check-manifest
7276
rev: "0.49"
7377
hooks:

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313

1414
A toolkit for time series machine learning algorithms.
1515

16-
The current release of `tsml` is v0.3.0.
16+
Please see [`tsml_eval`](https://github.com/time-series-machine-learning/tsml-eval) and
17+
[`aeon`](https://github.com/aeon-toolkit/aeon) for more developed packages. This package
18+
is more of a sandbox for testing out new ideas and algorithms. It may contain some
19+
algorithms and implementations that are not available in the other toolkits.
20+
21+
The current release of `tsml` is v0.4.0.
1722

1823
## Installation
1924

pyproject.toml

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tsml"
7-
version = "0.3.0"
7+
version = "0.4.0"
88
description = "A toolkit for time series machine learning algorithms."
99
authors = [
1010
{name = "Matthew Middlehurst", email = "m.b.middlehurst@soton.ac.uk"},
@@ -13,7 +13,7 @@ maintainers = [
1313
{name = "Matthew Middlehurst", email = "m.b.middlehurst@soton.ac.uk"},
1414
]
1515
readme = "README.md"
16-
requires-python = ">=3.8,<3.12"
16+
requires-python = ">=3.8,<3.13"
1717
keywords = [
1818
"data-science",
1919
"machine-learning",
@@ -36,6 +36,7 @@ classifiers = [
3636
"Programming Language :: Python :: 3.9",
3737
"Programming Language :: Python :: 3.10",
3838
"Programming Language :: Python :: 3.11",
39+
"Programming Language :: Python :: 3.12",
3940
]
4041
dependencies = [
4142
"numba>=0.55.0",
@@ -47,25 +48,26 @@ dependencies = [
4748

4849
[project.optional-dependencies]
4950
all_extras = [
50-
"pyfftw>=0.12.0",
51-
"statsmodels>=0.12.1",
52-
"wildboar>=1.1.0",
51+
"grailts",
5352
"scikit-fda>=0.7.0",
53+
"statsmodels>=0.12.1",
5454
"stumpy>=1.6.0",
55-
"grailts",
55+
"wildboar>=1.1.0",
56+
5657
# temp
57-
"fdasrsf<=2.5.2", # currently above this breaks on some OS
58+
"fdasrsf<=2.5.2", # temporary, currently above this breaks on some OS
5859
]
5960
unstable_extras = [
60-
"mrsqm>=0.0.1 ; platform_system == 'Darwin'", # requires gcc and fftw to be installed for Windows and some other OS (see http://www.fftw.org/index.html)
61+
"mrsqm>=0.0.1 ; platform_system == 'Darwin' and python_version < '3.12'", # requires gcc and fftw to be installed for Windows and some other OS (see http://www.fftw.org/index.html)
6162
"pycatch22<=0.4.3", # Known to fail installation on some setups
63+
"pyfftw>=0.12.0; python_version < '3.12'", # requires fftw to be installed for Windows and some other OS (see http://www.fftw.org/index.html)
6264
]
6365
dev = [
6466
"pre-commit",
6567
"pytest",
6668
"pytest-randomly",
6769
"pytest-timeout",
68-
"pytest-xdist",
70+
"pytest-xdist[psutil]",
6971
"pytest-cov",
7072
"pytest-rerunfailures",
7173
]
@@ -74,7 +76,7 @@ binder = [
7476
"jupyterlab",
7577
]
7678
docs = [
77-
"sphinx<8.0.0",
79+
"sphinx<7.3.0",
7880
"sphinx-design",
7981
"sphinx-version-warning",
8082
"sphinx_issues",
@@ -106,13 +108,21 @@ ignore = [
106108
"local/**",
107109
]
108110

111+
[tool.ruff.lint]
112+
select = ["D"]
113+
114+
[tool.ruff.lint.pydocstyle]
115+
convention = "numpy"
116+
109117
[tool.pytest.ini_options]
110118
testpaths = "tsml"
111119
addopts = '''
120+
--doctest-modules
112121
--durations 20
113122
--timeout 600
114123
--showlocals
115-
--doctest-modules
116-
--numprocesses auto
124+
--numprocesses logical
125+
--dist worksteal
117126
--reruns 2
127+
--only-rerun "crashed while running"
118128
'''

tsml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""tsml."""
22

3-
__version__ = "0.3.0"
3+
__version__ = "0.4.0"

tsml/distance_based/_grail.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import os
88
import sys
9+
import warnings
910

1011
import numpy as np
1112
from sklearn.base import ClassifierMixin
@@ -38,7 +39,7 @@ def __init__(self, classifier="svm"):
3839

3940
_check_optional_dependency("grailts", "GRAIL", self)
4041

41-
super(GRAILClassifier, self).__init__()
42+
super().__init__()
4243

4344
def fit(self, X, y):
4445
"""Fit the estimator to training data.
@@ -85,11 +86,19 @@ def fit(self, X, y):
8586
) = self._modified_GRAIL_rep_fit(X, self._d)
8687

8788
if self.classifier == "svm":
88-
self._clf = GridSearchCV(
89-
SVC(kernel="linear", probability=True),
90-
param_grid={"C": [i**2 for i in np.arange(-10, 20, 0.11)]},
91-
cv=min(min(class_count), 5),
92-
)
89+
cv = min(min(class_count), 5)
90+
if cv == 1:
91+
warnings.warn(
92+
"Only one class was found in y, so no cross-validation.",
93+
stacklevel=2,
94+
)
95+
self._clf = SVC(kernel="linear", probability=True, C=1)
96+
else:
97+
self._clf = GridSearchCV(
98+
SVC(kernel="linear", probability=True),
99+
param_grid={"C": [i**2 for i in np.arange(-10, 20, 0.11)]},
100+
cv=min(min(class_count), 5),
101+
)
93102
self._clf.fit(Xt, y)
94103
elif self.classifier == "knn":
95104
self._train_Xt = Xt

tsml/hybrid/_rist.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ class RISTClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
112112
>>> from tsml.hybrid import RISTClassifier
113113
>>> from tsml.utils.testing import generate_3d_test_data
114114
>>> X, y = generate_3d_test_data(n_samples=8, series_length=10, random_state=0)
115-
>>> clf = RISTClassifier(random_state=0)
116-
>>> clf.fit(X, y)
115+
>>> clf = RISTClassifier(random_state=0) # doctest: +SKIP
116+
>>> clf.fit(X, y) # doctest: +SKIP
117117
RISTClassifier(...)
118-
>>> clf.predict(X)
118+
>>> clf.predict(X) # doctest: +SKIP
119119
array([0, 1, 1, 0, 0, 1, 0, 1])
120120
"""
121121

@@ -144,7 +144,7 @@ def __init__(
144144
if use_pyfftw:
145145
_check_optional_dependency("pyfftw", "pyfftw", self)
146146

147-
super(RISTClassifier, self).__init__()
147+
super().__init__()
148148

149149
def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
150150
"""Fit the estimator to training data.
@@ -363,10 +363,10 @@ class RISTRegressor(RegressorMixin, BaseTimeSeriesEstimator):
363363
>>> from tsml.utils.testing import generate_3d_test_data
364364
>>> X, y = generate_3d_test_data(n_samples=8, series_length=10,
365365
... regression_target=True, random_state=0)
366-
>>> reg = RISTRegressor(random_state=0)
367-
>>> reg.fit(X, y)
366+
>>> reg = RISTRegressor(random_state=0) # doctest: +SKIP
367+
>>> reg.fit(X, y) # doctest: +SKIP
368368
RISTRegressor(...)
369-
>>> reg.predict(X)
369+
>>> reg.predict(X) # doctest: +SKIP
370370
array([0.31798318, 1.41426301, 1.06414747, 0.6924721 , 0.56660146,
371371
1.26538944, 0.52324808, 1.0939405 ])
372372
"""
@@ -396,7 +396,7 @@ def __init__(
396396
if use_pyfftw:
397397
_check_optional_dependency("pyfftw", "pyfftw", self)
398398

399-
super(RISTRegressor, self).__init__()
399+
super().__init__()
400400

401401
def fit(self, X: Union[np.ndarray, List[np.ndarray]], y: np.ndarray) -> object:
402402
"""Fit the estimator to training data.

tsml/interval_based/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class BaseIntervalForest(BaseTimeSeriesEstimator, metaclass=ABCMeta):
129129
contract_max_n_estimators : int, default=500
130130
Max number of estimators when time_limit_in_minutes is set.
131131
save_transformed_data : bool, default=False
132-
Save the data transformed in fit for use in _get_train_probs.
132+
Save the data transformed in fit.
133133
random_state : int, RandomState instance or None, default=None
134134
If `int`, random_state is the seed used by the random number generator;
135135
If `RandomState` instance, random_state is the random number generator;
@@ -211,7 +211,7 @@ def __init__(
211211
self.n_jobs = n_jobs
212212
self.parallel_backend = parallel_backend
213213

214-
super(BaseIntervalForest, self).__init__()
214+
super().__init__()
215215

216216
# if subsampling attributes, an interval_features transformer must contain a
217217
# parameter name from transformer_feature_selection and an attribute name

0 commit comments

Comments
 (0)