Skip to content

Commit 518f0a1

Browse files
Grail classifier (#25)
* fixes * commed mypy * c22 * transform no longer requires y * fixes * fixes * fixes * grail * grail * comment
1 parent 138b6bd commit 518f0a1

File tree

2 files changed

+266
-0
lines changed

2 files changed

+266
-0
lines changed

tsml/distance_based/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Distance-based estimators."""
22

33
__all__ = [
4+
"GRAILClassifier",
45
"ProximityForestClassifier",
56
"MPDistClassifier",
67
]
78

9+
from tsml.distance_based._grail import GRAILClassifier
810
from tsml.distance_based._mpdist import MPDistClassifier
911
from tsml.distance_based._pf import ProximityForestClassifier

tsml/distance_based/_grail.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
"""GRAIL classifier.
2+
3+
See the original implementation here:
4+
https://github.com/TheDatumOrg/grail-python
5+
"""
6+
7+
import os
8+
import sys
9+
10+
import numpy as np
11+
from sklearn.base import ClassifierMixin
12+
from sklearn.model_selection import GridSearchCV
13+
from sklearn.svm import SVC
14+
from sklearn.utils.multiclass import check_classification_targets
15+
from sklearn.utils.validation import check_is_fitted
16+
17+
from tsml.base import BaseTimeSeriesEstimator
18+
from tsml.utils.validation import _check_optional_dependency
19+
20+
21+
class GRAILClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
22+
"""
23+
GRAIL classifier.
24+
25+
Examples
26+
--------
27+
>>> from tsml.datasets import load_minimal_chinatown
28+
>>> from tsml.distance_based import GRAILClassifier
29+
>>> X, y = load_minimal_chinatown()
30+
>>> clf = GRAILClassifier()
31+
>>> clf.fit(X, y)
32+
GRAILClassifier(...)
33+
>>> preds = clf.predict(X)
34+
"""
35+
36+
def __init__(self, classifier="svm"):
37+
self.classifier = classifier
38+
39+
_check_optional_dependency("grailts", "GRAIL", self)
40+
41+
super(GRAILClassifier, self).__init__()
42+
43+
def fit(self, X, y):
44+
"""Fit the estimator to training data.
45+
46+
Parameters
47+
----------
48+
X : 2D np.ndarray of shape (n_instances, n_timepoints)
49+
The training data.
50+
y : 1D np.ndarray of shape (n_instances)
51+
The class labels for fitting, indices correspond to instance indices in X
52+
53+
Returns
54+
-------
55+
self :
56+
Reference to self.
57+
"""
58+
X, y = self._validate_data(X=X, y=y, ensure_min_samples=2)
59+
X = self._convert_X(X)
60+
61+
check_classification_targets(y)
62+
63+
self.n_instances_, self.n_timepoints_ = X.shape
64+
self.classes_, class_count = np.unique(y, return_counts=True)
65+
self.n_classes_ = self.classes_.shape[0]
66+
self.class_dictionary_ = {}
67+
for index, class_val in enumerate(self.classes_):
68+
self.class_dictionary_[class_val] = index
69+
70+
if self.n_classes_ == 1:
71+
return self
72+
73+
self._d = int(self.n_instances_ * 0.4)
74+
if self._d > 100:
75+
self._d = 100
76+
elif self._d < 3:
77+
self._d = 3
78+
79+
(
80+
Xt,
81+
self._Dictionary,
82+
self._gamma,
83+
self._eigenvecMatrix,
84+
self._inVa,
85+
) = self._modified_GRAIL_rep_fit(X, self._d)
86+
87+
if self.classifier == "svm":
88+
self._clf = GridSearchCV(
89+
SVC(kernel="linear", probability=True),
90+
param_grid={"C": [i**2 for i in np.arange(-10, 20, 0.11)]},
91+
cv=min(min(class_count), 5),
92+
)
93+
self._clf.fit(Xt, y)
94+
elif self.classifier == "knn":
95+
self._train_Xt = Xt
96+
self._train_y = y
97+
else:
98+
raise ValueError("classifier must be 'svm' or 'knn'")
99+
100+
return self
101+
102+
def predict(self, X):
103+
"""Predicts labels for sequences in X.
104+
105+
Parameters
106+
----------
107+
X : 2D np.array of shape (n_instances, n_timepoints)
108+
The testing data.
109+
110+
Returns
111+
-------
112+
y : array-like of shape (n_instances)
113+
Predicted class labels.
114+
"""
115+
return np.array(
116+
[self.classes_[int(np.argmax(prob))] for prob in self.predict_proba(X)]
117+
)
118+
119+
def predict_proba(self, X):
120+
"""Predicts labels probabilities for sequences in X.
121+
122+
Parameters
123+
----------
124+
X : 2D np.array of shape (n_instances, n_timepoints)
125+
The testing data.
126+
127+
Returns
128+
-------
129+
y : array-like of shape (n_instances, n_classes_)
130+
Predicted probabilities using the ordering in classes_.
131+
"""
132+
check_is_fitted(self)
133+
134+
# treat case of single class seen in fit
135+
if self.n_classes_ == 1:
136+
return np.repeat([[1]], X.shape[0], axis=0)
137+
138+
X = self._validate_data(X=X, reset=False)
139+
X = self._convert_X(X)
140+
141+
Xt = self._modified_GRAIL_rep_predict(
142+
X, self._d, self._Dictionary, self._gamma, self._eigenvecMatrix, self._inVa
143+
)
144+
145+
if self.classifier == "svm":
146+
probas = self._clf.predict_proba(Xt)
147+
elif self.classifier == "knn":
148+
from GRAIL.kNN import kNN
149+
150+
k = 5
151+
neighbors, _, _ = kNN(
152+
self._train_Xt,
153+
Xt,
154+
method="ED",
155+
k=k,
156+
representation=None,
157+
pq_method="opq",
158+
)
159+
160+
probas = np.zeros((len(X), self.n_classes_))
161+
for i, case in enumerate(neighbors):
162+
for j in range(k):
163+
probas[i, self.class_dictionary_[self._train_y[case[j]]]] += 1
164+
probas[i] /= k
165+
else:
166+
raise ValueError("classifier must be 'svm' or 'knn'")
167+
168+
return probas
169+
170+
@staticmethod
171+
def _modified_GRAIL_rep_fit(
172+
X,
173+
d,
174+
r=20,
175+
GV=None,
176+
fourier_coeff=-1,
177+
e=-1,
178+
eigenvecMatrix=None,
179+
inVa=None,
180+
gamma=None,
181+
initialization_method="k-shape++",
182+
):
183+
"""Fit GRAIL representation.
184+
185+
A modified version of the GRAIL_rep function from GRAIL.
186+
"""
187+
from GRAIL import exceptions
188+
from GRAIL.GRAIL_core import CheckNaNInfComplex, gamma_select
189+
from GRAIL.kshape import kshape_with_centroid_initialize, matlab_kshape
190+
from GRAIL.SINK import SINK
191+
192+
old_stdout = sys.stdout
193+
sys.stdout = open(os.devnull, "w")
194+
195+
n = X.shape[0]
196+
if initialization_method == "partition":
197+
[_, Dictionary] = matlab_kshape(X, d)
198+
elif initialization_method == "centroid_uniform":
199+
[_, Dictionary] = kshape_with_centroid_initialize(X, d, is_pp=False)
200+
elif initialization_method == "k-shape++":
201+
[_, Dictionary] = kshape_with_centroid_initialize(X, d, is_pp=True)
202+
else:
203+
raise exceptions.InitializationMethodNotFound
204+
205+
sys.stdout = old_stdout
206+
207+
if gamma is None:
208+
if GV is None:
209+
GV = [*range(1, 21)]
210+
211+
[_, gamma] = gamma_select(Dictionary, GV, r)
212+
213+
E = np.zeros((n, d))
214+
for i in range(n):
215+
for j in range(d):
216+
E[i, j] = SINK(X[i, :], Dictionary[j, :], gamma, fourier_coeff, e)
217+
218+
if eigenvecMatrix is None and inVa is None:
219+
W = np.zeros((d, d))
220+
for i in range(d):
221+
for j in range(d):
222+
W[i, j] = SINK(
223+
Dictionary[i, :], Dictionary[j, :], gamma, fourier_coeff, e
224+
)
225+
226+
[eigenvalvector, eigenvecMatrix] = np.linalg.eigh(W)
227+
inVa = np.diag(np.power(eigenvalvector, -0.5))
228+
229+
Zexact = E @ eigenvecMatrix @ inVa
230+
Zexact = CheckNaNInfComplex(Zexact)
231+
Zexact = np.real(Zexact)
232+
233+
return Zexact, Dictionary, gamma, eigenvecMatrix, inVa
234+
235+
@staticmethod
236+
def _modified_GRAIL_rep_predict(
237+
X, d, Dictionary, gamma, eigenvecMatrix, inVa, f=0.99, fourier_coeff=-1, e=-1
238+
):
239+
"""Predict GRAIL representation.
240+
241+
A modified version of the GRAIL_rep function from GRAIL.
242+
"""
243+
from GRAIL.GRAIL_core import CheckNaNInfComplex
244+
from GRAIL.SINK import SINK
245+
246+
n = X.shape[0]
247+
E = np.zeros((n, d))
248+
for i in range(n):
249+
for j in range(d):
250+
E[i, j] = SINK(X[i, :], Dictionary[j, :], gamma, fourier_coeff, e)
251+
252+
Zexact = E @ eigenvecMatrix @ inVa
253+
Zexact = CheckNaNInfComplex(Zexact)
254+
Zexact = np.real(Zexact)
255+
256+
return Zexact
257+
258+
def _more_tags(self) -> dict:
259+
return {
260+
"X_types": ["2darray"],
261+
"optional_dependency": True,
262+
"univariate_only": True,
263+
"non_deterministic": True,
264+
}

0 commit comments

Comments
 (0)