From 0765454f81d9f429b0b3a1162d0ea750ffd191f4 Mon Sep 17 00:00:00 2001 From: Genuster Date: Mon, 14 Jul 2025 14:11:36 +0300 Subject: [PATCH 01/23] topomap and scree plots --- mne/viz/decoding/__init__.py | 5 ++ mne/viz/decoding/ged.py | 140 +++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 mne/viz/decoding/__init__.py create mode 100644 mne/viz/decoding/ged.py diff --git a/mne/viz/decoding/__init__.py b/mne/viz/decoding/__init__.py new file mode 100644 index 00000000000..96b99a5e10f --- /dev/null +++ b/mne/viz/decoding/__init__.py @@ -0,0 +1,5 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +"""Decoding visualization routines.""" diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py new file mode 100644 index 00000000000..83f2883afbc --- /dev/null +++ b/mne/viz/decoding/ged.py @@ -0,0 +1,140 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import copy as cp + +import matplotlib.pyplot as plt +import numpy as np + +from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT +from ...evoked import EvokedArray + + +def _plot_model( + model_array, + info, + model="inverse", + components=None, + *, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format=None, + nrows=1, + ncols="auto", + show=True, +): + if name_format is None: + name_format = f"{model}%01d" + + if units is None: + units = "AU" + if components is None: + # n_components are rows + components = np.arange(model_array.shape[0]) + + # set sampling frequency to have 1 component per time point + info = cp.deepcopy(info) + with info._unlock(): + info["sfreq"] = 1.0 + # create an evoked + filters_evk = EvokedArray(model_array.T, info, tmin=0) + # the call plot_topomap + fig = filters_evk.plot_topomap( + times=components, + average=None, + ch_type=ch_type, + scalings=scalings, + proj=False, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) + return fig + + +def _plot_scree( + evals, + title="Scree plot", + add_cumul_evals=True, + plt_style="seaborn-v0_8-whitegrid", + ax=None, +): + cumul_evals = np.cumsum(evals) + n_components = len(evals) + component_numbers = np.arange(n_components) + + # check available styles with plt.style.available + with plt.style.context(plt_style): + if ax is None: + fig, ax = plt.subplots(figsize=(12, 7), layout="constrained") + else: + fig = None + + # plot individual eigenvalues + color_line = "cornflowerblue" + ax.set_xlabel("Component Index", fontsize=18) + ax.set_ylabel("Eigenvalue", color=color_line, fontsize=18) + ax.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) + ax.plot(component_numbers, evals, color=color_line, marker="o", markersize=10) + ax.tick_params(axis="y", labelcolor=color_line, labelsize=16) + + if add_cumul_evals: + # plot cumulative eigenvalues + ax2 = ax.twinx() + ax2.grid(False) + color_line = "firebrick" + ax2.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) + ax2.plot( + component_numbers, + cumul_evals, + color=color_line, + marker="o", + markersize=6, + ) + ax2.tick_params(axis="y", labelcolor=color_line, labelsize=16) + ax2.set_ylim(0) + + if fig: + fig.suptitle(title, fontsize=22, fontweight="bold") + + return fig From 5a7f5bd804f6e96a088d046f58949a75604fc4b9 Mon Sep 17 00:00:00 2001 From: Genuster Date: Tue, 15 Jul 2025 17:27:32 +0300 Subject: [PATCH 02/23] Add SpatialFilter visualization class --- doc/api/visualization.rst | 15 ++ mne/decoding/base.py | 28 +++ mne/viz/__init__.pyi | 2 + mne/viz/decoding/__init__.py | 1 + mne/viz/decoding/ged.py | 358 +++++++++++++++++++++++++++++++++-- 5 files changed, 388 insertions(+), 16 deletions(-) diff --git a/doc/api/visualization.rst b/doc/api/visualization.rst index 280ed51f590..f40065bf6aa 100644 --- a/doc/api/visualization.rst +++ b/doc/api/visualization.rst @@ -103,6 +103,21 @@ Eyetracking plot_gaze +Decoding +-------- + +.. currentmodule:: mne.viz.decoding + +:py:mod:`mne.viz.decoding`: + +.. automodule:: mne.viz.decoding + :no-members: + :no-inherited-members: +.. autosummary:: + :toctree: ../generated/ + + SpatialFilter + UI Events --------- diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 16687e91f07..8ffed339141 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -26,6 +26,7 @@ from ..parallel import parallel_func from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn +from ..viz import SpatialFilter from ._ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged from ._mod_ged import _no_op_mod from .transformer import MNETransformerMixin @@ -193,6 +194,33 @@ def transform(self, X): X = pick_filters @ X return X + def get_spatial_filter(self, info): + """Create a SpatialFilter object. + + Creates a :class:`mne.viz.SpatialFilter` object from the fitted + generalized eigendecomposition or other linear models. + This object can be used to visualize the spatial filters, + patterns, and eigenvalues. + + Parameters + ---------- + info : instance of mne.Info + The measurement info object for plotting topomaps. + + Returns + ------- + sp_filter : instance of SpatialFilter + The spatial filter object + """ + sp_filter = SpatialFilter( + info, + evecs=self.filters_.T, + evals=self.evals_, + patterns=self.patterns_, + patterns_method="pinv", + ) + return sp_filter + def _subset_multi_components(self, name="filters"): # The shape of stored filters and patterns is # is (n_classes, n_evecs, n_chs) diff --git a/mne/viz/__init__.pyi b/mne/viz/__init__.pyi index c58ad7d0e54..272c81254e2 100644 --- a/mne/viz/__init__.pyi +++ b/mne/viz/__init__.pyi @@ -3,6 +3,7 @@ __all__ = [ "ClickableImage", "EvokedField", "Figure3D", + "SpatialFilter", "_RAW_CLIP_DEF", "_get_plot_ch_type", "_get_presser", @@ -118,6 +119,7 @@ from .backends.renderer import ( use_3d_backend, ) from .circle import circular_layout, plot_channel_labels_circle +from .decoding.ged import SpatialFilter from .epochs import plot_drop_log, plot_epochs, plot_epochs_image, plot_epochs_psd from .evoked import ( plot_compare_evokeds, diff --git a/mne/viz/decoding/__init__.py b/mne/viz/decoding/__init__.py index 96b99a5e10f..43b144a2892 100644 --- a/mne/viz/decoding/__init__.py +++ b/mne/viz/decoding/__init__.py @@ -3,3 +3,4 @@ # Copyright the MNE-Python contributors. """Decoding visualization routines.""" +from .ged import SpatialFilter diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 83f2883afbc..0fcef9e915b 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -9,12 +9,12 @@ from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ...evoked import EvokedArray +from ...utils import _check_option, pinv def _plot_model( model_array, info, - model="inverse", components=None, *, ch_type=None, @@ -43,9 +43,6 @@ def _plot_model( ncols="auto", show=True, ): - if name_format is None: - name_format = f"{model}%01d" - if units is None: units = "AU" if components is None: @@ -57,9 +54,9 @@ def _plot_model( with info._unlock(): info["sfreq"] = 1.0 # create an evoked - filters_evk = EvokedArray(model_array.T, info, tmin=0) + model_evk = EvokedArray(model_array.T, info, tmin=0) # the call plot_topomap - fig = filters_evk.plot_topomap( + fig = model_evk.plot_topomap( times=components, average=None, ch_type=ch_type, @@ -97,30 +94,29 @@ def _plot_scree( title="Scree plot", add_cumul_evals=True, plt_style="seaborn-v0_8-whitegrid", - ax=None, + axes=None, ): cumul_evals = np.cumsum(evals) n_components = len(evals) component_numbers = np.arange(n_components) - # check available styles with plt.style.available with plt.style.context(plt_style): - if ax is None: - fig, ax = plt.subplots(figsize=(12, 7), layout="constrained") + if axes is None: + fig, axes = plt.subplots(figsize=(12, 7), layout="constrained") else: fig = None # plot individual eigenvalues color_line = "cornflowerblue" - ax.set_xlabel("Component Index", fontsize=18) - ax.set_ylabel("Eigenvalue", color=color_line, fontsize=18) - ax.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) - ax.plot(component_numbers, evals, color=color_line, marker="o", markersize=10) - ax.tick_params(axis="y", labelcolor=color_line, labelsize=16) + axes.set_xlabel("Component Index", fontsize=18) + axes.set_ylabel("Eigenvalue", color=color_line, fontsize=18) + axes.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) + axes.plot(component_numbers, evals, color=color_line, marker="o", markersize=10) + axes.tick_params(axis="y", labelcolor=color_line, labelsize=16) if add_cumul_evals: # plot cumulative eigenvalues - ax2 = ax.twinx() + ax2 = axes.twinx() ax2.grid(False) color_line = "firebrick" ax2.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) @@ -138,3 +134,333 @@ def _plot_scree( fig.suptitle(title, fontsize=22, fontweight="bold") return fig + + +class SpatialFilter: + r"""Visualization container for spatial filter weights and patterns. + + This object is obtained either by generalized eigendecomposition (GED) algorithms + such as :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, + :class:`mne.decoding.SSD`, :class:`mne.decoding.XdawnTransformer` or by + :class:`mne.decoding.LinearModel` wrapping linear models like SVM or Logit. + The objects stores the filters that projects sensor data to a reduced component + space, and the corresponding patterns (obtained by pseudoinverse in GED case or + Haufe's trickin case of `mne.decoding.LinearModel`). It can also be directly + initialized using filters from other transformers (e.g. PyRiemann). + + Parameters + ---------- + info : instance of Info + The measurement info containing channel topography. + evecs : ndarray, shape (n_channels, n_components) + The eigenvectors of the decomposition (transposed filters). + evals : ndarray, shape (n_components,) | None + The eigenvalues of the decomposition. Defaults to ``None``. + patterns : ndarray, shape (n_components, n_channels) | None + The patterns of the decomposition. If None, they will be computed + from the eigenvectors using pseudoinverse. Defaults to ``None``. + patterns_method : str + The method used to compute the patterns. Can be ``'pinv'`` or ``'haufe'``. + If `patterns` is None, it will be set to ``'pinv'``. Defaults to ``'pinv'``. + + Attributes + ---------- + info : instance of Info + The measurement info. + filters : ndarray, shape (n_components, n_channels) + The spatial filters (unmixing matrix). Applying these filters to the data + gives the component time series. + patterns : ndarray, shape (n_components, n_channels) + The spatial patterns (forward model). These represent the scalp + topography of each component. + evals : ndarray, shape (n_components,) + The eigenvalues associated with each component. + patterns_method : str + The method used to compute the patterns from the filters. + + Notes + ----- + The spatial filters and patterns are stored with shape + ``(n_components, n_channels)``. + + Filters and patterns are related by the following equation: + + .. math:: + \\mathbf{A} = \\mathbf{W}^{-1} + + where :math:`\\mathbf{A}` is the matrix of patterns (the mixing matrix) and + :math:`\\mathbf{W}` is the matrix of filters (the unmixing matrix). + + For a detailed discussion on the difference between filters and patterns for GED + see :footcite:`Cohen2022` and :footcite:`HaufeEtAl2014` for linear models in + general. + + Notes + ----- + .. versionadded:: 1.11 + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + info, + evecs, + evals=None, + patterns=None, + patterns_method="pinv", + ): + _check_option( + "patterns_method", + patterns_method, + ("pinv", "haufe"), + ) + self.info = info + self.evals = evals + self.filters = evecs.T + + if patterns is None: + self.patterns = pinv(evecs) + self.patterns_method = "pinv" + else: + self.patterns = patterns + self.patterns_method = patterns_method + + def plot_filters( + self, + components=None, + *, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="Filter%01d", + nrows=1, + ncols="auto", + show=True, + ): + """Plot topographic maps of model filters. + + Parameters + ---------- + components : float | array of float + Indices of filters to plot. If None, all filters will be plotted. + Defaults to None + %(ch_type_topomap)s + %(scalings_topomap)s + %(sensors_topomap)s + %(show_names_topomap)s + %(mask_evoked_topomap)s + %(mask_params_topomap)s + %(contours_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(image_interp_topomap)s + %(extrapolate_topomap)s + %(border_topomap)s + %(res_topomap)s + %(size_topomap)s + %(cmap_topomap)s + %(vlim_plot_topomap_psd)s + %(cnorm)s + %(colorbar_topomap)s + %(cbar_fmt_topomap)s + %(units_topomap_evoked)s + %(axes_evoked_plot_topomap)s + name_format : str + String format for topomap values. Defaults to ``'Filter%%01d'``. + %(nrows_ncols_topomap)s + %(show)s + + Returns + ------- + fig : instance of matplotlib.figure.Figure + The figure. + """ + _plot_model( + self.filters, + self.info, + components=components, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + name_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) + + def plot_patterns( + self, + components=None, + *, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="Pattern%01d", + nrows=1, + ncols="auto", + show=True, + ): + """Plot topographic maps of model patterns. + + Parameters + ---------- + components : float | array of float + Indices of patterns to plot. If None, all patterns will be plotted. + Defaults to None + %(ch_type_topomap)s + %(scalings_topomap)s + %(sensors_topomap)s + %(show_names_topomap)s + %(mask_evoked_topomap)s + %(mask_params_topomap)s + %(contours_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(image_interp_topomap)s + %(extrapolate_topomap)s + %(border_topomap)s + %(res_topomap)s + %(size_topomap)s + %(cmap_topomap)s + %(vlim_plot_topomap_psd)s + %(cnorm)s + %(colorbar_topomap)s + %(cbar_fmt_topomap)s + %(units_topomap_evoked)s + %(axes_evoked_plot_topomap)s + name_format : str + String format for topomap values. Defaults to ``'Pattern%%01d'``. + %(nrows_ncols_topomap)s + %(show)s + + Returns + ------- + fig : instance of matplotlib.figure.Figure + The figure. + """ + _plot_model( + self.patterns, + self.info, + components=components, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + name_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) + + def plot_scree( + self, + title="Scree plot", + add_cumul_evals=True, + plt_style="seaborn-v0_8-whitegrid", + axes=None, + ): + """Plot scree for GED eigenvalues. + + Parameters + ---------- + title : str + Title for the plot. Defaults to ``'Scree plot'``. + add_cumul_evals : bool + Whether to add second line and y-axis for cumulative eigenvalues. + Defaults to ``True``. + plt_style : str + Matplotlib plot style. + Check available styles with plt.style.available. + Defaults to ``'seaborn-v0_8-whitegrid'``. + axes : instance of Axes | None + The matplotlib axes to plot to. Defaults to ``None``. + + Returns + ------- + fig : instance of matplotlib.figure.Figure + The figure. + """ + if self.evals is None: + raise ValueError("Can't plot scree if eigenvalues are not provided.") + fig = _plot_scree( + self.evals, + title=title, + add_cumul_evals=add_cumul_evals, + plt_style=plt_style, + axes=axes, + ) + return fig From 614cd1b1e0c3de659c8068afdabb61aac8b5dac2 Mon Sep 17 00:00:00 2001 From: Genuster Date: Tue, 15 Jul 2025 19:02:48 +0300 Subject: [PATCH 03/23] Some doc-related fixes --- doc/api/visualization.rst | 16 +--------------- mne/decoding/base.py | 4 ++-- mne/viz/__init__.pyi | 2 +- mne/viz/decoding/ged.py | 28 +++++++++++++++------------- 4 files changed, 19 insertions(+), 31 deletions(-) diff --git a/doc/api/visualization.rst b/doc/api/visualization.rst index f40065bf6aa..39381b62a99 100644 --- a/doc/api/visualization.rst +++ b/doc/api/visualization.rst @@ -17,6 +17,7 @@ Visualization ClickableImage EvokedField Figure3D + SpatialFilter add_background_image centers_to_edges compare_fiff @@ -103,21 +104,6 @@ Eyetracking plot_gaze -Decoding --------- - -.. currentmodule:: mne.viz.decoding - -:py:mod:`mne.viz.decoding`: - -.. automodule:: mne.viz.decoding - :no-members: - :no-inherited-members: -.. autosummary:: - :toctree: ../generated/ - - SpatialFilter - UI Events --------- diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 8ffed339141..69819c14c24 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -197,7 +197,7 @@ def transform(self, X): def get_spatial_filter(self, info): """Create a SpatialFilter object. - Creates a :class:`mne.viz.SpatialFilter` object from the fitted + Creates a `mne.viz.SpatialFilter` object from the fitted generalized eigendecomposition or other linear models. This object can be used to visualize the spatial filters, patterns, and eigenvalues. @@ -210,7 +210,7 @@ def get_spatial_filter(self, info): Returns ------- sp_filter : instance of SpatialFilter - The spatial filter object + The spatial filter object. """ sp_filter = SpatialFilter( info, diff --git a/mne/viz/__init__.pyi b/mne/viz/__init__.pyi index 272c81254e2..3ec5954ba88 100644 --- a/mne/viz/__init__.pyi +++ b/mne/viz/__init__.pyi @@ -119,7 +119,7 @@ from .backends.renderer import ( use_3d_backend, ) from .circle import circular_layout, plot_channel_labels_circle -from .decoding.ged import SpatialFilter +from .decoding import SpatialFilter from .epochs import plot_drop_log, plot_epochs, plot_epochs_image, plot_epochs_psd from .evoked import ( plot_compare_evokeds, diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 0fcef9e915b..532ee6bf27f 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -9,7 +9,7 @@ from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ...evoked import EvokedArray -from ...utils import _check_option, pinv +from ...utils import _check_option, fill_doc, pinv def _plot_model( @@ -140,9 +140,9 @@ class SpatialFilter: r"""Visualization container for spatial filter weights and patterns. This object is obtained either by generalized eigendecomposition (GED) algorithms - such as :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, - :class:`mne.decoding.SSD`, :class:`mne.decoding.XdawnTransformer` or by - :class:`mne.decoding.LinearModel` wrapping linear models like SVM or Logit. + such as `mne.decoding.CSP`, `mne.decoding.SPoC`, `mne.decoding.SSD`, + `mne.decoding.XdawnTransformer` or by `mne.decoding.LinearModel` + wrapping linear models like SVM or Logit. The objects stores the filters that projects sensor data to a reduced component space, and the corresponding patterns (obtained by pseudoinverse in GED case or Haufe's trickin case of `mne.decoding.LinearModel`). It can also be directly @@ -195,8 +195,6 @@ class SpatialFilter: see :footcite:`Cohen2022` and :footcite:`HaufeEtAl2014` for linear models in general. - Notes - ----- .. versionadded:: 1.11 References @@ -228,6 +226,7 @@ def __init__( self.patterns = patterns self.patterns_method = patterns_method + @fill_doc def plot_filters( self, components=None, @@ -264,7 +263,7 @@ def plot_filters( ---------- components : float | array of float Indices of filters to plot. If None, all filters will be plotted. - Defaults to None + Defaults to None. %(ch_type_topomap)s %(scalings_topomap)s %(sensors_topomap)s @@ -294,9 +293,9 @@ def plot_filters( Returns ------- fig : instance of matplotlib.figure.Figure - The figure. + The figure. """ - _plot_model( + fig = _plot_model( self.filters, self.info, components=components, @@ -326,7 +325,9 @@ def plot_filters( ncols=ncols, show=show, ) + return fig + @fill_doc def plot_patterns( self, components=None, @@ -363,7 +364,7 @@ def plot_patterns( ---------- components : float | array of float Indices of patterns to plot. If None, all patterns will be plotted. - Defaults to None + Defaults to None. %(ch_type_topomap)s %(scalings_topomap)s %(sensors_topomap)s @@ -393,9 +394,9 @@ def plot_patterns( Returns ------- fig : instance of matplotlib.figure.Figure - The figure. + The figure. """ - _plot_model( + fig = _plot_model( self.patterns, self.info, components=components, @@ -425,6 +426,7 @@ def plot_patterns( ncols=ncols, show=show, ) + return fig def plot_scree( self, @@ -452,7 +454,7 @@ def plot_scree( Returns ------- fig : instance of matplotlib.figure.Figure - The figure. + The figure. """ if self.evals is None: raise ValueError("Can't plot scree if eigenvalues are not provided.") From c3feb724c46452ffc976b1050bc3b8e7778c310d Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 18 Jul 2025 14:10:40 +0300 Subject: [PATCH 04/23] more imrovements --- doc/_includes/ged.rst | 2 +- mne/decoding/base.py | 34 +++++- mne/viz/decoding/ged.py | 256 ++++++++++++++++++++++++++-------------- 3 files changed, 199 insertions(+), 93 deletions(-) diff --git a/doc/_includes/ged.rst b/doc/_includes/ged.rst index 8f5fc17131c..5146fef5ffa 100644 --- a/doc/_includes/ged.rst +++ b/doc/_includes/ged.rst @@ -14,7 +14,7 @@ This section describes the mathematical formulation and application of Generalized Eigendecomposition (GED), often used in spatial filtering and source separation algorithms, such as :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, :class:`mne.decoding.SSD` and -:class:`mne.preprocessing.Xdawn`. +:class:`mne.decoding.XdawnTransformer`. The core principle of GED is to find a set of channel weights (spatial filter) that maximizes the ratio of signal power between two data features. diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 69819c14c24..c54cc5d120b 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -197,8 +197,8 @@ def transform(self, X): def get_spatial_filter(self, info): """Create a SpatialFilter object. - Creates a `mne.viz.SpatialFilter` object from the fitted - generalized eigendecomposition or other linear models. + Creates an `mne.viz.SpatialFilter` object from the fitted + generalized eigendecomposition. This object can be used to visualize the spatial filters, patterns, and eigenvalues. @@ -209,12 +209,13 @@ def get_spatial_filter(self, info): Returns ------- - sp_filter : instance of SpatialFilter + sp_filter : instance of mne.viz.SpatialFilter The spatial filter object. """ + check_is_fitted(self, ["filters_", "patterns_", "evals_"]) sp_filter = SpatialFilter( info, - evecs=self.filters_.T, + evecs=self.filters_, evals=self.evals_, patterns=self.patterns_, patterns_method="pinv", @@ -443,6 +444,31 @@ def filters_(self): filters = filters[0] return filters + def get_spatial_filter(self, info): + """Create a SpatialFilter object. + + Creates an `mne.viz.SpatialFilter` object from the linear model. + This object can be used to visualize model weights and patterns. + + Parameters + ---------- + info : instance of mne.Info + The measurement info object for plotting topomaps. + + Returns + ------- + sp_filter : instance of mne.viz.SpatialFilter + The spatial filter object. + """ + check_is_fitted(self, ["filters_", "patterns_"]) + sp_filter = SpatialFilter( + info, + evecs=self.filters_.T, + patterns=self.patterns_, + patterns_method="haufe", + ) + return sp_filter + def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 532ee6bf27f..229a70ae53e 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -9,7 +9,7 @@ from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ...evoked import EvokedArray -from ...utils import _check_option, fill_doc, pinv +from ...utils import _check_option, fill_doc def _plot_model( @@ -45,48 +45,120 @@ def _plot_model( ): if units is None: units = "AU" + n_comps = model_array.shape[-2] if components is None: - # n_components are rows - components = np.arange(model_array.shape[0]) + components = np.arange(n_comps) # set sampling frequency to have 1 component per time point info = cp.deepcopy(info) with info._unlock(): info["sfreq"] = 1.0 # create an evoked - model_evk = EvokedArray(model_array.T, info, tmin=0) - # the call plot_topomap - fig = model_evk.plot_topomap( - times=components, - average=None, - ch_type=ch_type, - scalings=scalings, - proj=False, - sensors=sensors, - show_names=show_names, - mask=mask, - mask_params=mask_params, - contours=contours, - outlines=outlines, - sphere=sphere, - image_interp=image_interp, - extrapolate=extrapolate, - border=border, - res=res, - size=size, - cmap=cmap, - vlim=vlim, - cnorm=cnorm, - colorbar=colorbar, - cbar_fmt=cbar_fmt, - units=units, - axes=axes, - time_format=name_format, - nrows=nrows, - ncols=ncols, - show=show, + + if model_array.ndim == 3: + n_classes = model_array.shape[0] + figs = list() + for class_idx in range(n_classes): + model_evk = EvokedArray(model_array[class_idx].T, info, tmin=0) + fig = model_evk.plot_topomap( + times=components, + average=None, + ch_type=ch_type, + scalings=scalings, + proj=False, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes[class_idx], + time_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) + figs.append(fig) + return figs + else: + model_evk = EvokedArray(model_array.T, info, tmin=0) + fig = model_evk.plot_topomap( + times=components, + average=None, + ch_type=ch_type, + scalings=scalings, + proj=False, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + time_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) + return fig + + +def _plot_scree_per_class(evals, add_cumul_evals, axes): + component_numbers = np.arange(len(evals)) + cumul_evals = np.cumsum(evals) if add_cumul_evals else None + # plot individual eigenvalues + color_line = "cornflowerblue" + axes.set_xlabel("Component Index", fontsize=18) + axes.set_ylabel("Eigenvalue", color=color_line, fontsize=18) + axes.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) + axes.plot( + component_numbers, + evals, + color=color_line, + marker="o", + markersize=10, ) - return fig + axes.tick_params(axis="y", labelcolor=color_line, labelsize=16) + + if add_cumul_evals: + # plot cumulative eigenvalues + ax2 = axes.twinx() + ax2.grid(False) + color_line = "firebrick" + ax2.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) + ax2.plot( + component_numbers, + cumul_evals, + color=color_line, + marker="o", + markersize=6, + ) + ax2.tick_params(axis="y", labelcolor=color_line, labelsize=16) + ax2.set_ylim(0) def _plot_scree( @@ -96,67 +168,57 @@ def _plot_scree( plt_style="seaborn-v0_8-whitegrid", axes=None, ): - cumul_evals = np.cumsum(evals) - n_components = len(evals) - component_numbers = np.arange(n_components) - with plt.style.context(plt_style): - if axes is None: - fig, axes = plt.subplots(figsize=(12, 7), layout="constrained") + if evals.ndim == 2: + n_classes = evals.shape[0] + if axes is None: + fig, axes = plt.subplots( + nrows=n_classes, + ncols=1, + figsize=(12, 7 * n_classes), + layout="constrained", + ) + else: + assert len(axes) == n_classes + fig = None + for class_idx in range(n_classes): + _plot_scree_per_class( + evals[class_idx], add_cumul_evals, axes[class_idx] + ) else: - fig = None - - # plot individual eigenvalues - color_line = "cornflowerblue" - axes.set_xlabel("Component Index", fontsize=18) - axes.set_ylabel("Eigenvalue", color=color_line, fontsize=18) - axes.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) - axes.plot(component_numbers, evals, color=color_line, marker="o", markersize=10) - axes.tick_params(axis="y", labelcolor=color_line, labelsize=16) - - if add_cumul_evals: - # plot cumulative eigenvalues - ax2 = axes.twinx() - ax2.grid(False) - color_line = "firebrick" - ax2.set_ylabel("Cumulative Eigenvalues", color=color_line, fontsize=18) - ax2.plot( - component_numbers, - cumul_evals, - color=color_line, - marker="o", - markersize=6, - ) - ax2.tick_params(axis="y", labelcolor=color_line, labelsize=16) - ax2.set_ylim(0) + if axes is None: + fig, axes = plt.subplots(figsize=(12, 7), layout="constrained") + else: + fig = None + _plot_scree_per_class(evals, add_cumul_evals, axes) if fig: fig.suptitle(title, fontsize=22, fontweight="bold") - return fig class SpatialFilter: - r"""Visualization container for spatial filter weights and patterns. + r"""Visualization container for spatial filter weights (evecs) and patterns. This object is obtained either by generalized eigendecomposition (GED) algorithms such as `mne.decoding.CSP`, `mne.decoding.SPoC`, `mne.decoding.SSD`, - `mne.decoding.XdawnTransformer` or by `mne.decoding.LinearModel` + `mne.decoding.XdawnTransformer` or by `mne.decoding.LinearModel`, wrapping linear models like SVM or Logit. - The objects stores the filters that projects sensor data to a reduced component + The object stores the filters that projects sensor data to a reduced component space, and the corresponding patterns (obtained by pseudoinverse in GED case or - Haufe's trickin case of `mne.decoding.LinearModel`). It can also be directly - initialized using filters from other transformers (e.g. PyRiemann). + Haufe's trick in case of `mne.decoding.LinearModel`). It can also be directly + initialized using filters from other transformers (e.g. PyRiemann), + but make sure that the dimensions match. Parameters ---------- info : instance of Info The measurement info containing channel topography. - evecs : ndarray, shape (n_channels, n_components) - The eigenvectors of the decomposition (transposed filters). - evals : ndarray, shape (n_components,) | None + filters : ndarray, shape ((n_classes), n_components, n_channels) + The spatial filters (transposed eigenvectors of the decomposition). + evals : ndarray, shape ((n_classes), n_components) | None The eigenvalues of the decomposition. Defaults to ``None``. - patterns : ndarray, shape (n_components, n_channels) | None + patterns : ndarray, shape ((n_classes), n_components, n_channels) | None The patterns of the decomposition. If None, they will be computed from the eigenvectors using pseudoinverse. Defaults to ``None``. patterns_method : str @@ -171,8 +233,8 @@ class SpatialFilter: The spatial filters (unmixing matrix). Applying these filters to the data gives the component time series. patterns : ndarray, shape (n_components, n_channels) - The spatial patterns (forward model). These represent the scalp - topography of each component. + The spatial patterns (mixing matrix/forward model). + These represent the scalp topography of each component. evals : ndarray, shape (n_components,) The eigenvalues associated with each component. patterns_method : str @@ -186,14 +248,14 @@ class SpatialFilter: Filters and patterns are related by the following equation: .. math:: - \\mathbf{A} = \\mathbf{W}^{-1} + \mathbf{A} = \mathbf{W}^{-1} - where :math:`\\mathbf{A}` is the matrix of patterns (the mixing matrix) and - :math:`\\mathbf{W}` is the matrix of filters (the unmixing matrix). + where :math:`\mathbf{A}` is the matrix of patterns (the mixing matrix) and + :math:`\mathbf{W}` is the matrix of filters (the unmixing matrix). For a detailed discussion on the difference between filters and patterns for GED - see :footcite:`Cohen2022` and :footcite:`HaufeEtAl2014` for linear models in - general. + see :footcite:`Cohen2022` and for linear models in + general see :footcite:`HaufeEtAl2014`. .. versionadded:: 1.11 @@ -205,7 +267,8 @@ class SpatialFilter: def __init__( self, info, - evecs, + filters, + *, evals=None, patterns=None, patterns_method="pinv", @@ -217,15 +280,32 @@ def __init__( ) self.info = info self.evals = evals - self.filters = evecs.T - + self.filters = filters + n_comps, n_chs = self.filters.shape[-2:] if patterns is None: - self.patterns = pinv(evecs) + # XXX Using numpy's pinv here to handle 3D case seamlessly + # Perhaps mne.linalg.pinv can be improved to handle 3D also + # Then it could be changed here to be consistent with + # GEDTransformer + self.patterns = np.linalg.pinv(filters) self.patterns_method = "pinv" else: self.patterns = patterns self.patterns_method = patterns_method + if n_comps > n_chs: + raise ValueError( + "Number of components can't be greater " + "than number of channels in filters," + "perhaps the provided matrix is transposed?" + ) + if self.filters.shape != self.patterns.shape: + raise ValueError( + f"Shape mismatch between filters and patterns." + f"Filters are {self.filters.shape}," + f"while patterns are {self.patterns.shape}" + ) + @fill_doc def plot_filters( self, @@ -457,7 +537,7 @@ def plot_scree( The figure. """ if self.evals is None: - raise ValueError("Can't plot scree if eigenvalues are not provided.") + raise AttributeError("Can't plot scree if eigenvalues are not provided.") fig = _plot_scree( self.evals, title=title, From 3266d16b71f7b06d402e7a27089639a4ced1ea71 Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 18 Jul 2025 14:29:57 +0300 Subject: [PATCH 05/23] nest EvokedArray import --- mne/viz/decoding/ged.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 229a70ae53e..caeea9754bd 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -8,7 +8,6 @@ import numpy as np from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT -from ...evoked import EvokedArray from ...utils import _check_option, fill_doc @@ -43,6 +42,8 @@ def _plot_model( ncols="auto", show=True, ): + from ...evoked import EvokedArray + if units is None: units = "AU" n_comps = model_array.shape[-2] From d720dc25aaefaeb47c562e312b8f6f51297fde80 Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 18 Jul 2025 14:58:42 +0300 Subject: [PATCH 06/23] add some minor tests and fixes --- mne/decoding/base.py | 4 ++-- mne/decoding/tests/test_base.py | 8 ++++++++ mne/decoding/tests/test_ged.py | 23 ++++++++++++++++++++++- mne/viz/decoding/ged.py | 6 ++++-- mne/viz/decoding/tests/test_ged.py | 18 ++++++++++++++++++ 5 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 mne/viz/decoding/tests/test_ged.py diff --git a/mne/decoding/base.py b/mne/decoding/base.py index c54cc5d120b..b09ecf56bd1 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -215,7 +215,7 @@ def get_spatial_filter(self, info): check_is_fitted(self, ["filters_", "patterns_", "evals_"]) sp_filter = SpatialFilter( info, - evecs=self.filters_, + filters=self.filters_, evals=self.evals_, patterns=self.patterns_, patterns_method="pinv", @@ -463,7 +463,7 @@ def get_spatial_filter(self, info): check_is_fitted(self, ["filters_", "patterns_"]) sp_filter = SpatialFilter( info, - evecs=self.filters_.T, + filters=self.filters_, patterns=self.patterns_, patterns_method="haufe", ) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index f17d4328279..9b8008203ed 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -418,6 +418,14 @@ def test_linearmodel(): wrong_y = rng.rand(n, n_features, 99) clf.fit(X, wrong_y) + # check get_spatial_filter + info = create_info(n_features, 1000.0, "eeg") + sp_filter = clf.get_spatial_filter(info) + assert sp_filter.patterns_method == "haufe" + np.testing.assert_array_equal(sp_filter.filters, clf.filters_) + np.testing.assert_array_equal(sp_filter.patterns, clf.patterns_) + assert sp_filter.evals is None + def test_cross_val_multiscore(): """Test cross_val_multiscore for computing scores on decoding over time.""" diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index 5f413f81aee..d407db8c2dc 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -145,7 +145,7 @@ def test_sklearn_compliance(estimator, check): check(estimator) -def _get_X_y(event_id): +def _get_X_y(event_id, return_info=False): raw = read_raw(raw_fname, preload=False) events = read_events(event_name) picks = pick_types( @@ -166,6 +166,8 @@ def _get_X_y(event_id): ) X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT")) y = epochs.events[:, -1] + if return_info: + return X, y, epochs.info return X, y @@ -386,3 +388,22 @@ def test__no_op_mod(): assert evals is evals_no_op assert evecs is evecs_no_op assert sorter_no_op is None + + +def test_get_spatial_filter(): + """Test instantiation of spatial filter.""" + event_id = dict(aud_l=1, vis_l=3) + X, y, info = _get_X_y(event_id, return_info=True) + + ged = _GEDTransformer( + n_components=4, + cov_callable=_mock_cov_callable, + mod_ged_callable=_mock_mod_ged_callable, + restr_type="restricting", + ) + ged.fit(X, y) + sp_filter = ged.get_spatial_filter(info) + assert sp_filter.patterns_method == "pinv" + np.testing.assert_array_equal(sp_filter.filters, ged.filters_) + np.testing.assert_array_equal(sp_filter.patterns, ged.patterns_) + np.testing.assert_array_equal(sp_filter.evals, ged.evals_) diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index caeea9754bd..f063deae347 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -294,10 +294,12 @@ def __init__( self.patterns = patterns self.patterns_method = patterns_method - if n_comps > n_chs: + # In case of multi-target classification in LinearModel + # number of targets can be greater than number of channels. + if patterns_method != "haufe" and n_comps > n_chs: raise ValueError( "Number of components can't be greater " - "than number of channels in filters," + "than number of channels in filters, " "perhaps the provided matrix is transposed?" ) if self.filters.shape != self.patterns.shape: diff --git a/mne/viz/decoding/tests/test_ged.py b/mne/viz/decoding/tests/test_ged.py new file mode 100644 index 00000000000..379194ef5ed --- /dev/null +++ b/mne/viz/decoding/tests/test_ged.py @@ -0,0 +1,18 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np +import pytest + +from mne import create_info +from mne.viz import SpatialFilter + + +def test_plot_scree_raises(): + """Tests that plot_scree can't plot without evals.""" + info = create_info(2, 1000.0, "eeg") + filters = np.array([[1, 2], [3, 4]]) + sp_filter = SpatialFilter(info, filters, evals=None) + with pytest.raises(AttributeError): + sp_filter.plot_scree() From d2fdc2dfaa36d40250700b8e857dbb001a6c3a12 Mon Sep 17 00:00:00 2001 From: Genuster Date: Sat, 19 Jul 2025 15:16:21 +0300 Subject: [PATCH 07/23] make get_coef to work with an arbitrary named step --- mne/decoding/base.py | 61 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index b09ecf56bd1..842c50b578b 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -559,8 +559,43 @@ def _get_inverse_funcs(estimator, terminal=True): return inverse_func +def _get_inverse_funcs_before_step(estimator, step_name): + """Get the inverse_transform methods for all steps before a target step.""" + # in case step_name is nested with __ + parts = step_name.split("__") + inverse_funcs = list() + current_pipeline = estimator + for i, part_name in enumerate(parts): + is_last_part = i == len(parts) - 1 + all_names = [name for name, _ in current_pipeline.steps] + if part_name not in all_names: + raise ValueError(f"Step '{part_name}' not found.") + part_idx = all_names.index(part_name) + + # get all preceding steps for the current step + for prec_name, prec_step in current_pipeline.steps[:part_idx]: + if hasattr(prec_step, "inverse_transform"): + inverse_funcs.append(prec_step.inverse_transform) + else: + warn( + f"Preceding step '{prec_name}' is not invertible " + f"and will be skipped." + ) + + next_estimator = current_pipeline.named_steps[part_name] + # check if pipeline + if hasattr(next_estimator, "steps"): + current_pipeline = next_estimator + # if not pipeline and not last part - wrong + elif not is_last_part: + raise ValueError(f"Step '{part_name}' is not a pipeline.") + return inverse_funcs + + @verbose -def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=None): +def get_coef( + estimator, attr="filters_", inverse_transform=False, *, step_name=None, verbose=None +): """Retrieve the coefficients of an estimator ending with a Linear Model. This is typically useful to retrieve "spatial filters" or "spatial @@ -576,6 +611,13 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non inverse_transform : bool If True, returns the coefficients after inverse transforming them with the transformer steps of the estimator. + step_name : str + Name of the sklearn's pipeline step to get the coef from. + If inverse_transform is True, the inverse transformations + will be applied using transformers before this step. + If None, the last step will be used. Defaults to None. + + .. versionadded:: 1.11 %(verbose)s Returns @@ -590,8 +632,14 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non # Get the coefficients of the last estimator in case of nested pipeline est = estimator logger.debug(f"Getting coefficients from estimator: {est.__class__.__name__}") - while hasattr(est, "steps"): - est = est.steps[-1][1] + + if step_name is not None: + if not hasattr(estimator, "named_steps"): + raise ValueError("'step_name' can only be used with a Pipeline estimator.") + est = est.named_steps[step_name] + else: + while hasattr(est, "steps"): + est = est.steps[-1][1] squeeze_first_dim = False @@ -620,9 +668,14 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non raise ValueError( "inverse_transform can only be applied onto pipeline estimators." ) + if step_name is None: + inverse_funcs = _get_inverse_funcs(estimator) + else: + inverse_funcs = _get_inverse_funcs_before_step(estimator, step_name) + # The inverse_transform parameter will call this method on any # estimator contained in the pipeline, in reverse order. - for inverse_func in _get_inverse_funcs(estimator)[::-1]: + for inverse_func in inverse_funcs[::-1]: logger.debug(f" Applying inverse transformation: {inverse_func}.") coef = inverse_func(coef) From 4ac29d5a2339179783c7c13bed6ad1816095f69d Mon Sep 17 00:00:00 2001 From: Genuster Date: Sat, 19 Jul 2025 18:02:30 +0300 Subject: [PATCH 08/23] clean get_coef's inverse_transform test --- mne/decoding/base.py | 4 ++-- mne/decoding/tests/test_base.py | 31 +++++++++++++++---------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 842c50b578b..71823b4fa8a 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -635,8 +635,8 @@ def get_coef( if step_name is not None: if not hasattr(estimator, "named_steps"): - raise ValueError("'step_name' can only be used with a Pipeline estimator.") - est = est.named_steps[step_name] + raise ValueError("'step_name' can only be used with a pipeline estimator.") + est = est.get_params(deep=True)[step_name] else: while hasattr(est, "steps"): est = est.steps[-1][1] diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 9b8008203ed..c6fce5deb20 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -239,24 +239,13 @@ def transform(self, X): @pytest.mark.parametrize("inverse", (True, False)) -@pytest.mark.parametrize( - "Scale, kwargs", - [ - (Scaler, dict(info=None, scalings="mean")), - (_Noop, dict()), - ], -) -def test_get_coef_inverse_transform(inverse, Scale, kwargs): +def test_get_coef_inverse_transform(inverse): """Test get_coef with and without inverse_transform.""" lm_regression = LinearModel(Ridge()) X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1) # Check with search_light and combination of preprocessing ending with sl: - # slider = SlidingEstimator(make_pipeline(StandardScaler(), lm_regression)) - # XXX : line above should work but does not as only last step is - # used in get_coef ... - slider = SlidingEstimator(make_pipeline(lm_regression)) + clf = SlidingEstimator(make_pipeline(StandardScaler(), lm_regression)) X = np.transpose([X, -X], [1, 2, 0]) # invert X across 2 time samples - clf = make_pipeline(Scale(**kwargs), slider) clf.fit(X, y) patterns = get_coef(clf, "patterns_", inverse) filters = get_coef(clf, "filters_", inverse) @@ -265,10 +254,20 @@ def test_get_coef_inverse_transform(inverse, Scale, kwargs): assert_equal(patterns[0, 0], -patterns[0, 1]) for t in [0, 1]: filters_t = get_coef( - clf.named_steps["slidingestimator"].estimators_[t], "filters_", False + clf.estimators_[t], + "filters_", + inverse, + verbose=False, + ) + assert_array_equal(filters_t, filters[:, t]) + + with pytest.raises(ValueError, match=r"pipeline estimator"): + _ = get_coef( + clf, + "filters_", + inverse, + step_name="slidingestimator__pipeline__linearmodel", ) - if Scale is _Noop: - assert_array_equal(filters_t, filters[:, t]) @pytest.mark.parametrize("n_features", [1, 5]) From 37cd069b0a9ae9db29ce1f8a64e0df1c5afaa3b2 Mon Sep 17 00:00:00 2001 From: Genuster Date: Sat, 19 Jul 2025 18:52:07 +0300 Subject: [PATCH 09/23] make instantiation of the spatial filter from ged/linear model a standalone function --- mne/decoding/base.py | 56 +------------------------ mne/viz/decoding/ged.py | 93 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 57 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 71823b4fa8a..b1a329d0852 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -26,7 +26,6 @@ from ..parallel import parallel_func from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn -from ..viz import SpatialFilter from ._ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged from ._mod_ged import _no_op_mod from .transformer import MNETransformerMixin @@ -194,34 +193,6 @@ def transform(self, X): X = pick_filters @ X return X - def get_spatial_filter(self, info): - """Create a SpatialFilter object. - - Creates an `mne.viz.SpatialFilter` object from the fitted - generalized eigendecomposition. - This object can be used to visualize the spatial filters, - patterns, and eigenvalues. - - Parameters - ---------- - info : instance of mne.Info - The measurement info object for plotting topomaps. - - Returns - ------- - sp_filter : instance of mne.viz.SpatialFilter - The spatial filter object. - """ - check_is_fitted(self, ["filters_", "patterns_", "evals_"]) - sp_filter = SpatialFilter( - info, - filters=self.filters_, - evals=self.evals_, - patterns=self.patterns_, - patterns_method="pinv", - ) - return sp_filter - def _subset_multi_components(self, name="filters"): # The shape of stored filters and patterns is # is (n_classes, n_evecs, n_chs) @@ -444,31 +415,6 @@ def filters_(self): filters = filters[0] return filters - def get_spatial_filter(self, info): - """Create a SpatialFilter object. - - Creates an `mne.viz.SpatialFilter` object from the linear model. - This object can be used to visualize model weights and patterns. - - Parameters - ---------- - info : instance of mne.Info - The measurement info object for plotting topomaps. - - Returns - ------- - sp_filter : instance of mne.viz.SpatialFilter - The spatial filter object. - """ - check_is_fitted(self, ["filters_", "patterns_"]) - sp_filter = SpatialFilter( - info, - filters=self.filters_, - patterns=self.patterns_, - patterns_method="haufe", - ) - return sp_filter - def _set_cv(cv, estimator=None, X=None, y=None): """Set the default CV depending on whether clf is classifier/regressor.""" @@ -611,7 +557,7 @@ def get_coef( inverse_transform : bool If True, returns the coefficients after inverse transforming them with the transformer steps of the estimator. - step_name : str + step_name : str | None Name of the sklearn's pipeline step to get the coef from. If inverse_transform is True, the inverse transformations will be applied using transformers before this step. diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index f063deae347..10040441040 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -7,6 +7,8 @@ import matplotlib.pyplot as plt import numpy as np +from ...decoding import get_coef +from ...decoding.base import LinearModel, _GEDTransformer from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ...utils import _check_option, fill_doc @@ -198,6 +200,93 @@ def _plot_scree( return fig +@fill_doc +def get_spatial_filter_from_estimator( + estimator, + info, + *, + inverse_transform=False, + step_name=None, + get_coefs=("filters_", "patterns_", "evals_"), + patterns_method=None, + verbose=None, +): + """Instantiate a :class:`mne.viz.SpatialFilter` object. + + Creates object from the fitted generalized eigendecomposition + transformers or :class":`mne.decoding.LinearModel`. + This object can be used to visualize spatial filters, + patterns, and eigenvalues. + + Parameters + ---------- + estimator : instance of sklearn.BaseEstimator + Sklearn-based estimator or meta-estimator from which to initialize + spatial filter. Use ``step_name`` to select relevant transformer + from the pipeline object (works with nested names using ``__`` syntax). + info : instance of mne.Info + The measurement info object for plotting topomaps. + inverse_transform : bool + If True, returns filters and patterns after inverse transforming them with + the transformer steps of the estimator. Defaults to False. + step_name : str | None + Name of the sklearn's pipeline step to get the coefs from. + If inverse_transform is True, the inverse transformations + will be applied using transformers before this step. + If None, the last step will be used. Defaults to None. + get_coefs : tuple + The names of the coefficient attributes to retrieve, can include + ``'filters_'``, ``'patterns_'`` and ``'evals_'``. + If step is GEDTransformer, will use all. + if step is LinearModel will only use ``'filters_'`` and ``'patterns_'``. + Defaults to (``'filters_'``, ``'patterns_'``, ``'evals_'``). + patterns_method : str + The method used to compute the patterns. Can be None, ``'pinv'`` or ``'haufe'``. + It will be set automatically to ``'pinv'`` if step is GEDTransformer, + or to ``'haufe'`` if step is LinearModel. Defaults to None. + %(verbose)s + + Returns + ------- + sp_filter : instance of mne.viz.SpatialFilter + The spatial filter object. + + See Also + -------- + SpatialFilter + """ + for coef in get_coefs: + if coef not in ("filters_", "patterns_", "evals_"): + raise ValueError( + f"'get_coefs' can only include 'filters_', " + f"'patterns_' and 'evals_', but got {coef}." + ) + if step_name is not None: + model = estimator.get_params()[step_name] + else: + model = estimator + if isinstance(model, LinearModel): + patterns_method = "haufe" + get_coefs = ["filters_", "patterns_"] + elif isinstance(model, _GEDTransformer): + patterns_method = "pinv" + get_coefs = ["filters_", "patterns_", "evals_"] + + coefs = { + coef[:-1]: get_coef( + estimator, + coef, + inverse_transform=False if coef == "evals_" else inverse_transform, + step_name=step_name, + verbose=verbose, + ) + for coef in get_coefs + } + + sp_filter = SpatialFilter(info, patterns_method=patterns_method, **coefs) + return sp_filter + + class SpatialFilter: r"""Visualization container for spatial filter weights (evecs) and patterns. @@ -221,7 +310,7 @@ class SpatialFilter: The eigenvalues of the decomposition. Defaults to ``None``. patterns : ndarray, shape ((n_classes), n_components, n_channels) | None The patterns of the decomposition. If None, they will be computed - from the eigenvectors using pseudoinverse. Defaults to ``None``. + from the filters using pseudoinverse. Defaults to ``None``. patterns_method : str The method used to compute the patterns. Can be ``'pinv'`` or ``'haufe'``. If `patterns` is None, it will be set to ``'pinv'``. Defaults to ``'pinv'``. @@ -288,7 +377,7 @@ def __init__( # Perhaps mne.linalg.pinv can be improved to handle 3D also # Then it could be changed here to be consistent with # GEDTransformer - self.patterns = np.linalg.pinv(filters) + self.patterns = np.linalg.pinv(filters.T) self.patterns_method = "pinv" else: self.patterns = patterns From 7d78df6fba3c8b72dc80535a5679e182950711ba Mon Sep 17 00:00:00 2001 From: Genuster Date: Sat, 19 Jul 2025 19:01:40 +0300 Subject: [PATCH 10/23] fix tests and nesting --- mne/decoding/tests/test_base.py | 17 +++++++++-------- mne/decoding/tests/test_ged.py | 3 ++- mne/viz/decoding/ged.py | 5 +++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index c6fce5deb20..21bd6ee8f34 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -50,6 +50,7 @@ ) from mne.decoding.search_light import SlidingEstimator from mne.utils import check_version +from mne.viz.decoding.ged import get_spatial_filter_from_estimator def _make_data(n_samples=1000, n_features=5, n_targets=3): @@ -228,14 +229,14 @@ def inverse_transform(self, X): assert patterns[0] != patterns_inv[0] -class _Noop(BaseEstimator, TransformerMixin): - def fit(self, X, y=None): - return self +# class _Noop(BaseEstimator, TransformerMixin): +# def fit(self, X, y=None): +# return self - def transform(self, X): - return X.copy() +# def transform(self, X): +# return X.copy() - inverse_transform = transform +# inverse_transform = transform @pytest.mark.parametrize("inverse", (True, False)) @@ -417,9 +418,9 @@ def test_linearmodel(): wrong_y = rng.rand(n, n_features, 99) clf.fit(X, wrong_y) - # check get_spatial_filter + # check get_spatial_filter_from_estimator info = create_info(n_features, 1000.0, "eeg") - sp_filter = clf.get_spatial_filter(info) + sp_filter = get_spatial_filter_from_estimator(clf, info) assert sp_filter.patterns_method == "haufe" np.testing.assert_array_equal(sp_filter.filters, clf.filters_) np.testing.assert_array_equal(sp_filter.patterns, clf.patterns_) diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index d407db8c2dc..bd0fc2af7cd 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -29,6 +29,7 @@ from mne.decoding._mod_ged import _no_op_mod from mne.decoding.base import _GEDTransformer from mne.io import read_raw +from mne.viz.decoding.ged import get_spatial_filter_from_estimator data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" @@ -402,7 +403,7 @@ def test_get_spatial_filter(): restr_type="restricting", ) ged.fit(X, y) - sp_filter = ged.get_spatial_filter(info) + sp_filter = get_spatial_filter_from_estimator(ged, info) assert sp_filter.patterns_method == "pinv" np.testing.assert_array_equal(sp_filter.filters, ged.filters_) np.testing.assert_array_equal(sp_filter.patterns, ged.patterns_) diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 10040441040..9d0121b23e8 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -7,8 +7,6 @@ import matplotlib.pyplot as plt import numpy as np -from ...decoding import get_coef -from ...decoding.base import LinearModel, _GEDTransformer from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ...utils import _check_option, fill_doc @@ -255,6 +253,9 @@ def get_spatial_filter_from_estimator( -------- SpatialFilter """ + from ...decoding import get_coef + from ...decoding.base import LinearModel, _GEDTransformer + for coef in get_coefs: if coef not in ("filters_", "patterns_", "evals_"): raise ValueError( From 874af243a777a6d31576f5e2ada2717394479cca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Jul 2025 19:44:46 +0000 Subject: [PATCH 11/23] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.12.4 → v0.12.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.12.4...v0.12.5) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc288a4e343..9b1d120d9ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.4 + rev: v0.12.5 hooks: - id: ruff-check name: ruff lint mne From f3da23a26756c303a8fe4e9338810eddb757da4f Mon Sep 17 00:00:00 2001 From: Genuster Date: Tue, 29 Jul 2025 18:19:55 +0300 Subject: [PATCH 12/23] add test for get_coef with step_name --- mne/decoding/base.py | 7 +++-- mne/decoding/tests/test_base.py | 56 +++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index b1a329d0852..d7e65869a19 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -514,8 +514,6 @@ def _get_inverse_funcs_before_step(estimator, step_name): for i, part_name in enumerate(parts): is_last_part = i == len(parts) - 1 all_names = [name for name, _ in current_pipeline.steps] - if part_name not in all_names: - raise ValueError(f"Step '{part_name}' not found.") part_idx = all_names.index(part_name) # get all preceding steps for the current step @@ -582,7 +580,10 @@ def get_coef( if step_name is not None: if not hasattr(estimator, "named_steps"): raise ValueError("'step_name' can only be used with a pipeline estimator.") - est = est.get_params(deep=True)[step_name] + try: + est = est.get_params(deep=True)[step_name] + except KeyError: + raise ValueError(f"Step '{step_name}' is not part of the pipeline.") else: while hasattr(est, "steps"): est = est.steps[-1][1] diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 21bd6ee8f34..698e002144b 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -28,6 +28,7 @@ is_classifier, is_regressor, ) +from sklearn.decomposition import PCA from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge from sklearn.model_selection import ( GridSearchCV, @@ -271,6 +272,61 @@ def test_get_coef_inverse_transform(inverse): ) +def test_get_coef_inverse_step_name(): + """Test get_coef with inverse_transform=True and a specific step_name.""" + X, y, _ = _make_data(n_samples=100, n_features=5, n_targets=1) + + # Test with a simple pipeline + pipe = make_pipeline( + StandardScaler(), PCA(n_components=3), LinearModel(Ridge(alpha=1)) + ) + pipe.fit(X, y) + + coef_inv_actual = get_coef( + pipe, attr="patterns_", inverse_transform=True, step_name="linearmodel" + ) + # Reshape your data using array.reshape(1, -1) if it contains a single sample. + coef_raw = pipe.named_steps["linearmodel"].patterns_.reshape(1, -1) + coef_inv_desired = pipe.named_steps["pca"].inverse_transform(coef_raw) + coef_inv_desired = pipe.named_steps["standardscaler"].inverse_transform( + coef_inv_desired + ) + + assert coef_inv_actual.shape == (X.shape[1],) + # Reshape your data using array.reshape(1, -1) if it contains a single sample. + assert_array_almost_equal(coef_inv_actual.reshape(1, -1), coef_inv_desired) + + # Test with a nested pipeline to check __ parsing + inner_pipe = make_pipeline(PCA(n_components=3), LinearModel(Ridge())) + nested_pipe = make_pipeline(StandardScaler(), inner_pipe) + nested_pipe.fit(X, y) + target_step_name = "pipeline__linearmodel" + coef_nested_inv_actual = get_coef( + nested_pipe, + attr="patterns_", + inverse_transform=True, + step_name=target_step_name, + ) + linearmodel = nested_pipe.named_steps["pipeline"].named_steps["linearmodel"] + pca = nested_pipe.named_steps["pipeline"].named_steps["pca"] + scaler = nested_pipe.named_steps["standardscaler"] + + coef_nested_raw = linearmodel.patterns_.reshape(1, -1) + coef_nested_inv_desired = pca.inverse_transform(coef_nested_raw) + coef_nested_inv_desired = scaler.inverse_transform(coef_nested_inv_desired) + + assert coef_nested_inv_actual.shape == (X.shape[1],) + assert_array_almost_equal( + coef_nested_inv_actual.reshape(1, -1), coef_nested_inv_desired + ) + + # Test error case + with pytest.raises(ValueError, match="i_do_not_exist"): + get_coef( + pipe, attr="patterns_", inverse_transform=True, step_name="i_do_not_exist" + ) + + @pytest.mark.parametrize("n_features", [1, 5]) @pytest.mark.parametrize("n_targets", [1, 3]) def test_get_coef_multiclass(n_features, n_targets): From 842764ed7123d49f2949475b5522c323305dd9db Mon Sep 17 00:00:00 2001 From: Genuster Date: Tue, 29 Jul 2025 19:38:34 +0300 Subject: [PATCH 13/23] add some initialziation and plotting tests --- mne/decoding/tests/test_base.py | 9 -- mne/decoding/tests/test_ged.py | 24 +---- mne/viz/__init__.pyi | 3 +- mne/viz/decoding/ged.py | 4 + mne/viz/decoding/tests/test_ged.py | 140 +++++++++++++++++++++++++++-- 5 files changed, 138 insertions(+), 42 deletions(-) diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 698e002144b..c3f3637fad9 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -51,7 +51,6 @@ ) from mne.decoding.search_light import SlidingEstimator from mne.utils import check_version -from mne.viz.decoding.ged import get_spatial_filter_from_estimator def _make_data(n_samples=1000, n_features=5, n_targets=3): @@ -474,14 +473,6 @@ def test_linearmodel(): wrong_y = rng.rand(n, n_features, 99) clf.fit(X, wrong_y) - # check get_spatial_filter_from_estimator - info = create_info(n_features, 1000.0, "eeg") - sp_filter = get_spatial_filter_from_estimator(clf, info) - assert sp_filter.patterns_method == "haufe" - np.testing.assert_array_equal(sp_filter.filters, clf.filters_) - np.testing.assert_array_equal(sp_filter.patterns, clf.patterns_) - assert sp_filter.evals is None - def test_cross_val_multiscore(): """Test cross_val_multiscore for computing scores on decoding over time.""" diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index bd0fc2af7cd..5f413f81aee 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -29,7 +29,6 @@ from mne.decoding._mod_ged import _no_op_mod from mne.decoding.base import _GEDTransformer from mne.io import read_raw -from mne.viz.decoding.ged import get_spatial_filter_from_estimator data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" @@ -146,7 +145,7 @@ def test_sklearn_compliance(estimator, check): check(estimator) -def _get_X_y(event_id, return_info=False): +def _get_X_y(event_id): raw = read_raw(raw_fname, preload=False) events = read_events(event_name) picks = pick_types( @@ -167,8 +166,6 @@ def _get_X_y(event_id, return_info=False): ) X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT")) y = epochs.events[:, -1] - if return_info: - return X, y, epochs.info return X, y @@ -389,22 +386,3 @@ def test__no_op_mod(): assert evals is evals_no_op assert evecs is evecs_no_op assert sorter_no_op is None - - -def test_get_spatial_filter(): - """Test instantiation of spatial filter.""" - event_id = dict(aud_l=1, vis_l=3) - X, y, info = _get_X_y(event_id, return_info=True) - - ged = _GEDTransformer( - n_components=4, - cov_callable=_mock_cov_callable, - mod_ged_callable=_mock_mod_ged_callable, - restr_type="restricting", - ) - ged.fit(X, y) - sp_filter = get_spatial_filter_from_estimator(ged, info) - assert sp_filter.patterns_method == "pinv" - np.testing.assert_array_equal(sp_filter.filters, ged.filters_) - np.testing.assert_array_equal(sp_filter.patterns, ged.patterns_) - np.testing.assert_array_equal(sp_filter.evals, ged.evals_) diff --git a/mne/viz/__init__.pyi b/mne/viz/__init__.pyi index 3ec5954ba88..1078195e622 100644 --- a/mne/viz/__init__.pyi +++ b/mne/viz/__init__.pyi @@ -23,6 +23,7 @@ __all__ = [ "get_3d_backend", "get_brain_class", "get_browser_backend", + "get_spatial_filter_from_estimator", "iter_topography", "link_brains", "mne_analyze_colormap", @@ -119,7 +120,7 @@ from .backends.renderer import ( use_3d_backend, ) from .circle import circular_layout, plot_channel_labels_circle -from .decoding import SpatialFilter +from .decoding.ged import SpatialFilter, get_spatial_filter_from_estimator from .epochs import plot_drop_log, plot_epochs, plot_epochs_image, plot_epochs_psd from .evoked import ( plot_compare_evokeds, diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 9d0121b23e8..c84107fb72c 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -9,6 +9,7 @@ from ...defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ...utils import _check_option, fill_doc +from ..utils import plt_show def _plot_model( @@ -607,6 +608,7 @@ def plot_scree( add_cumul_evals=True, plt_style="seaborn-v0_8-whitegrid", axes=None, + show=True, ): """Plot scree for GED eigenvalues. @@ -631,6 +633,7 @@ def plot_scree( """ if self.evals is None: raise AttributeError("Can't plot scree if eigenvalues are not provided.") + fig = _plot_scree( self.evals, title=title, @@ -638,4 +641,5 @@ def plot_scree( plt_style=plt_style, axes=axes, ) + plt_show(show, block=False) return fig diff --git a/mne/viz/decoding/tests/test_ged.py b/mne/viz/decoding/tests/test_ged.py index 379194ef5ed..a9d8ae64943 100644 --- a/mne/viz/decoding/tests/test_ged.py +++ b/mne/viz/decoding/tests/test_ged.py @@ -2,17 +2,139 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from collections.abc import Iterable +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np import pytest +from numpy.testing import assert_array_equal +from sklearn.linear_model import LinearRegression + +from mne import Epochs, create_info, io, pick_types, read_events +from mne.decoding import CSP, LinearModel +from mne.viz import SpatialFilter, get_spatial_filter_from_estimator + +data_dir = Path(__file__).parents[3] / "io" / "tests" / "data" +raw_fname = data_dir / "test_raw.fif" +event_name = data_dir / "test-eve.fif" +tmin, tmax = -0.1, 0.2 +event_id = dict(aud_l=1, vis_l=3) +start, stop = 0, 8 + + +def _get_X_y(event_id, return_info=False): + raw = io.read_raw(raw_fname, preload=False) + events = read_events(event_name) + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) + picks = picks[2:12:3] # subselect channels -> disable proj! + raw.add_proj([], remove_existing=True) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + proj=False, + ) + X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT")) + y = epochs.events[:, -1] + if return_info: + return X, y, epochs.info + return X, y + + +def test_spatial_filter_init(): + """Test the initialization of the SpatialFilter class.""" + # Test initialization and factory function + rng = np.random.RandomState(0) + n, n_features = 20, 3 + X = rng.rand(n, n_features) + n_targets = 5 + y = rng.rand(n, n_targets) + clf = LinearModel(LinearRegression()) + clf.fit(X, y) + + # test get_spatial_filter_from_estimator for LinearModel + info = create_info(n_features, 1000.0, "eeg") + sp_filter = get_spatial_filter_from_estimator(clf, info) + assert sp_filter.patterns_method == "haufe" + assert_array_equal(sp_filter.filters, clf.filters_) + assert_array_equal(sp_filter.patterns, clf.patterns_) + assert sp_filter.evals is None + + event_id = dict(aud_l=1, vis_l=3) + X, y, info = _get_X_y(event_id, return_info=True) + csp = CSP(n_components=4) + csp.fit(X, y) + + # test get_spatial_filter_from_estimator for GED + sp_filter = get_spatial_filter_from_estimator(csp, info) + assert sp_filter.patterns_method == "pinv" + np.testing.assert_array_equal(sp_filter.filters, csp.filters_) + np.testing.assert_array_equal(sp_filter.patterns, csp.patterns_) + np.testing.assert_array_equal(sp_filter.evals, csp.evals_) + assert sp_filter.info is info + + # test basic initialization + sp_filter = SpatialFilter( + info, filters=csp.filters_, patterns=csp.patterns_, evals=csp.evals_ + ) + assert_array_equal(sp_filter.filters, csp.filters_) + assert_array_equal(sp_filter.patterns, csp.patterns_) + assert_array_equal(sp_filter.evals, csp.evals_) + assert sp_filter.info is info + + # test automatic pattern calculation via pinv + sp_filter_pinv = SpatialFilter(info, filters=csp.filters_, evals=csp.evals_) + patterns_pinv = np.linalg.pinv(csp.filters_.T) + assert_array_equal(sp_filter_pinv.patterns, patterns_pinv) + assert sp_filter_pinv.patterns_method == "pinv" + + # test shape mismatch error + with pytest.raises(ValueError, match="Shape mismatch"): + SpatialFilter(info, filters=csp.filters_, patterns=csp.patterns_[:-1]) + + # test invalid patterns_method + with pytest.raises(ValueError, match="patterns_method"): + SpatialFilter(info, filters=csp.filters_, patterns_method="foo") + + # test n_components > n_channels error + bad_filters = np.random.randn(31, 30) # 31 components, 30 channels + with pytest.raises(ValueError, match="Number of components can't be greater"): + SpatialFilter(info, filters=bad_filters) + + +def test_spatial_filter_plotting(): + """Test the plotting methods of SpatialFilter.""" + event_id = dict(aud_l=1, vis_l=3) + X, y, info = _get_X_y(event_id, return_info=True) + csp = CSP(n_components=4) + csp.fit(X, y) + + sp_filter = get_spatial_filter_from_estimator(csp, info) + + # test plot_filters + fig_filters = sp_filter.plot_filters(components=[0, 1], show=False) + assert isinstance(fig_filters, plt.Figure | Iterable) + plt.close("all") -from mne import create_info -from mne.viz import SpatialFilter + # test plot_patterns + fig_patterns = sp_filter.plot_patterns(show=False) + assert isinstance(fig_patterns, plt.Figure | Iterable) + plt.close("all") + # test plot_scree + fig_scree = sp_filter.plot_scree(show=False) + assert isinstance(fig_scree, plt.Figure) + plt.close("all") -def test_plot_scree_raises(): - """Tests that plot_scree can't plot without evals.""" - info = create_info(2, 1000.0, "eeg") - filters = np.array([[1, 2], [3, 4]]) - sp_filter = SpatialFilter(info, filters, evals=None) - with pytest.raises(AttributeError): - sp_filter.plot_scree() + # test plot_scree raises error if evals is None + sp_filter_no_evals = SpatialFilter(info, filters=csp.filters_, evals=None) + with pytest.raises(AttributeError, match="eigenvalues are not provided"): + sp_filter_no_evals.plot_scree() From 06efd971bd75bbd5c1c22801174266db90080521 Mon Sep 17 00:00:00 2001 From: Genuster Date: Tue, 29 Jul 2025 20:25:27 +0300 Subject: [PATCH 14/23] fix docstrings --- mne/viz/decoding/ged.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index c84107fb72c..a8167fe48b9 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -293,12 +293,12 @@ class SpatialFilter: r"""Visualization container for spatial filter weights (evecs) and patterns. This object is obtained either by generalized eigendecomposition (GED) algorithms - such as `mne.decoding.CSP`, `mne.decoding.SPoC`, `mne.decoding.SSD`, - `mne.decoding.XdawnTransformer` or by `mne.decoding.LinearModel`, - wrapping linear models like SVM or Logit. + such as :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, + :class:`mne.decoding.SSD`, :class:`mne.decoding.XdawnTransformer` or by + :class:`mne.decoding.LinearModel`, wrapping linear models like SVM or Logit. The object stores the filters that projects sensor data to a reduced component space, and the corresponding patterns (obtained by pseudoinverse in GED case or - Haufe's trick in case of `mne.decoding.LinearModel`). It can also be directly + Haufe's trick in case of :class:`mne.decoding.LinearModel`). It can also be directly initialized using filters from other transformers (e.g. PyRiemann), but make sure that the dimensions match. @@ -315,7 +315,7 @@ class SpatialFilter: from the filters using pseudoinverse. Defaults to ``None``. patterns_method : str The method used to compute the patterns. Can be ``'pinv'`` or ``'haufe'``. - If `patterns` is None, it will be set to ``'pinv'``. Defaults to ``'pinv'``. + If ``patterns`` is None, it will be set to ``'pinv'``. Defaults to ``'pinv'``. Attributes ---------- @@ -602,6 +602,7 @@ def plot_patterns( ) return fig + @fill_doc def plot_scree( self, title="Scree plot", @@ -625,6 +626,7 @@ def plot_scree( Defaults to ``'seaborn-v0_8-whitegrid'``. axes : instance of Axes | None The matplotlib axes to plot to. Defaults to ``None``. + %(show)s Returns ------- From 49dbdaa126f7dc1994952cb92e531af5a661645b Mon Sep 17 00:00:00 2001 From: Genuster Date: Tue, 29 Jul 2025 20:32:17 +0300 Subject: [PATCH 15/23] add factory func to doc api --- doc/api/visualization.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api/visualization.rst b/doc/api/visualization.rst index 39381b62a99..0fba65d81d8 100644 --- a/doc/api/visualization.rst +++ b/doc/api/visualization.rst @@ -18,6 +18,7 @@ Visualization EvokedField Figure3D SpatialFilter + get_spatial_filter_from_estimator add_background_image centers_to_edges compare_fiff From 9f6b62688559c642e1138a7d2f06531318c91493 Mon Sep 17 00:00:00 2001 From: Genuster Date: Tue, 29 Jul 2025 20:59:55 +0300 Subject: [PATCH 16/23] another docstring fix and sklearn importorskip --- mne/viz/decoding/ged.py | 2 +- mne/viz/decoding/tests/test_ged.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index a8167fe48b9..b6e98ea52b9 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -219,7 +219,7 @@ def get_spatial_filter_from_estimator( Parameters ---------- - estimator : instance of sklearn.BaseEstimator + estimator : instance of sklearn.base.BaseEstimator Sklearn-based estimator or meta-estimator from which to initialize spatial filter. Use ``step_name`` to select relevant transformer from the pipeline object (works with nested names using ``__`` syntax). diff --git a/mne/viz/decoding/tests/test_ged.py b/mne/viz/decoding/tests/test_ged.py index a9d8ae64943..20847dbd567 100644 --- a/mne/viz/decoding/tests/test_ged.py +++ b/mne/viz/decoding/tests/test_ged.py @@ -9,6 +9,9 @@ import numpy as np import pytest from numpy.testing import assert_array_equal + +pytest.importorskip("sklearn") + from sklearn.linear_model import LinearRegression from mne import Epochs, create_info, io, pick_types, read_events From 04c7d0118aca6c3c5018592901cab02445f960ae Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 30 Jul 2025 00:16:33 +0300 Subject: [PATCH 17/23] more get_coef tests --- mne/decoding/base.py | 15 +++-------- mne/decoding/tests/test_base.py | 48 +++++++++++++++++++++++---------- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index f8a141b1312..c8a4a7cdf83 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -562,11 +562,9 @@ def _get_inverse_funcs_before_step(estimator, step_name): parts = step_name.split("__") inverse_funcs = list() current_pipeline = estimator - for i, part_name in enumerate(parts): - is_last_part = i == len(parts) - 1 + for part_name in parts: all_names = [name for name, _ in current_pipeline.steps] part_idx = all_names.index(part_name) - # get all preceding steps for the current step for prec_name, prec_step in current_pipeline.steps[:part_idx]: if hasattr(prec_step, "inverse_transform"): @@ -576,14 +574,7 @@ def _get_inverse_funcs_before_step(estimator, step_name): f"Preceding step '{prec_name}' is not invertible " f"and will be skipped." ) - - next_estimator = current_pipeline.named_steps[part_name] - # check if pipeline - if hasattr(next_estimator, "steps"): - current_pipeline = next_estimator - # if not pipeline and not last part - wrong - elif not is_last_part: - raise ValueError(f"Step '{part_name}' is not a pipeline.") + current_pipeline = current_pipeline.named_steps[part_name] return inverse_funcs @@ -630,7 +621,7 @@ def get_coef( if step_name is not None: if not hasattr(estimator, "named_steps"): - raise ValueError("'step_name' can only be used with a pipeline estimator.") + raise ValueError("step_name can only be used with a pipeline estimator.") try: est = est.get_params(deep=True)[step_name] except KeyError: diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index c3f3637fad9..d8df1bfab35 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -262,23 +262,13 @@ def test_get_coef_inverse_transform(inverse): ) assert_array_equal(filters_t, filters[:, t]) - with pytest.raises(ValueError, match=r"pipeline estimator"): - _ = get_coef( - clf, - "filters_", - inverse, - step_name="slidingestimator__pipeline__linearmodel", - ) - def test_get_coef_inverse_step_name(): """Test get_coef with inverse_transform=True and a specific step_name.""" X, y, _ = _make_data(n_samples=100, n_features=5, n_targets=1) # Test with a simple pipeline - pipe = make_pipeline( - StandardScaler(), PCA(n_components=3), LinearModel(Ridge(alpha=1)) - ) + pipe = make_pipeline(StandardScaler(), PCA(n_components=3), LinearModel(Ridge())) pipe.fit(X, y) coef_inv_actual = get_coef( @@ -295,16 +285,29 @@ def test_get_coef_inverse_step_name(): # Reshape your data using array.reshape(1, -1) if it contains a single sample. assert_array_almost_equal(coef_inv_actual.reshape(1, -1), coef_inv_desired) + with pytest.raises(ValueError, match="inverse_transform"): + _ = get_coef( + pipe[-1], # LinearModel + "filters_", + inverse_transform=True, + ) + with pytest.raises(ValueError, match="step_name"): + _ = get_coef( + SlidingEstimator(pipe), + "filters_", + inverse_transform=True, + step_name="slidingestimator__pipeline__linearmodel", + ) + # Test with a nested pipeline to check __ parsing inner_pipe = make_pipeline(PCA(n_components=3), LinearModel(Ridge())) nested_pipe = make_pipeline(StandardScaler(), inner_pipe) nested_pipe.fit(X, y) - target_step_name = "pipeline__linearmodel" coef_nested_inv_actual = get_coef( nested_pipe, attr="patterns_", inverse_transform=True, - step_name=target_step_name, + step_name="pipeline__linearmodel", ) linearmodel = nested_pipe.named_steps["pipeline"].named_steps["linearmodel"] pca = nested_pipe.named_steps["pipeline"].named_steps["pca"] @@ -319,12 +322,29 @@ def test_get_coef_inverse_step_name(): coef_nested_inv_actual.reshape(1, -1), coef_nested_inv_desired ) - # Test error case with pytest.raises(ValueError, match="i_do_not_exist"): get_coef( pipe, attr="patterns_", inverse_transform=True, step_name="i_do_not_exist" ) + class NonInvertibleTransformer(BaseEstimator, TransformerMixin): + def fit(self, X, y=None): + return self + + def transform(self, X): + # In a real scenario, this would modify X + return X + + pipe = make_pipeline(NonInvertibleTransformer(), LinearModel(Ridge())) + pipe.fit(X, y) + with pytest.warns(RuntimeWarning, match="not invertible"): + _ = get_coef( + pipe, + "filters_", + inverse_transform=True, + step_name="linearmodel", + ) + @pytest.mark.parametrize("n_features", [1, 5]) @pytest.mark.parametrize("n_targets", [1, 3]) From 16e35b796da238237549f0b8825d18e94e202978 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Tue, 29 Jul 2025 17:28:47 -0400 Subject: [PATCH 18/23] FIX: Timeout --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c6a36af99d0..31a0db70cec 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -244,7 +244,7 @@ stages: PYTHONIOENCODING: 'utf-8' AZURE_CI_WINDOWS: 'true' PYTHON_ARCH: 'x64' - timeoutInMinutes: 90 + timeoutInMinutes: 95 strategy: maxParallel: 4 matrix: From d355011a75fd0cde108555f00b92581ddd25bf0f Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 30 Jul 2025 00:42:38 +0300 Subject: [PATCH 19/23] more spatial filter viz tests --- mne/viz/decoding/ged.py | 6 +++- mne/viz/decoding/tests/test_ged.py | 52 +++++++++++++++++++++++++----- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index b6e98ea52b9..5fff9a90e64 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -181,7 +181,11 @@ def _plot_scree( layout="constrained", ) else: - assert len(axes) == n_classes + if len(axes) != n_classes: + raise ValueError( + "Number of provided axes should be " + "equal to the number of classes" + ) fig = None for class_idx in range(n_classes): _plot_scree_per_class( diff --git a/mne/viz/decoding/tests/test_ged.py b/mne/viz/decoding/tests/test_ged.py index 20847dbd567..6ca22c54f58 100644 --- a/mne/viz/decoding/tests/test_ged.py +++ b/mne/viz/decoding/tests/test_ged.py @@ -2,7 +2,6 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -from collections.abc import Iterable from pathlib import Path import matplotlib.pyplot as plt @@ -13,9 +12,11 @@ pytest.importorskip("sklearn") from sklearn.linear_model import LinearRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler from mne import Epochs, create_info, io, pick_types, read_events -from mne.decoding import CSP, LinearModel +from mne.decoding import CSP, LinearModel, Vectorizer, XdawnTransformer from mne.viz import SpatialFilter, get_spatial_filter_from_estimator data_dir = Path(__file__).parents[3] / "io" / "tests" / "data" @@ -71,13 +72,18 @@ def test_spatial_filter_init(): assert_array_equal(sp_filter.patterns, clf.patterns_) assert sp_filter.evals is None + with pytest.raises(ValueError, match="can only include"): + _ = get_spatial_filter_from_estimator( + clf, info, get_coefs=("foo", "foo", "foo") + ) + event_id = dict(aud_l=1, vis_l=3) X, y, info = _get_X_y(event_id, return_info=True) - csp = CSP(n_components=4) - csp.fit(X, y) - + estimator = make_pipeline(Vectorizer(), StandardScaler(), CSP(n_components=4)) + estimator.fit(X, y) + csp = estimator[-1] # test get_spatial_filter_from_estimator for GED - sp_filter = get_spatial_filter_from_estimator(csp, info) + sp_filter = get_spatial_filter_from_estimator(estimator, info, step_name="csp") assert sp_filter.patterns_method == "pinv" np.testing.assert_array_equal(sp_filter.filters, csp.filters_) np.testing.assert_array_equal(sp_filter.patterns, csp.patterns_) @@ -124,20 +130,50 @@ def test_spatial_filter_plotting(): # test plot_filters fig_filters = sp_filter.plot_filters(components=[0, 1], show=False) - assert isinstance(fig_filters, plt.Figure | Iterable) + assert isinstance(fig_filters, plt.Figure) plt.close("all") # test plot_patterns fig_patterns = sp_filter.plot_patterns(show=False) - assert isinstance(fig_patterns, plt.Figure | Iterable) + assert isinstance(fig_patterns, plt.Figure) plt.close("all") # test plot_scree fig_scree = sp_filter.plot_scree(show=False) assert isinstance(fig_scree, plt.Figure) plt.close("all") + _, axes = plt.subplots(figsize=(12, 7), layout="constrained") + fig_scree = sp_filter.plot_scree(axes=axes, show=False) + assert fig_scree is None + plt.close("all") # test plot_scree raises error if evals is None sp_filter_no_evals = SpatialFilter(info, filters=csp.filters_, evals=None) with pytest.raises(AttributeError, match="eigenvalues are not provided"): sp_filter_no_evals.plot_scree() + + # 3D case ('multi' GED decomposition) + n_classes = 2 + event_id = dict(aud_l=1, vis_l=3) + X, y, info = _get_X_y(event_id, return_info=True) + xdawn = XdawnTransformer(n_components=4) + xdawn.fit(X, y) + sp_filter = get_spatial_filter_from_estimator(xdawn, info) + + fig_patterns = sp_filter.plot_patterns(show=False) + assert len(fig_patterns) == n_classes + plt.close("all") + + fig_scree = sp_filter.plot_scree(show=False) + gs = fig_scree.axes[0].get_gridspec() + assert gs.nrows == n_classes + plt.close("all") + + with pytest.raises(ValueError, match="should be equal"): + _, axes = plt.subplots(figsize=(12, 7), layout="constrained") + _ = sp_filter.plot_scree(axes=axes, show=False) + + _, axes = plt.subplots(n_classes, figsize=(12, 7), layout="constrained") + fig_scree = sp_filter.plot_scree(axes=axes, show=False) + assert fig_scree is None + plt.close("all") From 395ac48380cd70dbe4d86ffc847c4b1cba49418f Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 30 Jul 2025 00:58:01 +0300 Subject: [PATCH 20/23] replace CSP's plots --- mne/decoding/csp.py | 43 +++++++++++------------------------------ mne/viz/decoding/ged.py | 2 +- 2 files changed, 12 insertions(+), 33 deletions(-) diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index be20b968f07..4cde939a672 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -3,19 +3,18 @@ # Copyright the MNE-Python contributors. import collections.abc as abc -import copy as cp from functools import partial import numpy as np from .._fiff.meas_info import Info from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT -from ..evoked import EvokedArray from ..utils import ( _check_option, _validate_type, fill_doc, ) +from ..viz.decoding.ged import _plot_model from ._covs_ged import _csp_estimate, _spoc_estimate from ._mod_ged import _csp_mod, _spoc_mod from .base import _GEDTransformer @@ -402,20 +401,10 @@ def plot_patterns( fig : instance of matplotlib.figure.Figure The figure. """ - if units is None: - units = "AU" - if components is None: - components = np.arange(self.n_components) - - # set sampling frequency to have 1 component per time point - info = cp.deepcopy(info) - with info._unlock(): - info["sfreq"] = 1.0 - # create an evoked - patterns = EvokedArray(self.patterns_.T, info, tmin=0) - # the call plot_topomap - fig = patterns.plot_topomap( - times=components, + fig = _plot_model( + self.patterns_, + info, + components, ch_type=ch_type, scalings=scalings, sensors=sensors, @@ -437,7 +426,7 @@ def plot_patterns( cbar_fmt=cbar_fmt, units=units, axes=axes, - time_format=name_format, + name_format=name_format, nrows=nrows, ncols=ncols, show=show, @@ -530,20 +519,10 @@ def plot_filters( fig : instance of matplotlib.figure.Figure The figure. """ - if units is None: - units = "AU" - if components is None: - components = np.arange(self.n_components) - - # set sampling frequency to have 1 component per time point - info = cp.deepcopy(info) - with info._unlock(): - info["sfreq"] = 1.0 - # create an evoked - filters = EvokedArray(self.filters_.T, info, tmin=0) - # the call plot_topomap - fig = filters.plot_topomap( - times=components, + fig = _plot_model( + self.filters_, + info, + components, ch_type=ch_type, scalings=scalings, sensors=sensors, @@ -565,7 +544,7 @@ def plot_filters( cbar_fmt=cbar_fmt, units=units, axes=axes, - time_format=name_format, + name_format=name_format, nrows=nrows, ncols=ncols, show=show, diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 5fff9a90e64..81a79ee4461 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -217,7 +217,7 @@ def get_spatial_filter_from_estimator( """Instantiate a :class:`mne.viz.SpatialFilter` object. Creates object from the fitted generalized eigendecomposition - transformers or :class":`mne.decoding.LinearModel`. + transformers or :class:`mne.decoding.LinearModel`. This object can be used to visualize spatial filters, patterns, and eigenvalues. From 4ef949cbd70eb2ae6432b0e44beaf45779144e46 Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 30 Jul 2025 00:58:11 +0300 Subject: [PATCH 21/23] add changelog entry --- doc/changes/dev/13332.newfeature.rst | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 doc/changes/dev/13332.newfeature.rst diff --git a/doc/changes/dev/13332.newfeature.rst b/doc/changes/dev/13332.newfeature.rst new file mode 100644 index 00000000000..205c18b3810 --- /dev/null +++ b/doc/changes/dev/13332.newfeature.rst @@ -0,0 +1,4 @@ +Implement :class:`mne.viz.SpatialFilter` viz class for filters, +patterns for :class:`LinearModel`and additionally +eigenvalues for GED-based transformers such as +:class:`mne.decoding.XdawnTransformer`, by `Gennadiy Belonosov`_. \ No newline at end of file From de0fe5dabad0d759b5a4ab0c8712ad0b2e4d3f8d Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 30 Jul 2025 01:35:51 +0300 Subject: [PATCH 22/23] fix changelog --- doc/changes/dev/13332.newfeature.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/dev/13332.newfeature.rst b/doc/changes/dev/13332.newfeature.rst index 205c18b3810..93ad4c03ba6 100644 --- a/doc/changes/dev/13332.newfeature.rst +++ b/doc/changes/dev/13332.newfeature.rst @@ -1,4 +1,4 @@ Implement :class:`mne.viz.SpatialFilter` viz class for filters, -patterns for :class:`LinearModel`and additionally +patterns for :class:`mne.decoding.LinearModel` and additionally eigenvalues for GED-based transformers such as :class:`mne.decoding.XdawnTransformer`, by `Gennadiy Belonosov`_. \ No newline at end of file From 203e9852eb6768f8df13012d32154033f3fde4c7 Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 30 Jul 2025 02:16:08 +0300 Subject: [PATCH 23/23] fix axes handling --- mne/viz/decoding/ged.py | 48 +++++++++++++----------------- mne/viz/decoding/tests/test_ged.py | 2 +- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/mne/viz/decoding/ged.py b/mne/viz/decoding/ged.py index 81a79ee4461..65520df58cb 100644 --- a/mne/viz/decoding/ged.py +++ b/mne/viz/decoding/ged.py @@ -86,7 +86,7 @@ def _plot_model( colorbar=colorbar, cbar_fmt=cbar_fmt, units=units, - axes=axes[class_idx], + axes=axes[class_idx] if axes else None, time_format=name_format, nrows=nrows, ncols=ncols, @@ -170,34 +170,26 @@ def _plot_scree( plt_style="seaborn-v0_8-whitegrid", axes=None, ): - with plt.style.context(plt_style): - if evals.ndim == 2: - n_classes = evals.shape[0] - if axes is None: - fig, axes = plt.subplots( - nrows=n_classes, - ncols=1, - figsize=(12, 7 * n_classes), - layout="constrained", - ) - else: - if len(axes) != n_classes: - raise ValueError( - "Number of provided axes should be " - "equal to the number of classes" - ) - fig = None - for class_idx in range(n_classes): - _plot_scree_per_class( - evals[class_idx], add_cumul_evals, axes[class_idx] - ) - else: - if axes is None: - fig, axes = plt.subplots(figsize=(12, 7), layout="constrained") - else: - fig = None - _plot_scree_per_class(evals, add_cumul_evals, axes) + evals_data = evals if evals.ndim == 2 else [evals] + n_classes = len(evals_data) + axes = [axes] if isinstance(axes, plt.Axes) else axes + if axes is not None and n_classes != len(axes): + raise ValueError(f"Received {len(axes)} axes, but expected {n_classes}") + with plt.style.context(plt_style): + fig = None + if axes is None: + fig, axes = plt.subplots( + nrows=n_classes, + ncols=1, + figsize=(12, 7 * n_classes), + layout="constrained", + ) + axes = [axes] if n_classes == 1 else axes + for class_idx in range(n_classes): + _plot_scree_per_class( + evals_data[class_idx], add_cumul_evals, axes[class_idx] + ) if fig: fig.suptitle(title, fontsize=22, fontweight="bold") return fig diff --git a/mne/viz/decoding/tests/test_ged.py b/mne/viz/decoding/tests/test_ged.py index 6ca22c54f58..c8463b2c8df 100644 --- a/mne/viz/decoding/tests/test_ged.py +++ b/mne/viz/decoding/tests/test_ged.py @@ -169,7 +169,7 @@ def test_spatial_filter_plotting(): assert gs.nrows == n_classes plt.close("all") - with pytest.raises(ValueError, match="should be equal"): + with pytest.raises(ValueError, match="but expected"): _, axes = plt.subplots(figsize=(12, 7), layout="constrained") _ = sp_filter.plot_scree(axes=axes, show=False)