Skip to content

ENH: Collapse linear and nonlinear transforms chains #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fbc9228
ENH: Collapse linear and nonlinear transforms chains
oesteban Jul 20, 2022
0b13408
enh: read X5 transform files
oesteban Jul 17, 2025
a22a6d0
refactor: simplify X5 loader
oesteban Jul 17, 2025
2d3ba2f
test: cover x5.from_filename
oesteban Jul 17, 2025
eeb4d5d
enh: enable loading of X5 affines
oesteban Jul 17, 2025
6b000c8
Merge pull request #244 from nipy/codex/exclude-tests-from-coverage-r…
oesteban Jul 18, 2025
8e50969
fix: process exceptions when trying to open X5
oesteban Jul 17, 2025
b02077e
tst: add round-trip test to linear mappings
oesteban Jul 18, 2025
1b544cb
move flake8 config to pyproject
oesteban Jul 18, 2025
c2f9bde
Merge pull request #245 from nipy/codex/migrate-flake8-config-to-pypr…
oesteban Jul 18, 2025
33d91ad
hotfix: make tox pickup flake8 config from ``pyproject.toml``
oesteban Jul 18, 2025
82f58c1
Merge pull request #243 from nipy/codex/add-support-for-affine-and-li…
oesteban Jul 18, 2025
55e6937
tst: refactor io/lta to reduce one partial line
oesteban Jul 18, 2025
a94e577
Merge pull request #246 from nipy/tst/one-less-partial-lta
oesteban Jul 18, 2025
987eaa8
Add failing test for serialized resampling
oesteban Jul 18, 2025
5f08a36
enh: add docstring and doctests for `nitransforms.io.get_linear_factory`
oesteban Jul 18, 2025
5da27b1
FIX: recompute targets for serialized per-volume resampling
oesteban Jul 18, 2025
b562eb3
fix: remove ``__iter__()`` as iterator protocol is not met
oesteban Jul 18, 2025
72cd04f
fix: recompute coordinates per volume in serial resampling
oesteban Jul 18, 2025
4e159c2
fix: generalize targets, test all branches
oesteban Jul 18, 2025
f1efba1
Merge pull request #247 from nipy/codex/investigate-4d-dataset-resamp…
oesteban Jul 18, 2025
a7265e4
Merge branch 'master' into tst/increase-coverage
oesteban Jul 18, 2025
6f497c0
Merge pull request #248 from nipy/tst/increase-coverage
oesteban Jul 18, 2025
ac69db0
Merge branch 'master' into enh/89-collapse-nonlinear
oesteban Jul 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion nitransforms/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Read and write transforms."""

from nitransforms.io import afni, fsl, itk, lta, x5
from nitransforms.io.base import TransformIOError, TransformFileError

Expand All @@ -27,7 +28,37 @@


def get_linear_factory(fmt, is_array=True):
"""Return the type required by a given format."""
"""
Return the type required by a given format.

Parameters
----------
fmt : :obj:`str`
A format identifying string.
is_array : :obj:`bool`
Whether the array version of the class should be returned.

Returns
-------
type
The class object (not an instance) of the linear transfrom to be created
(for example, :obj:`~nitransforms.io.itk.ITKLinearTransform`).

Examples
--------
>>> get_linear_factory("itk")
<class 'nitransforms.io.itk.ITKLinearTransformArray'>
>>> get_linear_factory("itk", is_array=False)
<class 'nitransforms.io.itk.ITKLinearTransform'>
>>> get_linear_factory("fsl")
<class 'nitransforms.io.fsl.FSLLinearTransformArray'>
>>> get_linear_factory("fsl", is_array=False)
<class 'nitransforms.io.fsl.FSLLinearTransform'>
>>> get_linear_factory("fakepackage") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TypeError: Unsupported transform format <fakepackage>.

"""
if fmt.lower() not in _IO_TYPES:
raise TypeError(f"Unsupported transform format <{fmt}>.")

Expand Down
68 changes: 38 additions & 30 deletions nitransforms/io/lta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Read/write linear transforms."""

import numpy as np
from nibabel.volumeutils import Recoder
from nibabel.affines import voxel_sizes, from_matvec
Expand Down Expand Up @@ -29,12 +30,12 @@ class VolumeGeometry(StringBasedStruct):
template_dtype = np.dtype(
[
("valid", "i4"), # Valid values: 0, 1
("volume", "i4", (3, )), # width, height, depth
("voxelsize", "f4", (3, )), # xsize, ysize, zsize
("volume", "i4", (3,)), # width, height, depth
("voxelsize", "f4", (3,)), # xsize, ysize, zsize
("xras", "f8", (3, 1)), # x_r, x_a, x_s
("yras", "f8", (3, 1)), # y_r, y_a, y_s
("zras", "f8", (3, 1)), # z_r, z_a, z_s
("cras", "f8", (3, )), # c_r, c_a, c_s
("cras", "f8", (3,)), # c_r, c_a, c_s
("filename", "U1024"),
]
) # Not conformant (may be >1024 bytes)
Expand Down Expand Up @@ -109,14 +110,19 @@ def from_string(cls, string):
label, valstring = lines.pop(0).split(" =")
assert label.strip() == key

