1
1
"""Denoising source separation."""
2
2
# Authors: Nicolas Barascud <nicolas.barascud@gmail.com>
3
3
# Maciej Szul <maciej.szul@isc.cnrs.fr>
4
+ from pathlib import Path
5
+
4
6
import numpy as np
5
7
from numpy .lib .stride_tricks import sliding_window_view
6
8
from scipy import linalg
@@ -264,7 +266,7 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
264
266
265
267
266
268
def dss_line_iter (data , fline , sfreq , win_sz = 10 , spot_sz = 2.5 ,
267
- nfft = 512 , show = False , prefix = "dss_iter " , n_iter_max = 100 ):
269
+ nfft = 512 , show = False , dirname = None , extension = ".png " , n_iter_max = 100 ):
268
270
"""Remove power line artifact iteratively.
269
271
270
272
This method applies dss_line() until the artifact has been smoothed out
@@ -288,9 +290,12 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
288
290
FFT size for the internal PSD calculation (default=512).
289
291
show: bool
290
292
Produce a visual output of each iteration (default=False).
291
- prefix : str
292
- Path and first part of the visualisation output file
293
- "{prefix}_{iteration number}.png" (default="dss_iter").
293
+ dirname: str
294
+ Path to the directory where visual outputs are saved when show is 'True'.
295
+ If 'None', does not save the outputs. (default=None)
296
+ extension: str
297
+ Extension of the images filenames. Must be compatible with plt.savefig()
298
+ function. (default=".png")
294
299
n_iter_max : int
295
300
Maximum number of iterations (default=100).
296
301
@@ -312,7 +317,8 @@ def nan_basic_interp(array):
312
317
freq_sp = [fline - spot_sz , fline + spot_sz ]
313
318
freq , psd = welch (data , fs = sfreq , nfft = nfft , axis = 0 )
314
319
315
- freq_rn_ix = np .logical_and (freq >= freq_rn [0 ], freq <= freq_rn [1 ])
320
+ freq_rn_ix = np .logical_and (freq >= freq_rn [0 ],
321
+ freq <= freq_rn [1 ])
316
322
freq_used = freq [freq_rn_ix ]
317
323
freq_sp_ix = np .logical_and (freq_used >= freq_sp [0 ],
318
324
freq_used <= freq_sp [1 ])
@@ -357,26 +363,37 @@ def nan_basic_interp(array):
357
363
y = mean_sens [freq_rn_ix ]
358
364
ax .flat [0 ].plot (freq_used , y )
359
365
ax .flat [0 ].set_title ("Mean PSD across trials" )
366
+ ax .flat [0 ].set_xlabel ("Frequency (Hz)" )
367
+ ax .flat [0 ].set_ylabel ("Power" )
360
368
361
- ax .flat [1 ].plot (freq_used , mean_psd_tf , c = "gray" )
362
- ax .flat [1 ].plot (freq_used , mean_psd , c = "blue" )
363
- ax .flat [1 ].plot (freq_used , clean_fit_line , c = "red" )
369
+ ax .flat [1 ].plot (freq_used , mean_psd_tf , c = "gray" ,
370
+ label = "Interpolated mean PSD" )
371
+ ax .flat [1 ].plot (freq_used , mean_psd , c = "blue" , label = "Mean PSD" )
372
+ ax .flat [1 ].plot (freq_used , clean_fit_line , c = "red" , label = "Fitted polynomial" )
364
373
ax .flat [1 ].set_title ("Mean PSD across trials and sensors" )
374
+ ax .flat [1 ].set_xlabel ("Frequency (Hz)" )
375
+ ax .flat [1 ].set_ylabel ("Power" )
376
+ ax .flat [1 ].legend ()
365
377
366
378
tf_ix = np .where (freq_used <= fline )[0 ][- 1 ]
367
- ax .flat [2 ].plot (residuals , freq_used )
379
+ ax .flat [2 ].plot (freq_used , residuals )
368
380
color = "green"
369
381
if mean_score <= 0 :
370
382
color = "red"
371
- ax .flat [2 ].scatter (residuals [tf_ix ], freq_used [tf_ix ], c = color )
383
+ ax .flat [2 ].scatter (freq_used [tf_ix ], residuals [tf_ix ], c = color )
372
384
ax .flat [2 ].set_title ("Residuals" )
385
+ ax .flat [2 ].set_xlabel ("Frequency (Hz)" )
386
+ ax .flat [2 ].set_ylabel ("Power" )
373
387
374
388
ax .flat [3 ].plot (np .arange (iterations + 1 ), aggr_resid , marker = "o" )
375
- ax .flat [3 ].set_title ("Iterations" )
389
+ ax .flat [3 ].set_title ("Aggregated residuals" )
390
+ ax .flat [3 ].set_xlabel ("Iteration" )
391
+ ax .flat [3 ].set_ylabel ("Power" )
376
392
377
393
plt .tight_layout ()
378
- plt .savefig (f"{ prefix } _{ iterations :03} .png" )
379
- plt .close ("all" )
394
+ if dirname is not None :
395
+ plt .savefig (Path (dirname ) / f"dss_iter_{ iterations :03} { extension } " )
396
+ plt .show ()
380
397
381
398
if mean_score <= 0 :
382
399
break
0 commit comments