Skip to content

ENH: Add X5 support of nonlinear transforms #249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions nitransforms/io/x5.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ class X5Transform:
For parametric models it is generally possible to obtain it analytically, so this dataset
could not be as useful in that case.
"""
# additional_parameters: Optional[np.ndarray] = None
# AdditionalParameters is empty in the draft spec - ignore for now.
# Only documentation ATM is for SubType:
# The SubType setting enables setting the additional parameters on a dataset called
# "AdditionalParameters" that hangs directly from this transform node.
additional_parameters: Optional[np.ndarray] = None
"""
An OPTIONAL field to store additional parameters, depending on the SubType of the
transform.
"""
array_length: int = 1
"""Undocumented field in the draft to enable a single transform group for 4D transforms."""

Expand Down Expand Up @@ -130,11 +130,10 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
g.create_dataset("Inverse", data=node.inverse)
if node.jacobian is not None:
g.create_dataset("Jacobian", data=node.jacobian)
# Disabled until we need SubType and AdditionalParameters
# if node.additional_parameters is not None:
# g.create_dataset(
# "AdditionalParameters", data=node.additional_parameters
# )
if node.additional_parameters is not None:
g.create_dataset(
"AdditionalParameters", data=node.additional_parameters
)
return fname


Expand Down Expand Up @@ -174,6 +173,9 @@ def _read_x5_group(node) -> X5Transform:
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
array_length=int(node.attrs.get("ArrayLength", 1)),
additional_parameters=np.asarray(node["AdditionalParameters"])
if "AdditionalParameters" in node
else None,
)

if "Domain" in node:
Expand Down
137 changes: 122 additions & 15 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Nonlinear transforms."""

import warnings
from functools import partial
from collections import namedtuple
import numpy as np
import nibabel as nb

from nitransforms import io
from nitransforms.io.base import _ensure_image
from nitransforms.io.x5 import from_filename as load_x5
from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline
from nitransforms.base import (
TransformBase,
Expand All @@ -22,11 +26,17 @@
)
from scipy.ndimage import map_coordinates

# Avoids circular imports
try:
from nitransforms._version import __version__
except ModuleNotFoundError: # pragma: no cover
__version__ = "0+unknown"


class DenseFieldTransform(TransformBase):
"""Represents dense field (voxel-wise) transforms."""

__slots__ = ("_field", "_deltas")
__slots__ = ("_field", "_deltas", "_is_deltas")

def __init__(self, field=None, is_deltas=True, reference=None):
"""
Expand Down Expand Up @@ -60,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None):

super().__init__()

if field is not None:
field = _ensure_image(field)
self._field = np.squeeze(
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
)
else:
self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32")
is_deltas = True
self._is_deltas = is_deltas

try:
self.reference = ImageGrid(reference if reference is not None else field)
Expand All @@ -78,22 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None):
else "Reference is not a spatial image"
)

fieldshape = (*self.reference.shape, self.reference.ndim)
if field is not None:
field = _ensure_image(field)
self._field = np.squeeze(
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
)
if fieldshape != self._field.shape:
raise TransformError(
f"Shape of the field ({'x'.join(str(i) for i in self._field.shape)}) "
f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})"
)
else:
self._field = np.zeros(fieldshape, dtype="float32")
self._is_deltas = True

if self._field.shape[-1] != self.ndim:
raise TransformError(
"The number of components of the field (%d) does not match "
"the number of dimensions (%d)" % (self._field.shape[-1], self.ndim)
)

if is_deltas:
self._deltas = self._field
if self._is_deltas:
self._deltas = (
self._field.copy()
) # IMPORTANT: you don't want to update deltas
# Convert from displacements (deltas) to deformations fields
# (just add its origin to each delta vector)
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
self._field += self.reference.ndcoords.T.reshape(fieldshape)

def __repr__(self):
"""Beautify the python representation."""
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"

@property
def is_deltas(self):
"""Check whether this is a displacements (``True``) or a deformation (``False``) field."""
return self._is_deltas

@property
def ndim(self):
"""Get the dimensions of the transform."""
Expand Down Expand Up @@ -222,22 +247,60 @@ def __eq__(self, other):
True

"""
_eq = np.array_equal(self._field, other._field)
_eq = np.allclose(self._field, other._field)
if _eq and self._reference != other._reference:
warnings.warn("Fields are equal, but references do not match.")
return _eq

def to_x5(self, metadata=None):
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})

domain = None
if (reference := self.reference) is not None:
domain = io.x5.X5Domain(
grid=True,
size=getattr(reference, "shape", (0, 0, 0)),
mapping=reference.affine,
coordinates="cartesian",
)

kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)

return io.x5.X5Transform(
type="nonlinear",
subtype="densefield",
representation="displacements" if self.is_deltas else "deformations",
metadata=metadata,
transform=self._deltas if self.is_deltas else self._field,
dimension_kinds=kinds,
domain=domain,
)

@classmethod
def from_filename(cls, filename, fmt="X5"):
_factory = {
"afni": io.afni.AFNIDisplacementsField,
"itk": io.itk.ITKDisplacementsField,
"fsl": io.fsl.FSLDisplacementsField,
"X5": None,
}
if fmt not in _factory:
fmt = fmt.upper()
if fmt not in {k.upper() for k in _factory}:
raise NotImplementedError(f"Unsupported format <{fmt}>")

return cls(_factory[fmt].from_filename(filename))
if fmt == "X5":
x5_xfm = load_x5(filename)[0]
Domain = namedtuple("Domain", "affine shape")
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)
field = nb.Nifti1Image(x5_xfm.transform, reference.affine)
return cls(
field,
is_deltas=x5_xfm.representation == "displacements",
reference=reference,
)

return cls(_factory[fmt.lower()].from_filename(filename))


load = DenseFieldTransform.from_filename
Expand Down Expand Up @@ -272,6 +335,24 @@ def ndim(self):
"""Get the dimensions of the transform."""
return self._coeffs.ndim - 1

@classmethod
def from_filename(cls, filename, fmt="X5"):
_factory = {
"X5": None,
}
fmt = fmt.upper()
if fmt not in {k.upper() for k in _factory}:
raise NotImplementedError(f"Unsupported format <{fmt}>")

x5_xfm = load_x5(filename)[0]
Domain = namedtuple("Domain", "affine shape")
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)

coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters)
return cls(coefficients, reference=reference)

# return cls(_factory[fmt.lower()].from_filename(filename))

def to_field(self, reference=None, dtype="float32"):
"""Generate a displacements deformation field from this B-Spline field."""
_ref = (
Expand All @@ -293,6 +374,32 @@ def to_field(self, reference=None, dtype="float32"):
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
)

def to_x5(self, metadata=None):
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {})

domain = None
if (reference := self.reference) is not None:
domain = io.x5.X5Domain(
grid=True,
size=getattr(reference, "shape", (0, 0, 0)),
mapping=reference.affine,
coordinates="cartesian",
)

kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)

return io.x5.X5Transform(
type="nonlinear",
subtype="bspline",
representation="coefficients",
metadata=metadata,
transform=self._coeffs,
dimension_kinds=kinds,
domain=domain,
additional_parameters=self._knots.affine,
)

def map(self, x, inverse=False):
r"""
Apply the transformation to a list of physical coordinate points.
Expand Down
47 changes: 47 additions & 0 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
BSplineFieldTransform,
DenseFieldTransform,
)
from nitransforms import io
from ..io.itk import ITKDisplacementsField


Expand Down Expand Up @@ -119,3 +120,49 @@ def test_bspline(tmp_path, testdata_path):
).mean()
< 0.2
)


@pytest.mark.parametrize("is_deltas", [True, False])
def test_densefield_x5_roundtrip(tmp_path, is_deltas):
"""Ensure dense field transforms roundtrip via X5."""
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4))

xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref)

node = xfm.to_x5(metadata={"GeneratedBy": "pytest"})
assert node.type == "nonlinear"
assert node.subtype == "densefield"
assert node.representation == "displacements" if is_deltas else "deformations"
assert node.domain.size == ref.shape
assert node.metadata["GeneratedBy"] == "pytest"

fname = tmp_path / "test.x5"
io.x5.to_filename(fname, [node])

xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5")

assert xfm2.reference.shape == ref.shape
assert np.allclose(xfm2.reference.affine, ref.affine)
assert xfm == xfm2


def test_bspline_to_x5(tmp_path):
"""Check BSpline transforms export to X5."""
coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4))
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))

xfm = BSplineFieldTransform(coeff, reference=ref)
node = xfm.to_x5(metadata={"tool": "pytest"})
assert node.type == "nonlinear"
assert node.subtype == "bspline"
assert node.representation == "coefficients"
assert node.metadata["tool"] == "pytest"

fname = tmp_path / "bspline.x5"
io.x5.to_filename(fname, [node])

xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5")
assert np.allclose(xfm._coeffs, xfm2._coeffs)
assert xfm2.reference.shape == ref.shape
assert np.allclose(xfm2.reference.affine, ref.affine)
Loading