val = ""
if valstring.strip():
parsed = np.genfromtxt(
parsed = (
np.genfromtxt(
[valstring.encode()], autostrip=True, dtype=cls.dtype[key]
)
if parsed.size:
val = parsed.reshape(sa[key].shape)
sa[key] = val
if valstring.strip()
else None
)

if parsed is not None and parsed.size:
sa[key] = parsed.reshape(sa[key].shape)
else: # pragma: no coverage
"""Do not set sa[key]"""

return volgeom


Expand Down Expand Up @@ -218,11 +224,15 @@ def to_ras(self, moving=None, reference=None):
def to_string(self, partial=False):
"""Convert this transform to text."""
sa = self.structarr
lines = [
"# LTA file created by NiTransforms",
"type = {}".format(sa["type"]),
"nxforms = 1",
] if not partial else []
lines = (
[
"# LTA file created by NiTransforms",
"type = {}".format(sa["type"]),
"nxforms = 1",
]
if not partial
else []
)

# Standard preamble
lines += [
Expand All @@ -232,10 +242,7 @@ def to_string(self, partial=False):
]

# Format parameters matrix
lines += [
" ".join(f"{v:18.15e}" for v in sa["m_L"][i])
for i in range(4)
]
lines += [" ".join(f"{v:18.15e}" for v in sa["m_L"][i]) for i in range(4)]

lines += [
"src volume info",
Expand Down Expand Up @@ -324,10 +331,7 @@ def __getitem__(self, idx):
def to_ras(self, moving=None, reference=None):
"""Set type to RAS2RAS and return the new matrix."""
self.structarr["type"] = 1
return [
xfm.to_ras(moving=moving, reference=reference)
for xfm in self.xforms
]
return [xfm.to_ras(moving=moving, reference=reference) for xfm in self.xforms]

def to_string(self):
"""Convert this LTA into text format."""
Expand Down Expand Up @@ -396,9 +400,11 @@ def from_ras(cls, ras, moving=None, reference=None):
sa["type"] = 1
sa["nxforms"] = ras.shape[0]
for i in range(sa["nxforms"]):
lt._xforms.append(cls._inner_type.from_ras(
ras[i, ...], moving=moving, reference=reference
))
lt._xforms.append(
cls._inner_type.from_ras(
ras[i, ...], moving=moving, reference=reference
)
)

sa["subject"] = "unset"
sa["fscale"] = 0.0
Expand All @@ -407,8 +413,10 @@ def from_ras(cls, ras, moving=None, reference=None):

def _drop_comments(string):
"""Drop comments."""
return "\n".join([
line.split("#")[0].strip()
for line in string.splitlines()
if line.split("#")[0].strip()
])
return "\n".join(
[
line.split("#")[0].strip()
for line in string.splitlines()
if line.split("#")[0].strip()
]
)
50 changes: 50 additions & 0 deletions nitransforms/io/x5.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,53 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]):
# "AdditionalParameters", data=node.additional_parameters
# )
return fname


def from_filename(fname: str | Path) -> List[X5Transform]:
"""Read a list of :class:`X5Transform` objects from an X5 HDF5 file."""
try:
with h5py.File(str(fname), "r") as in_file:
if in_file.attrs.get("Format") != "X5":
raise TypeError("Input file is not in X5 format")

tg = in_file["TransformGroup"]
return [
_read_x5_group(node)
for _, node in sorted(tg.items(), key=lambda kv: int(kv[0]))
]
except OSError as err:
if "file signature not found" in err.args[0]:
raise TypeError("Input file is not HDF5.")

raise # pragma: no cover


def _read_x5_group(node) -> X5Transform:
x5 = X5Transform(
type=node.attrs["Type"],
transform=np.asarray(node["Transform"]),
subtype=node.attrs.get("SubType"),
representation=node.attrs.get("Representation"),
metadata=json.loads(node.attrs["Metadata"])
if "Metadata" in node.attrs
else None,
dimension_kinds=[
k.decode() if isinstance(k, bytes) else k
for k in node["DimensionKinds"][()]
],
domain=None,
inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None,
jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None,
array_length=int(node.attrs.get("ArrayLength", 1)),
)

if "Domain" in node:
dgrp = node["Domain"]
x5.domain = X5Domain(
grid=bool(int(np.asarray(dgrp["Grid"]))),
size=tuple(np.asarray(dgrp["Size"])),
mapping=np.asarray(dgrp["Mapping"]),
coordinates=dgrp.attrs.get("Coordinates"),
)

