Skip to content

ENH: Viz for spatial filters #13332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0765454
topomap and scree plots
Genuster Jul 14, 2025
5a7f5bd
Add SpatialFilter visualization class
Genuster Jul 15, 2025
614cd1b
Some doc-related fixes
Genuster Jul 15, 2025
c3feb72
more imrovements
Genuster Jul 18, 2025
3266d16
nest EvokedArray import
Genuster Jul 18, 2025
d720dc2
add some minor tests and fixes
Genuster Jul 18, 2025
d2fdc2d
make get_coef to work with an arbitrary named step
Genuster Jul 19, 2025
4ac29d5
clean get_coef's inverse_transform test
Genuster Jul 19, 2025
37cd069
make instantiation of the spatial filter from ged/linear model a stan…
Genuster Jul 19, 2025
7d78df6
fix tests and nesting
Genuster Jul 19, 2025
874af24
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 28, 2025
f3da23a
add test for get_coef with step_name
Genuster Jul 29, 2025
842764e
add some initialziation and plotting tests
Genuster Jul 29, 2025
8f8336b
Merge remote-tracking branch 'upstream/main' into ged-viz
Genuster Jul 29, 2025
06efd97
fix docstrings
Genuster Jul 29, 2025
49dbdaa
add factory func to doc api
Genuster Jul 29, 2025
9f6b626
another docstring fix and sklearn importorskip
Genuster Jul 29, 2025
04c7d01
more get_coef tests
Genuster Jul 29, 2025
16e35b7
FIX: Timeout
larsoner Jul 29, 2025
e5d300a
Merge remote-tracking branch 'upstream/main' into pre-commit-ci-updat…
larsoner Jul 29, 2025
d355011
more spatial filter viz tests
Genuster Jul 29, 2025
395ac48
replace CSP's plots
Genuster Jul 29, 2025
4ef949c
add changelog entry
Genuster Jul 29, 2025
05450d4
Merge remote-tracking branch 'upstream/pre-commit-ci-update-config' i…
Genuster Jul 29, 2025
de0fe5d
fix changelog
Genuster Jul 29, 2025
203e985
fix axes handling
Genuster Jul 29, 2025
f3748c7
Merge branch 'main' into ged-viz
Genuster Jul 30, 2025
4839d51
Merge branch 'main' into ged-viz
Genuster Aug 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/_includes/ged.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This section describes the mathematical formulation and application of
Generalized Eigendecomposition (GED), often used in spatial filtering
and source separation algorithms, such as :class:`mne.decoding.CSP`,
:class:`mne.decoding.SPoC`, :class:`mne.decoding.SSD` and
:class:`mne.preprocessing.Xdawn`.
:class:`mne.decoding.XdawnTransformer`.

The core principle of GED is to find a set of channel weights (spatial filter)
that maximizes the ratio of signal power between two data features.
Expand Down
2 changes: 2 additions & 0 deletions doc/api/visualization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Visualization
ClickableImage
EvokedField
Figure3D
SpatialFilter
get_spatial_filter_from_estimator
add_background_image
centers_to_edges
compare_fiff
Expand Down
4 changes: 4 additions & 0 deletions doc/changes/dev/13332.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Implement :class:`mne.viz.SpatialFilter` viz class for filters,
patterns for :class:`mne.decoding.LinearModel` and additionally
eigenvalues for GED-based transformers such as
:class:`mne.decoding.XdawnTransformer`, by `Gennadiy Belonosov`_.
53 changes: 49 additions & 4 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,32 @@ def _get_inverse_funcs(estimator, terminal=True):
return inverse_func


def _get_inverse_funcs_before_step(estimator, step_name):
"""Get the inverse_transform methods for all steps before a target step."""
# in case step_name is nested with __
parts = step_name.split("__")
inverse_funcs = list()
current_pipeline = estimator
for part_name in parts:
all_names = [name for name, _ in current_pipeline.steps]
part_idx = all_names.index(part_name)
# get all preceding steps for the current step
for prec_name, prec_step in current_pipeline.steps[:part_idx]:
if hasattr(prec_step, "inverse_transform"):
inverse_funcs.append(prec_step.inverse_transform)
else:
warn(
f"Preceding step '{prec_name}' is not invertible "
f"and will be skipped."
)
current_pipeline = current_pipeline.named_steps[part_name]
return inverse_funcs


