Skip to content

Commit 5a9a826

Browse files
sappelhoffnbara
andauthored
[FIX] ASR fixes (#90)
Co-authored-by: Nicolas Barascud <nicolas.barascud@gmail.com>
1 parent e44c4bb commit 5a9a826

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

meegkit/asr.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@ class ASR:
2121
component-based artifact removal method for removing transient or
2222
large-amplitude artifacts in multi-channel EEG recordings [1]_.
2323
24+
The key parameter of the method is ``cutoff``.
25+
2426
Parameters
2527
----------
2628
sfreq : float
2729
Sampling rate of the data, in Hz.
28-
29-
The following are optional parameters (the key parameter of the method is
30-
the ``cutoff``):
31-
3230
cutoff: float
3331
Standard deviation cutoff for rejection. X portions whose variance
3432
is larger than this threshold relative to the calibration data are
@@ -58,16 +56,16 @@ class ASR:
5856
method : {'riemann', 'euclid'}
5957
Method to use. If riemann, use the riemannian-modified version of
6058
ASR [2]_.
61-
memory : float
62-
Memory size (s), regulates the number of covariance matrices to store.
63-
estimator : str in {'scm', 'lwf', 'oas', 'mcd'}
59+
memory : float | None
60+
Memory size (samples), regulates the number of covariance matrices to
61+
store.
62+
If None (default), will use twice the sampling frequency.
63+
estimator : {'scm', 'lwf', 'oas', 'mcd'}
6464
Covariance estimator (default: 'scm' which computes the sample
6565
covariance). Use 'lwf' if you need regularization (requires pyriemann).
6666
6767
Attributes
6868
----------
69-
``state_`` : dict
70-
Initial state of the ASR filter.
7169
``zi_``: array, shape=(n_channels, filter_order)
7270
Filter initial conditions.
7371
``ab_``: 2-tuple
@@ -98,9 +96,9 @@ class ASR:
9896
9997
"""
10098

101-
def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5,
99+
def __init__(self, *, sfreq=250, cutoff=5, blocksize=100, win_len=0.5,
102100
win_overlap=0.66, max_dropout_fraction=0.1,
103-
min_clean_fraction=0.25, name="asrfilter", method="euclid",
101+
min_clean_fraction=0.25, method="euclid", memory=None,
104102
estimator="scm", **kwargs):
105103

106104
if pyriemann is None and method == "riemann":
@@ -115,7 +113,10 @@ def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5,
115113
self.min_clean_fraction = min_clean_fraction
116114
self.max_bad_chans = 0.3
117115
self.method = method
118-
self.memory = int(2 * sfreq) # smoothing window for covariances
116+
if memory is None:
117+
self.memory = int(2 * sfreq) # smoothing window for covariances
118+
else:
119+
self.memory = memory
119120
self.sample_weight = np.geomspace(0.05, 1, num=self.memory + 1)
120121
self.sfreq = sfreq
121122
self.estimator = estimator
@@ -141,10 +142,10 @@ def fit(self, X, y=None, **kwargs):
141142
"""Calibration for the Artifact Subspace Reconstruction method.
142143
143144
The input to this data is a multi-channel time series of calibration
144-
data. In typical uses the calibration data is clean resting EEG data of
145-
data if the fraction of artifact content is below the breakdown point
145+
data. In typical uses the calibration data is clean resting EEG data.
146+
The fraction of artifact content should be below the breakdown point
146147
of the robust statistics used for estimation (50% theoretical, ~30%
147-
practical). If the data has a proportion of more than 30-50% artifacts
148+
practical). If the data has a proportion of more than 30-50% artifacts,
148149
then bad time windows should be removed beforehand. This data is used
149150
to estimate the thresholds that are used by the ASR processing function
150151
to identify and remove artifact components.
@@ -164,6 +165,12 @@ def fit(self, X, y=None, **kwargs):
164165
reasonably clean not less than 30 seconds (this method is typically
165166
used with 1 minute or more).
166167
168+
Returns
169+
-------
170+
clean : array, shape=(n_channels, n_samples)
171+
Dataset with bad time periods removed.
172+
sample_mask : boolean array, shape=(1, n_samples)
173+
Mask of retained samples (logical array).
167174
"""
168175
if X.ndim == 3:
169176
X = X.squeeze()
@@ -468,6 +475,9 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5,
468475
estimation (default=0.25).
469476
method : {'euclid', 'riemann'}
470477
Metric to compute the covariance matrix average.
478+
estimator : {'scm', 'lwf', 'oas', 'mcd'}
479+
Covariance estimator (default: 'scm' which computes the sample
480+
covariance). Use 'lwf' if you need regularization (requires pyriemann).
471481
472482
Returns
473483
-------

tests/test_asr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_asr_class(method, reref, show=False):
193193
blah = ASR(method=method, estimator="scm")
194194
blah.fit(raw2[:, train_idx])
195195

196-
asr = ASR(method=method, estimator="lwf")
196+
asr = ASR(method=method, estimator="lwf", memory=int(2 * sfreq))
197197
asr.fit(raw2[:, train_idx])
198198
else:
199199
asr = ASR(method=method, estimator="scm")

0 commit comments

Comments
 (0)