diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index a86a855..2f86e8a 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -73,11 +73,11 @@ class X5Transform: For parametric models it is generally possible to obtain it analytically, so this dataset could not be as useful in that case. """ - # additional_parameters: Optional[np.ndarray] = None - # AdditionalParameters is empty in the draft spec - ignore for now. - # Only documentation ATM is for SubType: - # The SubType setting enables setting the additional parameters on a dataset called - # "AdditionalParameters" that hangs directly from this transform node. + additional_parameters: Optional[np.ndarray] = None + """ + An OPTIONAL field to store additional parameters, depending on the SubType of the + transform. + """ array_length: int = 1 """Undocumented field in the draft to enable a single transform group for 4D transforms.""" @@ -130,11 +130,10 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): g.create_dataset("Inverse", data=node.inverse) if node.jacobian is not None: g.create_dataset("Jacobian", data=node.jacobian) - # Disabled until we need SubType and AdditionalParameters - # if node.additional_parameters is not None: - # g.create_dataset( - # "AdditionalParameters", data=node.additional_parameters - # ) + if node.additional_parameters is not None: + g.create_dataset( + "AdditionalParameters", data=node.additional_parameters + ) return fname @@ -174,6 +173,9 @@ def _read_x5_group(node) -> X5Transform: inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None, jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None, array_length=int(node.attrs.get("ArrayLength", 1)), + additional_parameters=np.asarray(node["AdditionalParameters"]) + if "AdditionalParameters" in node + else None, ) if "Domain" in node: diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 9c29c53..372db3c 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -7,12 +7,16 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Nonlinear transforms.""" + import warnings from functools import partial +from collections import namedtuple import numpy as np +import nibabel as nb from nitransforms import io from nitransforms.io.base import _ensure_image +from nitransforms.io.x5 import from_filename as load_x5 from nitransforms.interp.bspline import grid_bspline_weights, _cubic_bspline from nitransforms.base import ( TransformBase, @@ -22,11 +26,17 @@ ) from scipy.ndimage import map_coordinates +# Avoids circular imports +try: + from nitransforms._version import __version__ +except ModuleNotFoundError: # pragma: no cover + __version__ = "0+unknown" + class DenseFieldTransform(TransformBase): """Represents dense field (voxel-wise) transforms.""" - __slots__ = ("_field", "_deltas") + __slots__ = ("_field", "_deltas", "_is_deltas") def __init__(self, field=None, is_deltas=True, reference=None): """ @@ -60,14 +70,7 @@ def __init__(self, field=None, is_deltas=True, reference=None): super().__init__() - if field is not None: - field = _ensure_image(field) - self._field = np.squeeze( - np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field - ) - else: - self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32") - is_deltas = True + self._is_deltas = is_deltas try: self.reference = ImageGrid(reference if reference is not None else field) @@ -78,22 +81,44 @@ def __init__(self, field=None, is_deltas=True, reference=None): else "Reference is not a spatial image" ) + fieldshape = (*self.reference.shape, self.reference.ndim) + if field is not None: + field = _ensure_image(field) + self._field = np.squeeze( + np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field + ) + if fieldshape != self._field.shape: + raise TransformError( + f"Shape of the field ({'x'.join(str(i) for i in self._field.shape)}) " + f"doesn't match that of the reference({'x'.join(str(i) for i in fieldshape)})" + ) + else: + self._field = np.zeros(fieldshape, dtype="float32") + self._is_deltas = True + if self._field.shape[-1] != self.ndim: raise TransformError( "The number of components of the field (%d) does not match " "the number of dimensions (%d)" % (self._field.shape[-1], self.ndim) ) - if is_deltas: - self._deltas = self._field + if self._is_deltas: + self._deltas = ( + self._field.copy() + ) # IMPORTANT: you don't want to update deltas # Convert from displacements (deltas) to deformations fields # (just add its origin to each delta vector) - self._field += self.reference.ndcoords.T.reshape(self._field.shape) + self._field += self.reference.ndcoords.T.reshape(fieldshape) def __repr__(self): """Beautify the python representation.""" return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>" + @property + def is_deltas(self): + """Check whether this is a displacements (``True``) or a deformation (``False``) field.""" + return self._is_deltas + @property def ndim(self): """Get the dimensions of the transform.""" @@ -222,22 +247,60 @@ def __eq__(self, other): True """ - _eq = np.array_equal(self._field, other._field) + _eq = np.allclose(self._field, other._field) if _eq and self._reference != other._reference: warnings.warn("Fields are equal, but references do not match.") return _eq + def to_x5(self, metadata=None): + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) + + domain = None + if (reference := self.reference) is not None: + domain = io.x5.X5Domain( + grid=True, + size=getattr(reference, "shape", (0, 0, 0)), + mapping=reference.affine, + coordinates="cartesian", + ) + + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) + + return io.x5.X5Transform( + type="nonlinear", + subtype="densefield", + representation="displacements" if self.is_deltas else "deformations", + metadata=metadata, + transform=self._deltas if self.is_deltas else self._field, + dimension_kinds=kinds, + domain=domain, + ) + @classmethod def from_filename(cls, filename, fmt="X5"): _factory = { "afni": io.afni.AFNIDisplacementsField, "itk": io.itk.ITKDisplacementsField, "fsl": io.fsl.FSLDisplacementsField, + "X5": None, } - if fmt not in _factory: + fmt = fmt.upper() + if fmt not in {k.upper() for k in _factory}: raise NotImplementedError(f"Unsupported format <{fmt}>") - return cls(_factory[fmt].from_filename(filename)) + if fmt == "X5": + x5_xfm = load_x5(filename)[0] + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + field = nb.Nifti1Image(x5_xfm.transform, reference.affine) + return cls( + field, + is_deltas=x5_xfm.representation == "displacements", + reference=reference, + ) + + return cls(_factory[fmt.lower()].from_filename(filename)) load = DenseFieldTransform.from_filename @@ -272,6 +335,24 @@ def ndim(self): """Get the dimensions of the transform.""" return self._coeffs.ndim - 1 + @classmethod + def from_filename(cls, filename, fmt="X5"): + _factory = { + "X5": None, + } + fmt = fmt.upper() + if fmt not in {k.upper() for k in _factory}: + raise NotImplementedError(f"Unsupported format <{fmt}>") + + x5_xfm = load_x5(filename)[0] + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + + coefficients = nb.Nifti1Image(x5_xfm.transform, x5_xfm.additional_parameters) + return cls(coefficients, reference=reference) + + # return cls(_factory[fmt.lower()].from_filename(filename)) + def to_field(self, reference=None, dtype="float32"): """Generate a displacements deformation field from this B-Spline field.""" _ref = ( @@ -293,6 +374,32 @@ def to_field(self, reference=None, dtype="float32"): field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref ) + def to_x5(self, metadata=None): + """Return an :class:`~nitransforms.io.x5.X5Transform` representation.""" + metadata = {"WrittenBy": f"NiTransforms {__version__}"} | (metadata or {}) + + domain = None + if (reference := self.reference) is not None: + domain = io.x5.X5Domain( + grid=True, + size=getattr(reference, "shape", (0, 0, 0)), + mapping=reference.affine, + coordinates="cartesian", + ) + + kinds = tuple("space" for _ in range(self.ndim)) + ("vector",) + + return io.x5.X5Transform( + type="nonlinear", + subtype="bspline", + representation="coefficients", + metadata=metadata, + transform=self._coeffs, + dimension_kinds=kinds, + domain=domain, + additional_parameters=self._knots.affine, + ) + def map(self, x, inverse=False): r""" Apply the transformation to a list of physical coordinate points. diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index 6112f63..9df5dce 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -14,6 +14,7 @@ BSplineFieldTransform, DenseFieldTransform, ) +from nitransforms import io from ..io.itk import ITKDisplacementsField @@ -119,3 +120,49 @@ def test_bspline(tmp_path, testdata_path): ).mean() < 0.2 ) + + +@pytest.mark.parametrize("is_deltas", [True, False]) +def test_densefield_x5_roundtrip(tmp_path, is_deltas): + """Ensure dense field transforms roundtrip via X5.""" + ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4)) + disp = nb.Nifti1Image(np.random.rand(2, 2, 2, 3).astype("float32"), np.eye(4)) + + xfm = DenseFieldTransform(disp, is_deltas=is_deltas, reference=ref) + + node = xfm.to_x5(metadata={"GeneratedBy": "pytest"}) + assert node.type == "nonlinear" + assert node.subtype == "densefield" + assert node.representation == "displacements" if is_deltas else "deformations" + assert node.domain.size == ref.shape + assert node.metadata["GeneratedBy"] == "pytest" + + fname = tmp_path / "test.x5" + io.x5.to_filename(fname, [node]) + + xfm2 = DenseFieldTransform.from_filename(fname, fmt="X5") + + assert xfm2.reference.shape == ref.shape + assert np.allclose(xfm2.reference.affine, ref.affine) + assert xfm == xfm2 + + +def test_bspline_to_x5(tmp_path): + """Check BSpline transforms export to X5.""" + coeff = nb.Nifti1Image(np.zeros((2, 2, 2, 3), dtype="float32"), np.eye(4)) + ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4)) + + xfm = BSplineFieldTransform(coeff, reference=ref) + node = xfm.to_x5(metadata={"tool": "pytest"}) + assert node.type == "nonlinear" + assert node.subtype == "bspline" + assert node.representation == "coefficients" + assert node.metadata["tool"] == "pytest" + + fname = tmp_path / "bspline.x5" + io.x5.to_filename(fname, [node]) + + xfm2 = BSplineFieldTransform.from_filename(fname, fmt="X5") + assert np.allclose(xfm._coeffs, xfm2._coeffs) + assert xfm2.reference.shape == ref.shape + assert np.allclose(xfm2.reference.affine, ref.affine)