From 896d46b820512e51293917ea80cb9b23c59447a1 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sat, 19 Jul 2025 18:56:18 +0200 Subject: [PATCH 1/4] Add nonlinear X5 tests --- CHANGES.rst | 2 + nitransforms/nonlinear.py | 75 +++++++++++++++++++++++++++- nitransforms/tests/test_nonlinear.py | 45 +++++++++++++++++ 3 files changed, 120 insertions(+), 2 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 44579977..94631b4c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,12 +4,14 @@ A new major release with critical updates. The new release includes a critical hotfix for 4D resamplings. The second major improvement is the inclusion of a first implementation of the X5 format (BIDS). The X5 implementation is currently restricted to reading/writing of linear transforms. +It now supports nonlinear transforms as well. CHANGES ------- * FIX: Broken 4D resampling by @oesteban in https://github.com/nipy/nitransforms/pull/247 * ENH: Loading of X5 (linear) transforms by @oesteban in https://github.com/nipy/nitransforms/pull/243 * ENH: Implement X5 representation and output to filesystem by @oesteban in https://github.com/nipy/nitransforms/pull/241 +* ENH: Support reading and writing of nonlinear transforms in X5 * DOC: Fix references to ``os.PathLike`` by @oesteban in https://github.com/nipy/nitransforms/pull/242 * MNT: Increase coverage by testing edge cases and adding docstrings by @oesteban in https://github.com/nipy/nitransforms/pull/248 * MNT: Refactor io/lta to reduce one partial line by @oesteban in https://github.com/nipy/nitransforms/pull/246 diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 9c29c53c..d752cffa 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -9,6 +9,7 @@ """Nonlinear transforms.""" import warnings from functools import partial +from collections import namedtuple import numpy as np from nitransforms import io @@ -227,17 +228,54 @@ def __eq__(self, other): warnings.warn("Fields are equal, but references do not match.") return _eq + def to_x5(self, metadata=None): + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" + from ._version import __version__ + from .io.x5 import X5Domain, X5Transform + + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) + + domain = None + if (reference := self.reference) is not None: + domain = X5Domain( + grid=True, + size=getattr(reference, "shape", (0, 0, 0)), + mapping=reference.affine, + coordinates="cartesian", + ) + + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) + + return X5Transform( + type="nonlinear", + subtype="densefield", + representation="displacements", + metadata=metadata, + transform=self._deltas, + dimension_kinds=kinds, + domain=domain, + ) + @classmethod def from_filename(cls, filename, fmt="X5"): _factory = { "afni": io.afni.AFNIDisplacementsField, "itk": io.itk.ITKDisplacementsField, "fsl": io.fsl.FSLDisplacementsField, + "X5": None, } - if fmt not in _factory: + fmt = fmt.upper() + if fmt not in {k.upper() for k in _factory}: raise NotImplementedError(f"Unsupported format <{fmt}>") - return cls(_factory[fmt].from_filename(filename)) + if fmt == "X5": + from .io.x5 import from_filename as load_x5 + x5_xfm = load_x5(filename)[0] + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + return cls(x5_xfm.transform, is_deltas=True, reference=reference) + + return cls(_factory[fmt.lower()].from_filename(filename)) load = DenseFieldTransform.from_filename @@ -293,6 +331,39 @@ def to_field(self, reference=None, dtype="float32"): field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref ) + def to_x5(self, metadata=None): + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" + from ._version import __version__ + from .io.x5 import X5Transform, X5Domain + + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) + + domain = None + if (reference := self.reference) is not None: + domain = X5Domain( + grid=True, + size=getattr(reference, "shape", (0, 0, 0)), + mapping=reference.affine, + coordinates="cartesian", + ) + + meta = metadata | { + "KnotsAffine": self._knots.affine.tolist(), + "KnotsShape": self._knots.shape, + } + + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) + + return X5Transform( + type="nonlinear", + subtype="bspline", + representation="coefficients", + metadata=meta, + transform=self._coeffs, + dimension_kinds=kinds, + domain=domain, + ) + def map(self, x, inverse=False): r""" Apply the transformation to a list of physical coordinate points. diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index 6112f633..d2ba4bd7 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -14,6 +14,7 @@ BSplineFieldTransform, DenseFieldTransform, ) +from nitransforms import io from ..io.itk import ITKDisplacementsField @@ -119,3 +120,47 @@ def test_bspline(tmp_path, testdata_path): ).mean() < 0.2 ) + + +def test_densefield_x5_roundtrip(tmp_path): + """Ensure dense field transforms roundtrip via X5.""" + ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4)) + disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4)) + + xfm = DenseFieldTransform(disp, reference=ref) + + node = xfm.to_x5(metadata={"GeneratedBy": "pytest"}) + assert node.type == "nonlinear" + assert node.subtype == "densefield" + assert node.representation == "displacements" + assert node.domain.size == ref.shape + assert node.metadata["GeneratedBy"] == "pytest" + + fname = tmp_path / "test.x5" + io.x5.to_filename(fname, [node]) + + xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5") + diff = xfm2._deltas - xfm._deltas + coords = xfm.reference.ndcoords.T.reshape(xfm._deltas.shape) + assert np.allclose(diff, coords) + assert xfm2.reference.shape == ref.shape + assert np.allclose(xfm2.reference.affine, ref.affine) + + +def test_bspline_to_x5(tmp_path): + """Check BSpline transforms export to X5.""" + coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4)) + ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4)) + + xfm = BSplineFieldTransform(coeff, reference=ref) + node = xfm.to_x5(metadata={"tool": "pytest"}) + assert node.type == "nonlinear" + assert node.subtype == "bspline" + assert node.representation == "coefficients" + assert node.metadata["tool"] == "pytest" + + fname = tmp_path / "bspline.x5" + io.x5.to_filename(fname, [node]) + node2 = io.x5.from_filename(fname)[0] + assert np.allclose(node2.transform, node.transform) + assert node2.metadata["tool"] == "pytest" From 7347e87b18825ec310deac2e929117882c1d438f Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sat, 19 Jul 2025 19:04:10 +0200 Subject: [PATCH 2/4] DOC: revert CHANGES entry --- CHANGES.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 94631b4c..44579977 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,14 +4,12 @@ A new major release with critical updates. The new release includes a critical hotfix for 4D resamplings. The second major improvement is the inclusion of a first implementation of the X5 format (BIDS). The X5 implementation is currently restricted to reading/writing of linear transforms. -It now supports nonlinear transforms as well. CHANGES ------- * FIX: Broken 4D resampling by @oesteban in https://github.com/nipy/nitransforms/pull/247 * ENH: Loading of X5 (linear) transforms by @oesteban in https://github.com/nipy/nitransforms/pull/243 * ENH: Implement X5 representation and output to filesystem by @oesteban in https://github.com/nipy/nitransforms/pull/241 -* ENH: Support reading and writing of nonlinear transforms in X5 * DOC: Fix references to ``os.PathLike`` by @oesteban in https://github.com/nipy/nitransforms/pull/242 * MNT: Increase coverage by testing edge cases and adding docstrings by @oesteban in https://github.com/nipy/nitransforms/pull/248 * MNT: Refactor io/lta to reduce one partial line by @oesteban in https://github.com/nipy/nitransforms/pull/246 From aa973a37571d5a10552d1341145095cb57f6f0ce Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 20 Jul 2025 08:27:09 +0200 Subject: [PATCH 3/4] enh: move lazy imports to top of nonlinear file --- nitransforms/nonlinear.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index d752cffa..9af12b98 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -7,6 +7,7 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Nonlinear transforms.""" + import warnings from functools import partial from collections import namedtuple @@ -23,6 +24,12 @@ ) from scipy.ndimage import map_coordinates +# Avoids circular imports +try: + from nitransforms._version import __version__ +except ModuleNotFoundError: # pragma: no cover + __version__ = "0+unknown" + class DenseFieldTransform(TransformBase): """Represents dense field (voxel-wise) transforms.""" @@ -230,14 +237,11 @@ def __eq__(self, other): def to_x5(self, metadata=None): """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" - from ._version import __version__ - from .io.x5 import X5Domain, X5Transform - metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) domain = None if (reference := self.reference) is not None: - domain = X5Domain( + domain = io.x5.X5Domain( grid=True, size=getattr(reference, "shape", (0, 0, 0)), mapping=reference.affine, @@ -246,7 +250,7 @@ def to_x5(self, metadata=None): kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) - return X5Transform( + return io.x5.X5Transform( type="nonlinear", subtype="densefield", representation="displacements", @@ -270,6 +274,7 @@ def from_filename(cls, filename, fmt="X5"): if fmt == "X5": from .io.x5 import from_filename as load_x5 + x5_xfm = load_x5(filename)[0] Domain = namedtuple("Domain", "affine shape") reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) @@ -333,14 +338,11 @@ def to_field(self, reference=None, dtype="float32"): def to_x5(self, metadata=None): """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" - from ._version import __version__ - from .io.x5 import X5Transform, X5Domain - metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) domain = None if (reference := self.reference) is not None: - domain = X5Domain( + domain = io.x5.X5Domain( grid=True, size=getattr(reference, "shape", (0, 0, 0)), mapping=reference.affine, @@ -354,7 +356,7 @@ def to_x5(self, metadata=None): kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) - return X5Transform( + return io.x5.X5Transform( type="nonlinear", subtype="bspline", representation="coefficients", From 27e95c7aa9d138bc00a1fcd46ce2f3dba12faa95 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 20 Jul 2025 09:11:44 +0200 Subject: [PATCH 4/4] enh: enable B-Splines X5 i/o and DenseFields' is_deltas --- nitransforms/io/x5.py | 22 ++++---- nitransforms/nonlinear.py | 82 ++++++++++++++++++++-------- nitransforms/tests/test_nonlinear.py | 20 ++++--- 3 files changed, 81 insertions(+), 43 deletions(-) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index a86a8554..2f86e8ab 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -73,11 +73,11 @@ class X5Transform: For parametric models it is generally possible to obtain it analytically, so this dataset could not be as useful in that case. """ - # additional_parameters: Optional[np.ndarray] = None - # AdditionalParameters is empty in the draft spec - ignore for now. - # Only documentation ATM is for SubType: - # The SubType setting enables setting the additional parameters on a dataset called - # "AdditionalParameters" that hangs directly from this transform node. + additional_parameters: Optional[np.ndarray] = None + """ + An OPTIONAL field to store additional parameters, depending on the SubType of the + transform. + """ array_length: int = 1 """Undocumented field in the draft to enable a single transform group for 4D transforms.""" @@ -130,11 +130,10 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): g.create_dataset("Inverse", data=node.inverse) if node.jacobian is not None: g.create_dataset("Jacobian", data=node.jacobian) - # Disabled until we need SubType and AdditionalParameters - # if node.additional_parameters is not None: - # g.create_dataset( - # "AdditionalParameters", data=node.additional_parameters - # ) + if node.additional_parameters is not None: + g.create_dataset( + "AdditionalParameters", data=node.additional_parameters + ) return fname @@ -174,6 +173,9 @@ def _read_x5_group(node) -> X5Transform: inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None, jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None, array_length=int(node.attrs.get("ArrayLength", 1)), + additional_parameters=np.asarray(node["AdditionalParameters"]) + if "AdditionalParameters" in node + else None, ) if "Domain" in node: diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 9af12b98..372db3cf 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -12,9 +12,11 @@ from functools import partial from collections import namedtuple import numpy as np +import nibabel as nb from nitransforms import io from nitransforms.io.base import _ensure_image +from nitransforms.io.x5 import from_filename as load_x5 from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline from nitransforms.base import ( TransformBase, @@ -34,7 +36,7 @@ class DenseFieldTransform(TransformBase): """Represents dense field (voxel-wise) transforms.""" - __slots__ = ("_field", "_deltas") + __slots__ = ("_field", "_deltas", "_is_deltas") def __init__(self, field=None, is_deltas=True, reference=None): """ @@ -68,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None): super().__init__() - if field is not None: - field = _ensure_image(field) - self._field = np.squeeze( - np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field - ) - else: - self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32") - is_deltas = True + self._is_deltas = is_deltas try: self.reference = ImageGrid(reference if reference is not None else field) @@ -86,22 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None): else "Reference is not a spatial image" ) + fieldshape = (*self.reference.shape, self.reference.ndim) + if field is not None: + field = _ensure_image(field) + self._field = np.squeeze( + np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field + ) + if fieldshape != self._field.shape: + raise TransformError( + f"Shape of the field ({'x'.join(str(i) for i in self._field.shape)}) " + f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})" + ) + else: + self._field = np.zeros(fieldshape, dtype="float32") + self._is_deltas = True + if self._field.shape[-1] != self.ndim: raise TransformError( "The number of components of the field (%d) does not match " "the number of dimensions (%d)" % (self._field.shape[-1], self.ndim) ) - if is_deltas: - self._deltas = self._field + if self._is_deltas: + self._deltas = ( + self._field.copy() + ) # IMPORTANT: you don't want to update deltas # Convert from displacements (deltas) to deformations fields # (just add its origin to each delta vector) - self._field += self.reference.ndcoords.T.reshape(self._field.shape) + self._field += self.reference.ndcoords.T.reshape(fieldshape) def __repr__(self): """Beautify the python representation.""" return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>" + @property + def is_deltas(self): + """Check whether this is a displacements (``True``) or a deformation (``False``) field.""" + return self._is_deltas + @property def ndim(self): """Get the dimensions of the transform.""" @@ -230,7 +247,7 @@ def __eq__(self, other): True """ - _eq = np.array_equal(self._field, other._field) + _eq = np.allclose(self._field, other._field) if _eq and self._reference != other._reference: warnings.warn("Fields are equal, but references do not match.") return _eq @@ -253,9 +270,9 @@ def to_x5(self, metadata=None): return io.x5.X5Transform( type="nonlinear", subtype="densefield", - representation="displacements", + representation="displacements" if self.is_deltas else "deformations", metadata=metadata, - transform=self._deltas, + transform=self._deltas if self.is_deltas else self._field, dimension_kinds=kinds, domain=domain, ) @@ -273,12 +290,15 @@ def from_filename(cls, filename, fmt="X5"): raise NotImplementedError(f"Unsupported format <{fmt}>") if fmt == "X5": - from .io.x5 import from_filename as load_x5 - x5_xfm = load_x5(filename)[0] Domain = namedtuple("Domain", "affine shape") reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) - return cls(x5_xfm.transform, is_deltas=True, reference=reference) + field = nb.Nifti1Image(x5_xfm.transform, reference.affine) + return cls( + field, + is_deltas=x5_xfm.representation == "displacements", + reference=reference, + ) return cls(_factory[fmt.lower()].from_filename(filename)) @@ -315,6 +335,24 @@ def ndim(self): """Get the dimensions of the transform.""" return self._coeffs.ndim - 1 + @classmethod + def from_filename(cls, filename, fmt="X5"): + _factory = { + "X5": None, + } + fmt = fmt.upper() + if fmt not in {k.upper() for k in _factory}: + raise NotImplementedError(f"Unsupported format <{fmt}>") + + x5_xfm = load_x5(filename)[0] + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + + coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters) + return cls(coefficients, reference=reference) + + # return cls(_factory[fmt.lower()].from_filename(filename)) + def to_field(self, reference=None, dtype="float32"): """Generate a displacements deformation field from this B-Spline field.""" _ref = ( @@ -349,21 +387,17 @@ def to_x5(self, metadata=None): coordinates="cartesian", ) - meta = metadata | { - "KnotsAffine": self._knots.affine.tolist(), - "KnotsShape": self._knots.shape, - } - kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) return io.x5.X5Transform( type="nonlinear", subtype="bspline", representation="coefficients", - metadata=meta, + metadata=metadata, transform=self._coeffs, dimension_kinds=kinds, domain=domain, + additional_parameters=self._knots.affine, ) def map(self, x, inverse=False): diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index d2ba4bd7..9df5dcea 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -122,17 +122,18 @@ def test_bspline(tmp_path, testdata_path): ) -def test_densefield_x5_roundtrip(tmp_path): +@pytest.mark.parametrize("is_deltas", [True, False]) +def test_densefield_x5_roundtrip(tmp_path, is_deltas): """Ensure dense field transforms roundtrip via X5.""" ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4)) disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4)) - xfm = DenseFieldTransform(disp, reference=ref) + xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref) node = xfm.to_x5(metadata={"GeneratedBy": "pytest"}) assert node.type == "nonlinear" assert node.subtype == "densefield" - assert node.representation == "displacements" + assert node.representation == "displacements" if is_deltas else "deformations" assert node.domain.size == ref.shape assert node.metadata["GeneratedBy"] == "pytest" @@ -140,11 +141,10 @@ def test_densefield_x5_roundtrip(tmp_path): io.x5.to_filename(fname, [node]) xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5") - diff = xfm2._deltas - xfm._deltas - coords = xfm.reference.ndcoords.T.reshape(xfm._deltas.shape) - assert np.allclose(diff, coords) + assert xfm2.reference.shape == ref.shape assert np.allclose(xfm2.reference.affine, ref.affine) + assert xfm == xfm2 def test_bspline_to_x5(tmp_path): @@ -161,6 +161,8 @@ def test_bspline_to_x5(tmp_path): fname = tmp_path / "bspline.x5" io.x5.to_filename(fname, [node]) - node2 = io.x5.from_filename(fname)[0] - assert np.allclose(node2.transform, node.transform) - assert node2.metadata["tool"] == "pytest" + + xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5") + assert np.allclose(xfm._coeffs, xfm2._coeffs) + assert xfm2.reference.shape == ref.shape + assert np.allclose(xfm2.reference.affine, ref.affine)