@verbose
def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=None):
def get_coef(
estimator, attr="filters_", inverse_transform=False, *, step_name=None, verbose=None
):
"""Retrieve the coefficients of an estimator ending with a Linear Model.

This is typically useful to retrieve "spatial filters" or "spatial
Expand All @@ -573,6 +597,13 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non
inverse_transform : bool
If True, returns the coefficients after inverse transforming them with
the transformer steps of the estimator.
step_name : str | None
Name of the sklearn's pipeline step to get the coef from.
If inverse_transform is True, the inverse transformations
will be applied using transformers before this step.
If None, the last step will be used. Defaults to None.

.. versionadded:: 1.11
%(verbose)s

Returns
Expand All @@ -587,8 +618,17 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non
# Get the coefficients of the last estimator in case of nested pipeline
est = estimator
logger.debug(f"Getting coefficients from estimator: {est.__class__.__name__}")
while hasattr(est, "steps"):
est = est.steps[-1][1]

if step_name is not None:
if not hasattr(estimator, "named_steps"):
raise ValueError("step_name can only be used with a pipeline estimator.")
try:
est = est.get_params(deep=True)[step_name]
except KeyError:
raise ValueError(f"Step '{step_name}' is not part of the pipeline.")
else:
while hasattr(est, "steps"):
est = est.steps[-1][1]

squeeze_first_dim = False

Expand Down Expand Up @@ -617,9 +657,14 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non
raise ValueError(
"inverse_transform can only be applied onto pipeline estimators."
)
if step_name is None:
inverse_funcs = _get_inverse_funcs(estimator)
else:
inverse_funcs = _get_inverse_funcs_before_step(estimator, step_name)

# The inverse_transform parameter will call this method on any
# estimator contained in the pipeline, in reverse order.
for inverse_func in _get_inverse_funcs(estimator)[::-1]:
for inverse_func in inverse_funcs[::-1]:
logger.debug(f" Applying inverse transformation: {inverse_func}.")
coef = inverse_func(coef)

Expand Down
43 changes: 11 additions & 32 deletions mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
# Copyright the MNE-Python contributors.

import collections.abc as abc
import copy as cp
from functools import partial

import numpy as np

from .._fiff.meas_info import Info
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..evoked import EvokedArray
from ..utils import (
_check_option,
_validate_type,
fill_doc,
)
from ..viz.decoding.ged import _plot_model
from ._covs_ged import _csp_estimate, _spoc_estimate
from ._mod_ged import _csp_mod, _spoc_mod
from .base import _GEDTransformer
Expand Down Expand Up @@ -402,20 +401,10 @@ def plot_patterns(
fig : instance of matplotlib.figure.Figure
The figure.
"""
if units is None:
units = "AU"
if components is None:
components = np.arange(self.n_components)

