Skip to content

Commit ba2d037

Browse files
eliasbenyahianbaraCopilot
authored
[ENH] Annotations in dss_line_iter plots and specification of a saving path (#89)
Co-authored-by: Nicolas Barascud <nicolas.barascud@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 5a9a826 commit ba2d037

File tree

4 files changed

+64
-54
lines changed

4 files changed

+64
-54
lines changed

examples/example_dss_line.ipynb

Lines changed: 26 additions & 34 deletions
Large diffs are not rendered by default.

examples/example_dss_line.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@
6868
print(data.shape) # n_samples, n_chans, n_trials
6969

7070
# Apply dss_line(), removing only one component
71-
out1, _ = dss.dss_line(data, fline, sfreq, nremove=1, nfft=400)
71+
out1, _ = dss.dss_line(data, fline, sfreq, nfft=400, nremove=1)
7272

7373
###############################################################################
7474
# Now try dss_line_iter(). This applies dss_line() repeatedly until the
7575
# artifact is gone
76-
out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400)
76+
out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400, show=True)
7777
print(f"Removed {iterations} components")
7878

7979
###############################################################################

meegkit/dss.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Denoising source separation."""
22
# Authors: Nicolas Barascud <nicolas.barascud@gmail.com>
33
# Maciej Szul <maciej.szul@isc.cnrs.fr>
4+
from pathlib import Path
5+
46
import numpy as np
57
from numpy.lib.stride_tricks import sliding_window_view
68
from scipy import linalg
@@ -264,7 +266,7 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
264266

265267

266268
def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
267-
nfft=512, show=False, prefix="dss_iter", n_iter_max=100):
269+
nfft=512, show=False, dirname=None, extension=".png", n_iter_max=100):
268270
"""Remove power line artifact iteratively.
269271
270272
This method applies dss_line() until the artifact has been smoothed out
@@ -288,9 +290,12 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
288290
FFT size for the internal PSD calculation (default=512).
289291
show: bool
290292
Produce a visual output of each iteration (default=False).
291-
prefix : str
292-
Path and first part of the visualisation output file
293-
"{prefix}_{iteration number}.png" (default="dss_iter").
293+
dirname: str
294+
Path to the directory where visual outputs are saved when show is 'True'.
295+
If 'None', does not save the outputs. (default=None)
296+
extension: str
297+
Extension of the images filenames. Must be compatible with plt.savefig()
298+
function. (default=".png")
294299
n_iter_max : int
295300
Maximum number of iterations (default=100).
296301
@@ -312,7 +317,8 @@ def nan_basic_interp(array):
312317
freq_sp = [fline - spot_sz, fline + spot_sz]
313318
freq, psd = welch(data, fs=sfreq, nfft=nfft, axis=0)
314319

315-
freq_rn_ix = np.logical_and(freq >= freq_rn[0], freq <= freq_rn[1])
320+
freq_rn_ix = np.logical_and(freq >= freq_rn[0],
321+
freq <= freq_rn[1])
316322
freq_used = freq[freq_rn_ix]
317323
freq_sp_ix = np.logical_and(freq_used >= freq_sp[0],
318324
freq_used <= freq_sp[1])
@@ -357,26 +363,37 @@ def nan_basic_interp(array):
357363
y = mean_sens[freq_rn_ix]
358364
ax.flat[0].plot(freq_used, y)
359365
ax.flat[0].set_title("Mean PSD across trials")
366+
ax.flat[0].set_xlabel("Frequency (Hz)")
367+
ax.flat[0].set_ylabel("Power")
360368

361-
ax.flat[1].plot(freq_used, mean_psd_tf, c="gray")
362-
ax.flat[1].plot(freq_used, mean_psd, c="blue")
363-
ax.flat[1].plot(freq_used, clean_fit_line, c="red")
369+
ax.flat[1].plot(freq_used, mean_psd_tf, c="gray",
370+
label="Interpolated mean PSD")
371+
ax.flat[1].plot(freq_used, mean_psd, c="blue", label="Mean PSD")
372+
ax.flat[1].plot(freq_used, clean_fit_line, c="red", label="Fitted polynomial")
364373
ax.flat[1].set_title("Mean PSD across trials and sensors")
374+
ax.flat[1].set_xlabel("Frequency (Hz)")
375+
ax.flat[1].set_ylabel("Power")
376+
ax.flat[1].legend()
365377

366378
tf_ix = np.where(freq_used <= fline)[0][-1]
367-
ax.flat[2].plot(residuals, freq_used)
379+
ax.flat[2].plot(freq_used, residuals)
368380
color = "green"
369381
if mean_score <= 0:
370382
color = "red"
371-
ax.flat[2].scatter(residuals[tf_ix], freq_used[tf_ix], c=color)
383+
ax.flat[2].scatter(freq_used[tf_ix], residuals[tf_ix], c=color)
372384
ax.flat[2].set_title("Residuals")
385+
ax.flat[2].set_xlabel("Frequency (Hz)")
386+
ax.flat[2].set_ylabel("Power")
373387

374388
ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker="o")
375-
ax.flat[3].set_title("Iterations")
389+
ax.flat[3].set_title("Aggregated residuals")
390+
ax.flat[3].set_xlabel("Iteration")
391+
ax.flat[3].set_ylabel("Power")
376392

377393
plt.tight_layout()
378-
plt.savefig(f"{prefix}_{iterations:03}.png")
379-
plt.close("all")
394+
if dirname is not None:
395+
plt.savefig(Path(dirname) / f"dss_iter_{iterations:03}{extension}")
396+
plt.show()
380397

381398
if mean_score <= 0:
382399
break

tests/test_dss.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_dss_line_iter():
132132
# # time x channel x trial sf=200 fline=50
133133

134134
sr = 200
135-
fline = 25
135+
fline = 50
136136
n_samples = 9000
137137
n_chans = 10
138138

@@ -147,9 +147,8 @@ def test_dss_line_iter():
147147
show=False, n_iter_max=2)
148148

149149
with TemporaryDirectory() as tmpdir:
150-
out, _ = dss.dss_line_iter(x, fline + .5, sr,
151-
prefix=os.path.join(tmpdir, "dss_iter_"),
152-
show=True)
150+
out, _ = dss.dss_line_iter(x, fline + 1, sr,
151+
show=True, dirname=tmpdir)
153152

154153
def _plot(before, after):
155154
f, ax = plt.subplots(1, 2, sharey=True)
@@ -171,7 +170,9 @@ def _plot(before, after):
171170
# # Test n_trials > 1 TODO
172171
x, _ = create_line_data(n_samples, n_chans=n_chans, n_trials=2,
173172
noise_dim=10, SNR=2, fline=fline / sr)
174-
out, _ = dss.dss_line_iter(x, fline, sr, show=False)
173+
with TemporaryDirectory() as tmpdir:
174+
out, _ = dss.dss_line_iter(x, fline, sr,
175+
show=True, dirname=tmpdir)
175176
plt.close("all")
176177

177178

0 commit comments

Comments
 (0)