Skip to content

Commit 5dc7281

Browse files
committed
Add array_length support for X5 and adjust mapping
1 parent eff0b71 commit 5dc7281

File tree

5 files changed

+192
-8
lines changed

5 files changed

+192
-8
lines changed

nitransforms/io/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Read and write transforms."""
4-
from nitransforms.io import afni, fsl, itk, lta
4+
from nitransforms.io import afni, fsl, itk, lta, x5
55
from nitransforms.io.base import TransformIOError, TransformFileError
66

77
__all__ = [
88
"afni",
99
"fsl",
1010
"itk",
1111
"lta",
12+
"x5",
1213
"get_linear_factory",
1314
"TransformFileError",
1415
"TransformIOError",

nitransforms/io/x5.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Data structures for the X5 transform format."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import Any, Dict, Optional, Sequence, List
7+
8+
import json
9+
import h5py
10+
11+
import numpy as np
12+
13+
14+
@dataclass
15+
class X5Domain:
16+
"""Domain information of a transform."""
17+
18+
grid: bool
19+
size: Sequence[int]
20+
mapping: np.ndarray
21+
coordinates: Optional[str] = None
22+
23+
24+
@dataclass
25+
class X5Transform:
26+
"""Represent one transform entry of an X5 file."""
27+
28+
type: str
29+
transform: np.ndarray
30+
dimension_kinds: Sequence[str]
31+
array_length: int = 1
32+
domain: Optional[X5Domain] = None
33+
subtype: Optional[str] = None
34+
representation: Optional[str] = None
35+
metadata: Optional[Dict[str, Any]] = None
36+
inverse: Optional[np.ndarray] = None
37+
jacobian: Optional[np.ndarray] = None
38+
additional_parameters: Optional[np.ndarray] = None
39+
40+
41+
def to_filename(fname: str, x5_list: List[X5Transform]):
42+
"""Write a list of :class:`X5Transform` objects to an X5 HDF5 file."""
43+
with h5py.File(str(fname), "w") as out_file:
44+
out_file.attrs["Format"] = "X5"
45+
out_file.attrs["Version"] = np.uint16(1)
46+
tg = out_file.create_group("TransformGroup")
47+
for i, node in enumerate(x5_list):
48+
g = tg.create_group(str(i))
49+
g.attrs["Type"] = node.type
50+
g.attrs["ArrayLength"] = node.array_length
51+
if node.subtype is not None:
52+
g.attrs["SubType"] = node.subtype
53+
if node.representation is not None:
54+
g.attrs["Representation"] = node.representation
55+
if node.metadata is not None:
56+
g.attrs["Metadata"] = json.dumps(node.metadata)
57+
g.create_dataset("Transform", data=node.transform)
58+
g.create_dataset(
59+
"DimensionKinds",
60+
data=np.asarray(node.dimension_kinds, dtype="S"),
61+
)
62+
if node.domain is not None:
63+
dgrp = g.create_group("Domain")
64+
dgrp.create_dataset(
65+
"Grid", data=np.uint8(1 if node.domain.grid else 0)
66+
)
67+
dgrp.create_dataset("Size", data=np.asarray(node.domain.size))
68+
dgrp.create_dataset("Mapping", data=node.domain.mapping)
69+
if node.domain.coordinates is not None:
70+
dgrp.attrs["Coordinates"] = node.domain.coordinates
71+
72+
if node.inverse is not None:
73+
g.create_dataset("Inverse", data=node.inverse)
74+
if node.jacobian is not None:
75+
g.create_dataset("Jacobian", data=node.jacobian)
76+
if node.additional_parameters is not None:
77+
g.create_dataset(
78+
"AdditionalParameters", data=node.additional_parameters
79+
)
80+
return str(fname)
81+

nitransforms/linear.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
EQUALITY_TOL,
2121
)
2222
from nitransforms.io import get_linear_factory, TransformFileError
23+
from nitransforms.io.x5 import X5Transform, X5Domain
2324

2425

2526
class Affine(TransformBase):
@@ -276,6 +277,25 @@ def __repr__(self):
276277
"""
277278
return repr(self.matrix)
278279

280+
def to_x5(self):
281+
"""Return an :class:`~nitransforms.io.x5.X5Transform` representation."""
282+
domain = None
283+
if self._reference is not None:
284+
domain = X5Domain(
285+
grid=True,
286+
size=self.reference.shape,
287+
mapping=self.reference.affine,
288+
)
289+
kinds = tuple("space" for _ in range(self.ndim)) + ("vector",)
290+
return X5Transform(
291+
type="linear",
292+
subtype="affine",
293+
transform=self.matrix,
294+
dimension_kinds=kinds,
295+
domain=domain,
296+
inverse=(~self).matrix,
297+
)
298+
279299