# set sampling frequency to have 1 component per time point
info = cp.deepcopy(info)
with info._unlock():
info["sfreq"] = 1.0
# create an evoked
patterns = EvokedArray(self.patterns_.T, info, tmin=0)
# the call plot_topomap
fig = patterns.plot_topomap(
times=components,
fig = _plot_model(
self.patterns_,
info,
components,
ch_type=ch_type,
scalings=scalings,
sensors=sensors,
Expand All @@ -437,7 +426,7 @@ def plot_patterns(
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
time_format=name_format,
name_format=name_format,
nrows=nrows,
ncols=ncols,
show=show,
Expand Down Expand Up @@ -530,20 +519,10 @@ def plot_filters(
fig : instance of matplotlib.figure.Figure
The figure.
"""
if units is None:
units = "AU"
if components is None:
components = np.arange(self.n_components)

# set sampling frequency to have 1 component per time point
info = cp.deepcopy(info)
with info._unlock():
info["sfreq"] = 1.0
# create an evoked
filters = EvokedArray(self.filters_.T, info, tmin=0)
# the call plot_topomap
fig = filters.plot_topomap(
times=components,
fig = _plot_model(
self.filters_,
info,
components,
ch_type=ch_type,
scalings=scalings,
sensors=sensors,
Expand All @@ -565,7 +544,7 @@ def plot_filters(
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
time_format=name_format,
name_format=name_format,
nrows=nrows,
ncols=ncols,
show=show,
Expand Down
119 changes: 97 additions & 22 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_classifier,
is_regressor,
)
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
from sklearn.model_selection import (
GridSearchCV,
Expand Down Expand Up @@ -228,35 +229,24 @@ def inverse_transform(self, X):
assert patterns[0] != patterns_inv[0]


class _Noop(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
return self
# class _Noop(BaseEstimator, TransformerMixin):
# def fit(self, X, y=None):
# return self

def transform(self, X):
return X.copy()
# def transform(self, X):
# return X.copy()

inverse_transform = transform
# inverse_transform = transform


@pytest.mark.parametrize("inverse", (True, False))
@pytest.mark.parametrize(
"Scale, kwargs",
[
(Scaler, dict(info=None, scalings="mean")),
(_Noop, dict()),
],
)
def test_get_coef_inverse_transform(inverse, Scale, kwargs):
def test_get_coef_inverse_transform(inverse):
"""Test get_coef with and without inverse_transform."""
lm_regression = LinearModel(Ridge())
X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1)
# Check with search_light and combination of preprocessing ending with sl:
# slider = SlidingEstimator(make_pipeline(StandardScaler(), lm_regression))
# XXX : line above should work but does not as only last step is
# used in get_coef ...
slider = SlidingEstimator(make_pipeline(lm_regression))
clf = SlidingEstimator(make_pipeline(StandardScaler(), lm_regression))
X = np.transpose([X, -X], [1, 2, 0]) # invert X across 2 time samples
clf = make_pipeline(Scale(**kwargs), slider)
clf.fit(X, y)
patterns = get_coef(clf, "patterns_", inverse)
filters = get_coef(clf, "filters_", inverse)
Expand All @@ -265,10 +255,95 @@ def test_get_coef_inverse_transform(inverse, Scale, kwargs):
assert_equal(patterns[0, 0], -patterns[0, 1])
for t in [0, 1]:
filters_t = get_coef(
clf.named_steps["slidingestimator"].estimators_[t], "filters_", False
clf.estimators_[t],
"filters_",
inverse,
verbose=False,
)
assert_array_equal(filters_t, filters[:, t])


def test_get_coef_inverse_step_name():
"""Test get_coef with inverse_transform=True and a specific step_name."""
X, y, _ = _make_data(n_samples=100, n_features=5, n_targets=1)

# Test with a simple pipeline
pipe = make_pipeline(StandardScaler(), PCA(n_components=3), LinearModel(Ridge()))
pipe.fit(X, y)

coef_inv_actual = get_coef(
pipe, attr="patterns_", inverse_transform=True, step_name="linearmodel"
)
# Reshape your data using array.reshape(1, -1) if it contains a single sample.
coef_raw = pipe.named_steps["linearmodel"].patterns_.reshape(1, -1)
coef_inv_desired = pipe.named_steps["pca"].inverse_transform(coef_raw)
coef_inv_desired = pipe.named_steps["standardscaler"].inverse_transform(
coef_inv_desired
)

assert coef_inv_actual.shape == (X.shape[1],)
# Reshape your data using array.reshape(1, -1) if it contains a single sample.
assert_array_almost_equal(coef_inv_actual.reshape(1, -1), coef_inv_desired)

with pytest.raises(ValueError, match="inverse_transform"):
_ = get_coef(
pipe[-1], # LinearModel
"filters_",
inverse_transform=True,
)
with pytest.raises(ValueError, match="step_name"):
_ = get_coef(
SlidingEstimator(pipe),
"filters_",
inverse_transform=True,
step_name="slidingestimator__pipeline__linearmodel",
)

# Test with a nested pipeline to check __ parsing
inner_pipe = make_pipeline(PCA(n_components=3), LinearModel(Ridge()))
nested_pipe = make_pipeline(StandardScaler(), inner_pipe)
nested_pipe.fit(X, y)
coef_nested_inv_actual = get_coef(
nested_pipe,
attr="patterns_",
inverse_transform=True,
step_name="pipeline__linearmodel",
)
linearmodel = nested_pipe.named_steps["pipeline"].named_steps["linearmodel"]
pca = nested_pipe.named_steps["pipeline"].named_steps["pca"]
scaler = nested_pipe.named_steps["standardscaler"]

coef_nested_raw = linearmodel.patterns_.reshape(1, -1)
coef_nested_inv_desired = pca.inverse_transform(coef_nested_raw)
coef_nested_inv_desired = scaler.inverse_transform(coef_nested_inv_desired)

assert coef_nested_inv_actual.shape == (X.shape[1],)
assert_array_almost_equal(
coef_nested_inv_actual.reshape(1, -1), coef_nested_inv_desired
)

with pytest.raises(ValueError, match="i_do_not_exist"):
get_coef(
pipe, attr="patterns_", inverse_transform=True, step_name="i_do_not_exist"
)

class NonInvertibleTransformer(BaseEstimator, TransformerMixin):
def fit(self, X, y=None):
return self

def transform(self, X):
# In a real scenario, this would modify X
return X

pipe = make_pipeline(NonInvertibleTransformer(), LinearModel(Ridge()))
pipe.fit(X, y)
with pytest.warns(RuntimeWarning, match="not invertible"):
_ = get_coef(
pipe,
"filters_",
inverse_transform=True,
step_name="linearmodel",
)
if Scale is _Noop:
assert_array_equal(filters_t, filters[:, t])


@pytest.mark.parametrize("n_features", [1, 5])
Expand Down
3 changes: 3 additions & 0 deletions mne/viz/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __all__ = [
"ClickableImage",
"EvokedField",
"Figure3D",
"SpatialFilter",
"_RAW_CLIP_DEF",
"_get_plot_ch_type",
"_get_presser",
Expand All @@ -22,6 +23,7 @@ __all__ = [
"get_3d_backend",
"get_brain_class",
"get_browser_backend",
"get_spatial_filter_from_estimator",
"iter_topography",
"link_brains",
"mne_analyze_colormap",
Expand Down Expand Up @@ -118,6 +120,7 @@ from .backends.renderer import (
use_3d_backend,
)
from .circle import circular_layout, plot_channel_labels_circle
from .decoding.ged import SpatialFilter, get_spatial_filter_from_estimator
from .epochs import plot_drop_log, plot_epochs, plot_epochs_image, plot_epochs_psd
from .evoked import (
plot_compare_evokeds,
Expand Down
6 changes: 6 additions & 0 deletions mne/viz/decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

"""Decoding visualization routines."""
from .ged import SpatialFilter
Loading
Loading