From ad903f985f1dae0e0e6cc0240f9452acc157d78d Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sat, 19 Jul 2025 19:04:41 +0200 Subject: [PATCH 1/2] Refactor TransformChain X5 loading --- CHANGES.rst | 1 + nitransforms/manip.py | 20 ++++++++++++++++++++ nitransforms/tests/test_x5.py | 14 ++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 44579977..4608bc16 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,6 +9,7 @@ CHANGES ------- * FIX: Broken 4D resampling by @oesteban in https://github.com/nipy/nitransforms/pull/247 * ENH: Loading of X5 (linear) transforms by @oesteban in https://github.com/nipy/nitransforms/pull/243 +* ENH: Add X5 support to transform chains by @oesteban * ENH: Implement X5 representation and output to filesystem by @oesteban in https://github.com/nipy/nitransforms/pull/241 * DOC: Fix references to ``os.PathLike`` by @oesteban in https://github.com/nipy/nitransforms/pull/242 * MNT: Increase coverage by testing edge cases and adding docstrings by @oesteban in https://github.com/nipy/nitransforms/pull/248 diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 9389197d..21eee2e8 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -193,6 +193,8 @@ def asaffine(self, indices=None): def from_filename(cls, filename, fmt="X5", reference=None, moving=None): """Load a transform file.""" from .io import itk + from .io.x5 import from_filename as load_x5 + from . import linear as nitl retval = [] if str(filename).endswith(".h5"): @@ -206,6 +208,24 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None): return TransformChain(retval) + if fmt.upper() == "X5" or str(filename).endswith(".x5"): + for i, x5_xfm in enumerate(load_x5(filename)): + if x5_xfm.type != "linear": + raise NotImplementedError( + "Only linear X5 transforms are currently supported" + ) + + xfm = nitl.Affine.from_filename( + filename, + fmt="X5", + reference=reference, + moving=moving, + x5_position=i, + ) + retval.append(xfm) + + return TransformChain(retval) + raise NotImplementedError diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index 89b49e06..c47a37d0 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -3,6 +3,8 @@ from h5py import File as H5File from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename +from nitransforms import linear as nitl +from nitransforms import manip as nitm def test_x5_transform_defaults(): @@ -75,3 +77,15 @@ def test_from_filename_invalid(tmp_path): with pytest.raises(TypeError): from_filename(fname) + + +def test_transformchain_from_x5(tmp_path): + aff1 = nitl.Affine.from_matvec(vec=(1, 2, 3)) + aff2 = nitl.Affine.from_matvec(vec=(-1, -2, -3)) + fname = tmp_path / "chain.x5" + to_filename(fname, [aff1.to_x5(), aff2.to_x5()]) + + chain = nitm.TransformChain.from_filename(fname, fmt="X5") + assert len(chain.transforms) == 2 + assert chain.transforms[0] == aff1 + assert chain.transforms[1] == aff2 From 8e7abcc80c75c8bcd4b9a7b01b6babfa0ef633f1 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 20 Jul 2025 08:09:25 +0200 Subject: [PATCH 2/2] Support nonlinear steps when loading X5 chains --- CHANGES.rst | 1 - nitransforms/manip.py | 23 +++++++++++++---------- nitransforms/nonlinear.py | 21 +++++++++++++++++++++ nitransforms/tests/test_x5.py | 22 ++++++++++++++++++++++ 4 files changed, 56 insertions(+), 11 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 4608bc16..44579977 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,7 +9,6 @@ CHANGES ------- * FIX: Broken 4D resampling by @oesteban in https://github.com/nipy/nitransforms/pull/247 * ENH: Loading of X5 (linear) transforms by @oesteban in https://github.com/nipy/nitransforms/pull/243 -* ENH: Add X5 support to transform chains by @oesteban * ENH: Implement X5 representation and output to filesystem by @oesteban in https://github.com/nipy/nitransforms/pull/241 * DOC: Fix references to ``os.PathLike`` by @oesteban in https://github.com/nipy/nitransforms/pull/242 * MNT: Increase coverage by testing edge cases and adding docstrings by @oesteban in https://github.com/nipy/nitransforms/pull/248 diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 21eee2e8..5ff9d534 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -195,6 +195,7 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None): from .io import itk from .io.x5 import from_filename as load_x5 from . import linear as nitl + from . import nonlinear as nitn retval = [] if str(filename).endswith(".h5"): @@ -210,18 +211,20 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None): if fmt.upper() == "X5" or str(filename).endswith(".x5"): for i, x5_xfm in enumerate(load_x5(filename)): - if x5_xfm.type != "linear": + if x5_xfm.type == "linear": + xfm = nitl.Affine.from_filename( + filename, + fmt="X5", + reference=reference, + moving=moving, + x5_position=i, + ) + elif x5_xfm.type == "nonlinear": + xfm = nitn.DenseFieldTransform.from_x5(x5_xfm) + else: raise NotImplementedError( - "Only linear X5 transforms are currently supported" + f"Unsupported X5 transform type {x5_xfm.type}" ) - - xfm = nitl.Affine.from_filename( - filename, - fmt="X5", - reference=reference, - moving=moving, - x5_position=i, - ) retval.append(xfm) return TransformChain(retval) diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 9c29c53c..a451cb82 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -10,6 +10,7 @@ import warnings from functools import partial import numpy as np +import nibabel as nb from nitransforms import io from nitransforms.io.base import _ensure_image @@ -239,6 +240,26 @@ def from_filename(cls, filename, fmt="X5"): return cls(_factory[fmt].from_filename(filename)) + @classmethod + def from_x5(cls, x5struct): + """Instantiate a dense field transform from an :class:`X5Transform`.""" + if x5struct.type != "nonlinear": + raise TypeError("X5 structure is not a nonlinear transform") + if not x5struct.domain or not x5struct.domain.grid: + raise NotImplementedError( + "Only regularly gridded nonlinear X5 transforms are supported" + ) + + hdr = nb.Nifti1Header() + hdr.set_intent("vector") + img = nb.Nifti1Image(x5struct.transform.astype("float32"), x5struct.domain.mapping, hdr) + + is_deltas = True + if x5struct.representation and "def" in x5struct.representation.lower(): + is_deltas = False + + return cls(img, is_deltas=is_deltas) + load = DenseFieldTransform.from_filename diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index c47a37d0..6c0603b9 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -5,6 +5,7 @@ from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename from nitransforms import linear as nitl from nitransforms import manip as nitm +from nitransforms import nonlinear as nitn def test_x5_transform_defaults(): @@ -89,3 +90,24 @@ def test_transformchain_from_x5(tmp_path): assert len(chain.transforms) == 2 assert chain.transforms[0] == aff1 assert chain.transforms[1] == aff2 + + +def test_transformchain_from_x5_nonlinear(tmp_path): + field = np.zeros((2, 2, 2, 3), dtype=float) + domain = X5Domain(grid=True, size=(2, 2, 2), mapping=np.eye(4)) + nonlinear_node = X5Transform( + type="nonlinear", + transform=field, + representation="dense_field", + dimension_kinds=("space", "space", "space", "vector"), + domain=domain, + ) + aff = nitl.Affine.from_matvec(vec=(0, 0, 0)) + fname = tmp_path / "nonlinear_chain.x5" + to_filename(fname, [aff.to_x5(), nonlinear_node]) + + chain = nitm.TransformChain.from_filename(fname, fmt="X5") + assert len(chain.transforms) == 2 + assert chain.transforms[0] == aff + assert isinstance(chain.transforms[1], nitn.DenseFieldTransform) + assert chain.transforms[1].reference.shape == (2, 2, 2)