diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 9389197..5ff9d53 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -193,6 +193,9 @@ def asaffine(self, indices=None): def from_filename(cls, filename, fmt="X5", reference=None, moving=None): """Load a transform file.""" from .io import itk + from .io.x5 import from_filename as load_x5 + from . import linear as nitl + from . import nonlinear as nitn retval = [] if str(filename).endswith(".h5"): @@ -206,6 +209,26 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None): return TransformChain(retval) + if fmt.upper() == "X5" or str(filename).endswith(".x5"): + for i, x5_xfm in enumerate(load_x5(filename)): + if x5_xfm.type == "linear": + xfm = nitl.Affine.from_filename( + filename, + fmt="X5", + reference=reference, + moving=moving, + x5_position=i, + ) + elif x5_xfm.type == "nonlinear": + xfm = nitn.DenseFieldTransform.from_x5(x5_xfm) + else: + raise NotImplementedError( + f"Unsupported X5 transform type {x5_xfm.type}" + ) + retval.append(xfm) + + return TransformChain(retval) + raise NotImplementedError diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 9c29c53..a451cb8 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -10,6 +10,7 @@ import warnings from functools import partial import numpy as np +import nibabel as nb from nitransforms import io from nitransforms.io.base import _ensure_image @@ -239,6 +240,26 @@ def from_filename(cls, filename, fmt="X5"): return cls(_factory[fmt].from_filename(filename)) + @classmethod + def from_x5(cls, x5struct): + """Instantiate a dense field transform from an :class:`X5Transform`.""" + if x5struct.type != "nonlinear": + raise TypeError("X5 structure is not a nonlinear transform") + if not x5struct.domain or not x5struct.domain.grid: + raise NotImplementedError( + "Only regularly gridded nonlinear X5 transforms are supported" + ) + + hdr = nb.Nifti1Header() + hdr.set_intent("vector") + img = nb.Nifti1Image(x5struct.transform.astype("float32"), x5struct.domain.mapping, hdr) + + is_deltas = True + if x5struct.representation and "def" in x5struct.representation.lower(): + is_deltas = False + + return cls(img, is_deltas=is_deltas) + load = DenseFieldTransform.from_filename diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index 89b49e0..6c0603b 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -3,6 +3,9 @@ from h5py import File as H5File from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename +from nitransforms import linear as nitl +from nitransforms import manip as nitm +from nitransforms import nonlinear as nitn def test_x5_transform_defaults(): @@ -75,3 +78,36 @@ def test_from_filename_invalid(tmp_path): with pytest.raises(TypeError): from_filename(fname) + + +def test_transformchain_from_x5(tmp_path): + aff1 = nitl.Affine.from_matvec(vec=(1, 2, 3)) + aff2 = nitl.Affine.from_matvec(vec=(-1, -2, -3)) + fname = tmp_path / "chain.x5" + to_filename(fname, [aff1.to_x5(), aff2.to_x5()]) + + chain = nitm.TransformChain.from_filename(fname, fmt="X5") + assert len(chain.transforms) == 2 + assert chain.transforms[0] == aff1 + assert chain.transforms[1] == aff2 + + +def test_transformchain_from_x5_nonlinear(tmp_path): + field = np.zeros((2, 2, 2, 3), dtype=float) + domain = X5Domain(grid=True, size=(2, 2, 2), mapping=np.eye(4)) + nonlinear_node = X5Transform( + type="nonlinear", + transform=field, + representation="dense_field", + dimension_kinds=("space", "space", "space", "vector"), + domain=domain, + ) + aff = nitl.Affine.from_matvec(vec=(0, 0, 0)) + fname = tmp_path / "nonlinear_chain.x5" + to_filename(fname, [aff.to_x5(), nonlinear_node]) + + chain = nitm.TransformChain.from_filename(fname, fmt="X5") + assert len(chain.transforms) == 2 + assert chain.transforms[0] == aff + assert isinstance(chain.transforms[1], nitn.DenseFieldTransform) + assert chain.transforms[1].reference.shape == (2, 2, 2)