11"""Denoising source separation."""
22# Authors: Nicolas Barascud <nicolas.barascud@gmail.com>
33# Maciej Szul <maciej.szul@isc.cnrs.fr>
4+ from pathlib import Path
5+
46import numpy as np
57from numpy .lib .stride_tricks import sliding_window_view
68from scipy import linalg
@@ -264,7 +266,7 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
264266
265267
266268def dss_line_iter (data , fline , sfreq , win_sz = 10 , spot_sz = 2.5 ,
267- nfft = 512 , show = False , prefix = "dss_iter " , n_iter_max = 100 ):
269+ nfft = 512 , show = False , dirname = None , extension = ".png " , n_iter_max = 100 ):
268270 """Remove power line artifact iteratively.
269271
270272 This method applies dss_line() until the artifact has been smoothed out
@@ -288,9 +290,12 @@ def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
288290 FFT size for the internal PSD calculation (default=512).
289291 show: bool
290292 Produce a visual output of each iteration (default=False).
291- prefix : str
292- Path and first part of the visualisation output file
293- "{prefix}_{iteration number}.png" (default="dss_iter").
293+ dirname: str
294+ Path to the directory where visual outputs are saved when show is 'True'.
295+ If 'None', does not save the outputs. (default=None)
296+ extension: str
297+ Extension of the images filenames. Must be compatible with plt.savefig()
298+ function. (default=".png")
294299 n_iter_max : int
295300 Maximum number of iterations (default=100).
296301
@@ -312,7 +317,8 @@ def nan_basic_interp(array):
312317 freq_sp = [fline - spot_sz , fline + spot_sz ]
313318 freq , psd = welch (data , fs = sfreq , nfft = nfft , axis = 0 )
314319
315- freq_rn_ix = np .logical_and (freq >= freq_rn [0 ], freq <= freq_rn [1 ])
320+ freq_rn_ix = np .logical_and (freq >= freq_rn [0 ],
321+ freq <= freq_rn [1 ])
316322 freq_used = freq [freq_rn_ix ]
317323 freq_sp_ix = np .logical_and (freq_used >= freq_sp [0 ],
318324 freq_used <= freq_sp [1 ])
@@ -357,26 +363,37 @@ def nan_basic_interp(array):
357363 y = mean_sens [freq_rn_ix ]
358364 ax .flat [0 ].plot (freq_used , y )
359365 ax .flat [0 ].set_title ("Mean PSD across trials" )
366+ ax .flat [0 ].set_xlabel ("Frequency (Hz)" )
367+ ax .flat [0 ].set_ylabel ("Power" )
360368
361- ax .flat [1 ].plot (freq_used , mean_psd_tf , c = "gray" )
362- ax .flat [1 ].plot (freq_used , mean_psd , c = "blue" )
363- ax .flat [1 ].plot (freq_used , clean_fit_line , c = "red" )
369+ ax .flat [1 ].plot (freq_used , mean_psd_tf , c = "gray" ,
370+ label = "Interpolated mean PSD" )
371+ ax .flat [1 ].plot (freq_used , mean_psd , c = "blue" , label = "Mean PSD" )
372+ ax .flat [1 ].plot (freq_used , clean_fit_line , c = "red" , label = "Fitted polynomial" )
364373 ax .flat [1 ].set_title ("Mean PSD across trials and sensors" )
374+ ax .flat [1 ].set_xlabel ("Frequency (Hz)" )
375+ ax .flat [1 ].set_ylabel ("Power" )
376+ ax .flat [1 ].legend ()
365377
366378 tf_ix = np .where (freq_used <= fline )[0 ][- 1 ]
367- ax .flat [2 ].plot (residuals , freq_used )
379+ ax .flat [2 ].plot (freq_used , residuals )
368380 color = "green"
369381 if mean_score <= 0 :
370382 color = "red"
371- ax .flat [2 ].scatter (residuals [tf_ix ], freq_used [tf_ix ], c = color )
383+ ax .flat [2 ].scatter (freq_used [tf_ix ], residuals [tf_ix ], c = color )
372384 ax .flat [2 ].set_title ("Residuals" )
385+ ax .flat [2 ].set_xlabel ("Frequency (Hz)" )
386+ ax .flat [2 ].set_ylabel ("Power" )
373387
374388 ax .flat [3 ].plot (np .arange (iterations + 1 ), aggr_resid , marker = "o" )
375- ax .flat [3 ].set_title ("Iterations" )
389+ ax .flat [3 ].set_title ("Aggregated residuals" )
390+ ax .flat [3 ].set_xlabel ("Iteration" )
391+ ax .flat [3 ].set_ylabel ("Power" )
376392
377393 plt .tight_layout ()
378- plt .savefig (f"{ prefix } _{ iterations :03} .png" )
379- plt .close ("all" )
394+ if dirname is not None :
395+ plt .savefig (Path (dirname ) / f"dss_iter_{ iterations :03} { extension } " )
396+ plt .show ()
380397
381398 if mean_score <= 0 :
382399 break
0 commit comments