280300
class LinearTransformsMapping(Affine):
281301
"""Represents a series of linear transforms."""
@@ -330,6 +350,26 @@ def __getitem__(self, i):
330350
"""Enable indexed access to the series of matrices."""
331351
return Affine(self.matrix[i, ...], reference=self._reference)
332352

353+
def to_x5(self):
354+
"""Return an :class:`~nitransforms.io.x5.X5Transform` object."""
355+
domain = None
356+
if self._reference is not None:
357+
domain = X5Domain(
358+
grid=True,
359+
size=self.reference.shape,
360+
mapping=self.reference.affine,
361+
)
362+
kinds = tuple("space" for _ in range(self.ndim - 1)) + ("vector",)
363+
return X5Transform(
364+
type="linear",
365+
subtype="affine",
366+
transform=self.matrix,
367+
dimension_kinds=kinds,
368+
domain=domain,
369+
inverse=self._inverse,
370+
array_length=len(self),
371+
)
372+
333373
def map(self, x, inverse=False):
334374
r"""
335375
Apply :math:`y = f(x)`.

nitransforms/tests/test_linear.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
import numpy as np
77
import h5py
88

9+
from pathlib import Path
10+
911
from nibabel.eulerangles import euler2mat
1012
from nibabel.affines import from_matvec
13+
import nibabel as nb
1114
from nitransforms import linear as nitl
1215
from nitransforms import io
1316
from .utils import assert_affines_by_filename
@@ -243,16 +246,35 @@ def test_linear_save(tmpdir, data_path, get_testdata, image_orientation, sw_tool
243246
assert_affines_by_filename(xfm_fname1, xfm_fname2)
244247

245248

246-
def test_Affine_to_x5(tmpdir, testdata_path):
249+
def test_Affine_to_x5(tmpdir):
247250
"""Test affine's operations."""
248251
tmpdir.chdir()
249252
aff = nitl.Affine()
250-
with h5py.File("xfm.x5", "w") as f:
251-
aff._to_hdf5(f.create_group("Affine"))
252-
253-
aff.reference = testdata_path / "someones_anatomy.nii.gz"
254-
with h5py.File("withref-xfm.x5", "w") as f:
255-
aff._to_hdf5(f.create_group("Affine"))
253+
node = aff.to_x5()
254+
assert node.type == "linear"
255+
assert node.domain is None
256+
assert node.transform.shape == (4, 4)
257+
assert node.array_length == 1
258+
259+
img = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="float32"), np.eye(4))
260+
img_path = Path(tmpdir) / "ref.nii.gz"
261+
img.to_filename(str(img_path))
262+
263+
aff.reference = img_path
264+
node = aff.to_x5()
265+
assert node.domain.grid
266+
assert node.domain.size == aff.reference.shape
267+
268+
269+
def test_mapping_to_x5():
270+
mats = [
271+
np.eye(4),
272+
np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]),
273+
]
274+
mapping = nitl.LinearTransformsMapping(mats)
275+
node = mapping.to_x5()
276+
assert node.array_length == 2
277+
assert node.transform.shape == (2, 4, 4)
256278

257279

258280
def test_mulmat_operator(testdata_path):

nitransforms/tests/test_x5.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
from h5py import File as H5File
3+
4+
from ..io.x5 import X5Transform, X5Domain, to_filename
5+
6+
7+
def test_x5_transform_defaults():
8+
xf = X5Transform(
9+
type="linear",
10+
transform=np.eye(4),
11+
dimension_kinds=("space", "space", "space", "vector"),
12+
)
13+
assert xf.domain is None
14+
assert xf.subtype is None
15+
assert xf.representation is None
16+
assert xf.metadata is None
17+
assert xf.inverse is None
18+
assert xf.jacobian is None
19+
assert xf.additional_parameters is None
20+
assert xf.array_length == 1
21+
22+
23+
def test_to_filename(tmp_path):
24+
domain = X5Domain(grid=True, size=(10, 10, 10), mapping=np.eye(4))
25+
node = X5Transform(
26+
type="linear",
27+
transform=np.eye(4),
28+
dimension_kinds=("space", "space", "space", "vector"),
29+
domain=domain,
30+
)
31+
fname = tmp_path / "test.x5"
32+
to_filename(fname, [node])
33+
34+
with H5File(fname, "r") as f:
35+
assert f.attrs["Format"] == "X5"
36+
assert f.attrs["Version"] == 1
37+
grp = f["TransformGroup"]
38+
assert "0" in grp
39+
assert grp["0"].attrs["Type"] == "linear"
40+
assert grp["0"].attrs["ArrayLength"] == 1

0 commit comments

Comments
 (0)