|
| 1 | +"""GRAIL classifier. |
| 2 | +
|
| 3 | +See the original implementation here: |
| 4 | +https://github.com/TheDatumOrg/grail-python |
| 5 | +""" |
| 6 | + |
| 7 | +import os |
| 8 | +import sys |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +from sklearn.base import ClassifierMixin |
| 12 | +from sklearn.model_selection import GridSearchCV |
| 13 | +from sklearn.svm import SVC |
| 14 | +from sklearn.utils.multiclass import check_classification_targets |
| 15 | +from sklearn.utils.validation import check_is_fitted |
| 16 | + |
| 17 | +from tsml.base import BaseTimeSeriesEstimator |
| 18 | +from tsml.utils.validation import _check_optional_dependency |
| 19 | + |
| 20 | + |
| 21 | +class GRAILClassifier(ClassifierMixin, BaseTimeSeriesEstimator): |
| 22 | + """ |
| 23 | + GRAIL classifier. |
| 24 | +
|
| 25 | + Examples |
| 26 | + -------- |
| 27 | + >>> from tsml.datasets import load_minimal_chinatown |
| 28 | + >>> from tsml.distance_based import GRAILClassifier |
| 29 | + >>> X, y = load_minimal_chinatown() |
| 30 | + >>> clf = GRAILClassifier() |
| 31 | + >>> clf.fit(X, y) |
| 32 | + GRAILClassifier(...) |
| 33 | + >>> preds = clf.predict(X) |
| 34 | + """ |
| 35 | + |
| 36 | + def __init__(self, classifier="svm"): |
| 37 | + self.classifier = classifier |
| 38 | + |
| 39 | + _check_optional_dependency("grailts", "GRAIL", self) |
| 40 | + |
| 41 | + super(GRAILClassifier, self).__init__() |
| 42 | + |
| 43 | + def fit(self, X, y): |
| 44 | + """Fit the estimator to training data. |
| 45 | +
|
| 46 | + Parameters |
| 47 | + ---------- |
| 48 | + X : 2D np.ndarray of shape (n_instances, n_timepoints) |
| 49 | + The training data. |
| 50 | + y : 1D np.ndarray of shape (n_instances) |
| 51 | + The class labels for fitting, indices correspond to instance indices in X |
| 52 | +
|
| 53 | + Returns |
| 54 | + ------- |
| 55 | + self : |
| 56 | + Reference to self. |
| 57 | + """ |
| 58 | + X, y = self._validate_data(X=X, y=y, ensure_min_samples=2) |
| 59 | + X = self._convert_X(X) |
| 60 | + |
| 61 | + check_classification_targets(y) |
| 62 | + |
| 63 | + self.n_instances_, self.n_timepoints_ = X.shape |
| 64 | + self.classes_, class_count = np.unique(y, return_counts=True) |
| 65 | + self.n_classes_ = self.classes_.shape[0] |
| 66 | + self.class_dictionary_ = {} |
| 67 | + for index, class_val in enumerate(self.classes_): |
| 68 | + self.class_dictionary_[class_val] = index |
| 69 | + |
| 70 | + if self.n_classes_ == 1: |
| 71 | + return self |
| 72 | + |
| 73 | + self._d = int(self.n_instances_ * 0.4) |
| 74 | + if self._d > 100: |
| 75 | + self._d = 100 |
| 76 | + elif self._d < 3: |
| 77 | + self._d = 3 |
| 78 | + |
| 79 | + ( |
| 80 | + Xt, |
| 81 | + self._Dictionary, |
| 82 | + self._gamma, |
| 83 | + self._eigenvecMatrix, |
| 84 | + self._inVa, |
| 85 | + ) = self._modified_GRAIL_rep_fit(X, self._d) |
| 86 | + |
| 87 | + if self.classifier == "svm": |
| 88 | + self._clf = GridSearchCV( |
| 89 | + SVC(kernel="linear", probability=True), |
| 90 | + param_grid={"C": [i**2 for i in np.arange(-10, 20, 0.11)]}, |
| 91 | + cv=min(min(class_count), 5), |
| 92 | + ) |
| 93 | + self._clf.fit(Xt, y) |
| 94 | + elif self.classifier == "knn": |
| 95 | + self._train_Xt = Xt |
| 96 | + self._train_y = y |
| 97 | + else: |
| 98 | + raise ValueError("classifier must be 'svm' or 'knn'") |
| 99 | + |
| 100 | + return self |
| 101 | + |
| 102 | + def predict(self, X): |
| 103 | + """Predicts labels for sequences in X. |
| 104 | +
|
| 105 | + Parameters |
| 106 | + ---------- |
| 107 | + X : 2D np.array of shape (n_instances, n_timepoints) |
| 108 | + The testing data. |
| 109 | +
|
| 110 | + Returns |
| 111 | + ------- |
| 112 | + y : array-like of shape (n_instances) |
| 113 | + Predicted class labels. |
| 114 | + """ |
| 115 | + return np.array( |
| 116 | + [self.classes_[int(np.argmax(prob))] for prob in self.predict_proba(X)] |
| 117 | + ) |
| 118 | + |
| 119 | + def predict_proba(self, X): |
| 120 | + """Predicts labels probabilities for sequences in X. |
| 121 | +
|
| 122 | + Parameters |
| 123 | + ---------- |
| 124 | + X : 2D np.array of shape (n_instances, n_timepoints) |
| 125 | + The testing data. |
| 126 | +
|
| 127 | + Returns |
| 128 | + ------- |
| 129 | + y : array-like of shape (n_instances, n_classes_) |
| 130 | + Predicted probabilities using the ordering in classes_. |
| 131 | + """ |
| 132 | + check_is_fitted(self) |
| 133 | + |
| 134 | + # treat case of single class seen in fit |
| 135 | + if self.n_classes_ == 1: |
| 136 | + return np.repeat([[1]], X.shape[0], axis=0) |
| 137 | + |
| 138 | + X = self._validate_data(X=X, reset=False) |
| 139 | + X = self._convert_X(X) |
| 140 | + |
| 141 | + Xt = self._modified_GRAIL_rep_predict( |
| 142 | + X, self._d, self._Dictionary, self._gamma, self._eigenvecMatrix, self._inVa |
| 143 | + ) |
| 144 | + |
| 145 | + if self.classifier == "svm": |
| 146 | + probas = self._clf.predict_proba(Xt) |
| 147 | + elif self.classifier == "knn": |
| 148 | + from GRAIL.kNN import kNN |
| 149 | + |
| 150 | + k = 5 |
| 151 | + neighbors, _, _ = kNN( |
| 152 | + self._train_Xt, |
| 153 | + Xt, |
| 154 | + method="ED", |
| 155 | + k=k, |
| 156 | + representation=None, |
| 157 | + pq_method="opq", |
| 158 | + ) |
| 159 | + |
| 160 | + probas = np.zeros((len(X), self.n_classes_)) |
| 161 | + for i, case in enumerate(neighbors): |
| 162 | + for j in range(k): |
| 163 | + probas[i, self.class_dictionary_[self._train_y[case[j]]]] += 1 |
| 164 | + probas[i] /= k |
| 165 | + else: |
| 166 | + raise ValueError("classifier must be 'svm' or 'knn'") |
| 167 | + |
| 168 | + return probas |
| 169 | + |
| 170 | + @staticmethod |
| 171 | + def _modified_GRAIL_rep_fit( |
| 172 | + X, |
| 173 | + d, |
| 174 | + r=20, |
| 175 | + GV=None, |
| 176 | + fourier_coeff=-1, |
| 177 | + e=-1, |
| 178 | + eigenvecMatrix=None, |
| 179 | + inVa=None, |
| 180 | + gamma=None, |
| 181 | + initialization_method="k-shape++", |
| 182 | + ): |
| 183 | + """Fit GRAIL representation. |
| 184 | +
|
| 185 | + A modified version of the GRAIL_rep function from GRAIL. |
| 186 | + """ |
| 187 | + from GRAIL import exceptions |
| 188 | + from GRAIL.GRAIL_core import CheckNaNInfComplex, gamma_select |
| 189 | + from GRAIL.kshape import kshape_with_centroid_initialize, matlab_kshape |
| 190 | + from GRAIL.SINK import SINK |
| 191 | + |
| 192 | + old_stdout = sys.stdout |
| 193 | + sys.stdout = open(os.devnull, "w") |
| 194 | + |
| 195 | + n = X.shape[0] |
| 196 | + if initialization_method == "partition": |
| 197 | + [_, Dictionary] = matlab_kshape(X, d) |
| 198 | + elif initialization_method == "centroid_uniform": |
| 199 | + [_, Dictionary] = kshape_with_centroid_initialize(X, d, is_pp=False) |
| 200 | + elif initialization_method == "k-shape++": |
| 201 | + [_, Dictionary] = kshape_with_centroid_initialize(X, d, is_pp=True) |
| 202 | + else: |
| 203 | + raise exceptions.InitializationMethodNotFound |
| 204 | + |
| 205 | + sys.stdout = old_stdout |
| 206 | + |
| 207 | + if gamma is None: |
| 208 | + if GV is None: |
| 209 | + GV = [*range(1, 21)] |
| 210 | + |
| 211 | + [_, gamma] = gamma_select(Dictionary, GV, r) |
| 212 | + |
| 213 | + E = np.zeros((n, d)) |
| 214 | + for i in range(n): |
| 215 | + for j in range(d): |
| 216 | + E[i, j] = SINK(X[i, :], Dictionary[j, :], gamma, fourier_coeff, e) |
| 217 | + |
| 218 | + if eigenvecMatrix is None and inVa is None: |
| 219 | + W = np.zeros((d, d)) |
| 220 | + for i in range(d): |
| 221 | + for j in range(d): |
| 222 | + W[i, j] = SINK( |
| 223 | + Dictionary[i, :], Dictionary[j, :], gamma, fourier_coeff, e |
| 224 | + ) |
| 225 | + |
| 226 | + [eigenvalvector, eigenvecMatrix] = np.linalg.eigh(W) |
| 227 | + inVa = np.diag(np.power(eigenvalvector, -0.5)) |
| 228 | + |
| 229 | + Zexact = E @ eigenvecMatrix @ inVa |
| 230 | + Zexact = CheckNaNInfComplex(Zexact) |
| 231 | + Zexact = np.real(Zexact) |
| 232 | + |
| 233 | + return Zexact, Dictionary, gamma, eigenvecMatrix, inVa |
| 234 | + |
| 235 | + @staticmethod |
| 236 | + def _modified_GRAIL_rep_predict( |
| 237 | + X, d, Dictionary, gamma, eigenvecMatrix, inVa, f=0.99, fourier_coeff=-1, e=-1 |
| 238 | + ): |
| 239 | + """Predict GRAIL representation. |
| 240 | +
|
| 241 | + A modified version of the GRAIL_rep function from GRAIL. |
| 242 | + """ |
| 243 | + from GRAIL.GRAIL_core import CheckNaNInfComplex |
| 244 | + from GRAIL.SINK import SINK |
| 245 | + |
| 246 | + n = X.shape[0] |
| 247 | + E = np.zeros((n, d)) |
| 248 | + for i in range(n): |
| 249 | + for j in range(d): |
| 250 | + E[i, j] = SINK(X[i, :], Dictionary[j, :], gamma, fourier_coeff, e) |
| 251 | + |
| 252 | + Zexact = E @ eigenvecMatrix @ inVa |
| 253 | + Zexact = CheckNaNInfComplex(Zexact) |
| 254 | + Zexact = np.real(Zexact) |
| 255 | + |
| 256 | + return Zexact |
| 257 | + |
| 258 | + def _more_tags(self) -> dict: |
| 259 | + return { |
| 260 | + "X_types": ["2darray"], |
| 261 | + "optional_dependency": True, |
| 262 | + "univariate_only": True, |
| 263 | + "non_deterministic": True, |
| 264 | + } |
0 commit comments