return x5
56 changes: 39 additions & 17 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Linear transforms."""

import warnings
from collections import namedtuple
import numpy as np
from pathlib import Path

Expand All @@ -27,7 +28,12 @@
EQUALITY_TOL,
)
from nitransforms.io import get_linear_factory, TransformFileError
from nitransforms.io.x5 import X5Transform, X5Domain, to_filename as save_x5
from nitransforms.io.x5 import (
X5Transform,
X5Domain,
to_filename as save_x5,
from_filename as load_x5,
)


class Affine(TransformBase):
Expand Down Expand Up @@ -149,19 +155,17 @@ def __matmul__(self, b):
True

>>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
>>> xfm1 @ np.eye(4) == xfm1
>>> xfm1 @ Affine() == xfm1
True

"""
if not isinstance(b, self.__class__):
_b = self.__class__(b)
else:
_b = b
if isinstance(b, self.__class__):
return self.__class__(
b.matrix @ self.matrix,
reference=b.reference,
)

retval = self.__class__(self.matrix.dot(_b.matrix))
if _b.reference:
retval.reference = _b.reference
return retval
return b @ self

@property
def matrix(self):
Expand All @@ -174,8 +178,29 @@ def ndim(self):
return self._matrix.ndim + 1

@classmethod
def from_filename(cls, filename, fmt=None, reference=None, moving=None):
def from_filename(
cls, filename, fmt=None, reference=None, moving=None, x5_position=0
):
"""Create an affine from a transform file."""

if fmt and fmt.upper() == "X5":
x5_xfm = load_x5(filename)[x5_position]
Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping
if (
x5_xfm.domain
and not x5_xfm.domain.grid
and len(x5_xfm.domain.size) == 3
): # pragma: no cover
raise NotImplementedError(
"Only 3D regularly gridded domains are supported"
)
elif x5_xfm.domain:
# Override reference
Domain = namedtuple("Domain", "affine shape")
reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size)

return Transform(x5_xfm.transform, reference=reference)

fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")

if fmt is not None and not Path(filename).exists():
Expand Down Expand Up @@ -265,7 +290,9 @@ def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False):
if fmt.upper() == "X5":
return save_x5(filename, [self.to_x5(store_inverse=x5_inverse)])

writer = get_linear_factory(fmt, is_array=isinstance(self, LinearTransformsMapping))
writer = get_linear_factory(
fmt, is_array=isinstance(self, LinearTransformsMapping)
)

if fmt.lower() in ("itk", "ants", "elastix"):
writer.from_ras(self.matrix).to_filename(filename)
Expand Down Expand Up @@ -348,11 +375,6 @@ def __init__(self, transforms, reference=None):
)
self._inverse = np.linalg.inv(self._matrix)

def __iter__(self):
"""Enable iterating over the series of transforms."""
for _m in self.matrix:
yield Affine(_m, reference=self._reference)

def __getitem__(self, i):
"""Enable indexed access to the series of matrices."""
return Affine(self.matrix[i, ...], reference=self._reference)
Expand Down
16 changes: 7 additions & 9 deletions nitransforms/manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""
from collections.abc import Iterable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from collections.abc import Iterable
from collections.abc import Iterable
from functools import reduce
import operator as op

import numpy as np

from .base import (
TransformBase,
Expand Down Expand Up @@ -145,17 +144,17 @@ def map(self, x, inverse=False):

return x

def asaffine(self, indices=None):
def collapse(self):
"""
Combine a succession of linear transforms into one.
Combine a succession of transforms into one.

Example
------
>>> chain = TransformChain(transforms=[
... Affine.from_matvec(vec=(2, -10, 3)),
... Affine.from_matvec(vec=(-2, 10, -3)),
... ])
>>> chain.asaffine()
>>> chain.collapse()
array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
Expand All @@ -165,15 +164,15 @@ def asaffine(self, indices=None):
... Affine.from_matvec(vec=(1, 2, 3)),
... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]),
... ])
>>> chain.asaffine()
>>> chain.collapse()
array([[0., 1., 0., 2.],
[0., 0., 1., 3.],
[1., 0., 0., 1.],
[0., 0., 0., 1.]])

>>> np.allclose(
... chain.map((4, -2, 1)),
... chain.asaffine().map((4, -2, 1)),
... chain.collapse().map((4, -2, 1)),
... )
True

Expand All @@ -183,9 +182,8 @@ def asaffine(self, indices=None):
The indices of the values to extract.

"""
affines = self.transforms if indices is None else np.take(self.transforms, indices)
retval = affines[0]
for xfm in affines[1:]:
retval = self.transforms[-1]
for xfm in reversed(self.transforms[:-1]):
retval = xfm @ retval
return retval
Comment on lines +185 to 188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it would be more intuitive to swap the arguments of the @ than to reverse the order of the list:

Suggested change
retval = self.transforms[-1]
for xfm in reversed(self.transforms[:-1]):
retval = xfm @ retval
return retval
retval = affines[0]
for xfm in affines[1:]:
retval = retval @ xfm
return retval

But we can also just use a reduce (I've added the imports above if you want to go this way):

Suggested change
retval = self.transforms[-1]
for xfm in reversed(self.transforms[:-1]):
retval = xfm @ retval
return retval
return reduce(op.matmul, self.transforms)


Expand Down
Loading
Loading