diff --git a/nitransforms/io/__init__.py b/nitransforms/io/__init__.py index f9030724..a2ec7e6b 100644 --- a/nitransforms/io/__init__.py +++ b/nitransforms/io/__init__.py @@ -1,6 +1,7 @@ # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Read and write transforms.""" + from nitransforms.io import afni, fsl, itk, lta, x5 from nitransforms.io.base import TransformIOError, TransformFileError @@ -27,7 +28,37 @@ def get_linear_factory(fmt, is_array=True): - """Return the type required by a given format.""" + """ + Return the type required by a given format. + + Parameters + ---------- + fmt : :obj:`str` + A format identifying string. + is_array : :obj:`bool` + Whether the array version of the class should be returned. + + Returns + ------- + type + The class object (not an instance) of the linear transfrom to be created + (for example, :obj:`~nitransforms.io.itk.ITKLinearTransform`). + + Examples + -------- + >>> get_linear_factory("itk") + + >>> get_linear_factory("itk", is_array=False) + + >>> get_linear_factory("fsl") + + >>> get_linear_factory("fsl", is_array=False) + + >>> get_linear_factory("fakepackage") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + TypeError: Unsupported transform format . + + """ if fmt.lower() not in _IO_TYPES: raise TypeError(f"Unsupported transform format <{fmt}>.") diff --git a/nitransforms/io/lta.py b/nitransforms/io/lta.py index 334266bb..1e7445bf 100644 --- a/nitransforms/io/lta.py +++ b/nitransforms/io/lta.py @@ -1,4 +1,5 @@ """Read/write linear transforms.""" + import numpy as np from nibabel.volumeutils import Recoder from nibabel.affines import voxel_sizes, from_matvec @@ -29,12 +30,12 @@ class VolumeGeometry(StringBasedStruct): template_dtype = np.dtype( [ ("valid", "i4"), # Valid values: 0, 1 - ("volume", "i4", (3, )), # width, height, depth - ("voxelsize", "f4", (3, )), # xsize, ysize, zsize + ("volume", "i4", (3,)), # width, height, depth + ("voxelsize", "f4", (3,)), # xsize, ysize, zsize ("xras", "f8", (3, 1)), # x_r, x_a, x_s ("yras", "f8", (3, 1)), # y_r, y_a, y_s ("zras", "f8", (3, 1)), # z_r, z_a, z_s - ("cras", "f8", (3, )), # c_r, c_a, c_s + ("cras", "f8", (3,)), # c_r, c_a, c_s ("filename", "U1024"), ] ) # Not conformant (may be >1024 bytes) @@ -109,14 +110,19 @@ def from_string(cls, string): label, valstring = lines.pop(0).split(" =") assert label.strip() == key - val = "" - if valstring.strip(): - parsed = np.genfromtxt( + parsed = ( + np.genfromtxt( [valstring.encode()], autostrip=True, dtype=cls.dtype[key] ) - if parsed.size: - val = parsed.reshape(sa[key].shape) - sa[key] = val + if valstring.strip() + else None + ) + + if parsed is not None and parsed.size: + sa[key] = parsed.reshape(sa[key].shape) + else: # pragma: no coverage + """Do not set sa[key]""" + return volgeom @@ -218,11 +224,15 @@ def to_ras(self, moving=None, reference=None): def to_string(self, partial=False): """Convert this transform to text.""" sa = self.structarr - lines = [ - "# LTA file created by NiTransforms", - "type = {}".format(sa["type"]), - "nxforms = 1", - ] if not partial else [] + lines = ( + [ + "# LTA file created by NiTransforms", + "type = {}".format(sa["type"]), + "nxforms = 1", + ] + if not partial + else [] + ) # Standard preamble lines += [ @@ -232,10 +242,7 @@ def to_string(self, partial=False): ] # Format parameters matrix - lines += [ - " ".join(f"{v:18.15e}" for v in sa["m_L"][i]) - for i in range(4) - ] + lines += [" ".join(f"{v:18.15e}" for v in sa["m_L"][i]) for i in range(4)] lines += [ "src volume info", @@ -324,10 +331,7 @@ def __getitem__(self, idx): def to_ras(self, moving=None, reference=None): """Set type to RAS2RAS and return the new matrix.""" self.structarr["type"] = 1 - return [ - xfm.to_ras(moving=moving, reference=reference) - for xfm in self.xforms - ] + return [xfm.to_ras(moving=moving, reference=reference) for xfm in self.xforms] def to_string(self): """Convert this LTA into text format.""" @@ -396,9 +400,11 @@ def from_ras(cls, ras, moving=None, reference=None): sa["type"] = 1 sa["nxforms"] = ras.shape[0] for i in range(sa["nxforms"]): - lt._xforms.append(cls._inner_type.from_ras( - ras[i, ...], moving=moving, reference=reference - )) + lt._xforms.append( + cls._inner_type.from_ras( + ras[i, ...], moving=moving, reference=reference + ) + ) sa["subject"] = "unset" sa["fscale"] = 0.0 @@ -407,8 +413,10 @@ def from_ras(cls, ras, moving=None, reference=None): def _drop_comments(string): """Drop comments.""" - return "\n".join([ - line.split("#")[0].strip() - for line in string.splitlines() - if line.split("#")[0].strip() - ]) + return "\n".join( + [ + line.split("#")[0].strip() + for line in string.splitlines() + if line.split("#")[0].strip() + ] + ) diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py index 463a1336..a86a8554 100644 --- a/nitransforms/io/x5.py +++ b/nitransforms/io/x5.py @@ -136,3 +136,53 @@ def to_filename(fname: str | Path, x5_list: List[X5Transform]): # "AdditionalParameters", data=node.additional_parameters # ) return fname + + +def from_filename(fname: str | Path) -> List[X5Transform]: + """Read a list of :class:`X5Transform` objects from an X5 HDF5 file.""" + try: + with h5py.File(str(fname), "r") as in_file: + if in_file.attrs.get("Format") != "X5": + raise TypeError("Input file is not in X5 format") + + tg = in_file["TransformGroup"] + return [ + _read_x5_group(node) + for _, node in sorted(tg.items(), key=lambda kv: int(kv[0])) + ] + except OSError as err: + if "file signature not found" in err.args[0]: + raise TypeError("Input file is not HDF5.") + + raise # pragma: no cover + + +def _read_x5_group(node) -> X5Transform: + x5 = X5Transform( + type=node.attrs["Type"], + transform=np.asarray(node["Transform"]), + subtype=node.attrs.get("SubType"), + representation=node.attrs.get("Representation"), + metadata=json.loads(node.attrs["Metadata"]) + if "Metadata" in node.attrs + else None, + dimension_kinds=[ + k.decode() if isinstance(k, bytes) else k + for k in node["DimensionKinds"][()] + ], + domain=None, + inverse=np.asarray(node["Inverse"]) if "Inverse" in node else None, + jacobian=np.asarray(node["Jacobian"]) if "Jacobian" in node else None, + array_length=int(node.attrs.get("ArrayLength", 1)), + ) + + if "Domain" in node: + dgrp = node["Domain"] + x5.domain = X5Domain( + grid=bool(int(np.asarray(dgrp["Grid"]))), + size=tuple(np.asarray(dgrp["Size"])), + mapping=np.asarray(dgrp["Mapping"]), + coordinates=dgrp.attrs.get("Coordinates"), + ) + + return x5 diff --git a/nitransforms/linear.py b/nitransforms/linear.py index cf8f8465..79c32776 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -9,6 +9,7 @@ """Linear transforms.""" import warnings +from collections import namedtuple import numpy as np from pathlib import Path @@ -27,7 +28,12 @@ EQUALITY_TOL, ) from nitransforms.io import get_linear_factory, TransformFileError -from nitransforms.io.x5 import X5Transform, X5Domain, to_filename as save_x5 +from nitransforms.io.x5 import ( + X5Transform, + X5Domain, + to_filename as save_x5, + from_filename as load_x5, +) class Affine(TransformBase): @@ -149,19 +155,17 @@ def __matmul__(self, b): True >>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) - >>> xfm1 @ np.eye(4) == xfm1 + >>> xfm1 @ Affine() == xfm1 True """ - if not isinstance(b, self.__class__): - _b = self.__class__(b) - else: - _b = b + if isinstance(b, self.__class__): + return self.__class__( + b.matrix @ self.matrix, + reference=b.reference, + ) - retval = self.__class__(self.matrix.dot(_b.matrix)) - if _b.reference: - retval.reference = _b.reference - return retval + return b @ self @property def matrix(self): @@ -174,8 +178,29 @@ def ndim(self): return self._matrix.ndim + 1 @classmethod - def from_filename(cls, filename, fmt=None, reference=None, moving=None): + def from_filename( + cls, filename, fmt=None, reference=None, moving=None, x5_position=0 + ): """Create an affine from a transform file.""" + + if fmt and fmt.upper() == "X5": + x5_xfm = load_x5(filename)[x5_position] + Transform = cls if x5_xfm.array_length == 1 else LinearTransformsMapping + if ( + x5_xfm.domain + and not x5_xfm.domain.grid + and len(x5_xfm.domain.size) == 3 + ): # pragma: no cover + raise NotImplementedError( + "Only 3D regularly gridded domains are supported" + ) + elif x5_xfm.domain: + # Override reference + Domain = namedtuple("Domain", "affine shape") + reference = Domain(x5_xfm.domain.mapping, x5_xfm.domain.size) + + return Transform(x5_xfm.transform, reference=reference) + fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") if fmt is not None and not Path(filename).exists(): @@ -265,7 +290,9 @@ def to_filename(self, filename, fmt="X5", moving=None, x5_inverse=False): if fmt.upper() == "X5": return save_x5(filename, [self.to_x5(store_inverse=x5_inverse)]) - writer = get_linear_factory(fmt, is_array=isinstance(self, LinearTransformsMapping)) + writer = get_linear_factory( + fmt, is_array=isinstance(self, LinearTransformsMapping) + ) if fmt.lower() in ("itk", "ants", "elastix"): writer.from_ras(self.matrix).to_filename(filename) @@ -348,11 +375,6 @@ def __init__(self, transforms, reference=None): ) self._inverse = np.linalg.inv(self._matrix) - def __iter__(self): - """Enable iterating over the series of transforms.""" - for _m in self.matrix: - yield Affine(_m, reference=self._reference) - def __getitem__(self, i): """Enable indexed access to the series of matrices.""" return Affine(self.matrix[i, ...], reference=self._reference) diff --git a/nitransforms/manip.py b/nitransforms/manip.py index 9389197d..9e0327cf 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -8,7 +8,6 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" from collections.abc import Iterable -import numpy as np from .base import ( TransformBase, @@ -145,9 +144,9 @@ def map(self, x, inverse=False): return x - def asaffine(self, indices=None): + def collapse(self): """ - Combine a succession of linear transforms into one. + Combine a succession of transforms into one. Example ------ @@ -155,7 +154,7 @@ def asaffine(self, indices=None): ... Affine.from_matvec(vec=(2, -10, 3)), ... Affine.from_matvec(vec=(-2, 10, -3)), ... ]) - >>> chain.asaffine() + >>> chain.collapse() array([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], @@ -165,7 +164,7 @@ def asaffine(self, indices=None): ... Affine.from_matvec(vec=(1, 2, 3)), ... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]), ... ]) - >>> chain.asaffine() + >>> chain.collapse() array([[0., 1., 0., 2.], [0., 0., 1., 3.], [1., 0., 0., 1.], @@ -173,7 +172,7 @@ def asaffine(self, indices=None): >>> np.allclose( ... chain.map((4, -2, 1)), - ... chain.asaffine().map((4, -2, 1)), + ... chain.collapse().map((4, -2, 1)), ... ) True @@ -183,9 +182,8 @@ def asaffine(self, indices=None): The indices of the values to extract. """ - affines = self.transforms if indices is None else np.take(self.transforms, indices) - retval = affines[0] - for xfm in affines[1:]: + retval = self.transforms[-1] + for xfm in reversed(self.transforms[:-1]): retval = xfm @ retval return retval diff --git a/nitransforms/resampling.py b/nitransforms/resampling.py index 53750206..98ef4454 100644 --- a/nitransforms/resampling.py +++ b/nitransforms/resampling.py @@ -10,6 +10,7 @@ import asyncio from os import cpu_count +from contextlib import suppress from functools import partial from pathlib import Path from typing import Callable, TypeVar, Union @@ -108,12 +109,17 @@ async def _apply_serial( semaphore = asyncio.Semaphore(max_concurrent) for t in range(n_resamplings): - xfm_t = transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + xfm_t = ( + transform if (n_resamplings == 1 or transform.ndim < 4) else transform[t] + ) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image + targets_t = ( + ImageGrid(spatialimage).index( _as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim) ) + if targets is None + else targets[t, ...] + ) data_t = ( data @@ -127,7 +133,7 @@ async def _apply_serial( partial( ndi.map_coordinates, data_t, - targets, + targets_t, output=output[..., t], order=order, mode=mode, @@ -255,11 +261,22 @@ def apply( dim=_ref.ndim, ) ) - elif xfm_nvols == 1: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + else: + # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints. + targets = ( + ImageGrid(spatialimage).index( + _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) + ) + if targets is None + else targets ) + if targets.ndim == 3: + targets = np.rollaxis(targets, targets.ndim - 1, 0) + else: + assert targets.ndim == 2 + targets = targets[np.newaxis, ...] + if serialize_4d: data = ( np.asanyarray(spatialimage.dataobj, dtype=input_dtype) @@ -294,17 +311,24 @@ def apply( else: data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype) - if targets is None: - targets = ImageGrid(spatialimage).index( # data should be an image - _as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim) - ) - + if data_nvols == 1 and xfm_nvols == 1: + targets = np.squeeze(targets) + assert targets.ndim == 2 # Cast 3D data into 4D if 4D nonsequential transform - if data_nvols == 1 and xfm_nvols > 1: + elif data_nvols == 1 and xfm_nvols > 1: data = data[..., np.newaxis] - if transform.ndim == 4: - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T + if xfm_nvols > 1: + assert targets.ndim == 3 + n_time, n_dim, n_vox = targets.shape + # Reshape to (3, n_time x n_vox) + ijk_targets = np.rollaxis(targets, 0, 2).reshape((n_dim, -1)) + time_row = np.repeat(np.arange(n_time), n_vox)[None, :] + + # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k) + # t is the slowest-changing axis, so we put it first + targets = np.vstack((time_row, ijk_targets)) + data = np.rollaxis(data, data.ndim - 1, 0) resampled = ndi.map_coordinates( data, @@ -323,11 +347,19 @@ def apply( ) hdr.set_data_dtype(output_dtype or spatialimage.header.get_data_dtype()) - moved = spatialimage.__class__( - resampled.reshape(_ref.shape if n_resamplings == 1 else _ref.shape + (-1,)), - _ref.affine, - hdr, - ) + if serialize_4d: + resampled = resampled.reshape( + _ref.shape + if n_resamplings == 1 + else _ref.shape + (resampled.shape[-1],) + ) + else: + resampled = resampled.reshape((-1, *_ref.shape)) + resampled = np.rollaxis(resampled, 0, resampled.ndim) + with suppress(ValueError): + resampled = np.squeeze(resampled, axis=3) + + moved = spatialimage.__class__(resampled, _ref.affine, hdr) return moved output_dtype = output_dtype or input_dtype diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 32634c61..1d45dd81 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -265,6 +265,9 @@ def test_linear_to_x5(tmpdir, store_inverse): aff.to_filename("export1.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export1.x5", fmt="X5") + # Test with Domain img = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="float32"), np.eye(4)) img_path = Path(tmpdir) / "ref.nii.gz" @@ -275,21 +278,32 @@ def test_linear_to_x5(tmpdir, store_inverse): assert node.domain.size == aff.reference.shape aff.to_filename("export2.x5", x5_inverse=store_inverse) + # Test round trip + assert aff == nitl.Affine.from_filename("export2.x5", fmt="X5") + # Test with Jacobian node.jacobian = np.zeros((2, 2, 2), dtype="float32") io.x5.to_filename("export3.x5", [node]) -def test_mapping_to_x5(): +@pytest.mark.parametrize("store_inverse", [True, False]) +def test_mapping_to_x5(tmp_path, store_inverse): mats = [ np.eye(4), np.array([[1, 0, 0, 1], [0, 1, 0, 2], [0, 0, 1, 3], [0, 0, 0, 1]]), ] mapping = nitl.LinearTransformsMapping(mats) - node = mapping.to_x5() + node = mapping.to_x5( + metadata={"GeneratedBy": "FreeSurfer 8"}, store_inverse=store_inverse + ) assert node.array_length == 2 assert node.transform.shape == (2, 4, 4) + mapping.to_filename(tmp_path / "export1.x5", x5_inverse=store_inverse) + + # Test round trip + assert mapping == nitl.Affine.from_filename(tmp_path / "export1.x5", fmt="X5") + def test_mulmat_operator(testdata_path): """Check the @ operator.""" @@ -298,10 +312,10 @@ def test_mulmat_operator(testdata_path): mat2 = from_matvec(np.eye(3), (4, 2, -1)) aff = nitl.Affine(mat1, reference=ref) - composed = aff @ mat2 + composed = aff @ nitl.Affine(mat2) assert composed.reference is None - assert composed == nitl.Affine(mat1.dot(mat2)) + assert composed == nitl.Affine(mat2 @ mat1) composed = nitl.Affine(mat2) @ aff assert composed.reference == aff.reference - assert composed == nitl.Affine(mat2.dot(mat1), reference=ref) + assert composed == nitl.Affine(mat1 @ mat2, reference=ref) diff --git a/nitransforms/tests/test_manip.py b/nitransforms/tests/test_manip.py index b5dd5c62..0275c6e6 100644 --- a/nitransforms/tests/test_manip.py +++ b/nitransforms/tests/test_manip.py @@ -11,6 +11,55 @@ FMT = {"lta": "fs", "tfm": "itk"} +def test_itk_h5(tmp_path, testdata_path): + """Check a translation-only field on one or more axes, different image orientations.""" + os.chdir(str(tmp_path)) + img_fname = testdata_path / "T1w_scanner.nii.gz" + xfm_fname = ( + testdata_path + / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5" + ) + + xfm = _load(xfm_fname) + + assert len(xfm) == 2 + + ref_fname = tmp_path / "reference.nii.gz" + nb.Nifti1Image( + np.zeros(xfm.reference.shape, dtype="uint16"), xfm.reference.affine, + ).to_filename(str(ref_fname)) + + # Then apply the transform and cross-check with software + cmd = APPLY_NONLINEAR_CMD["itk"]( + transform=xfm_fname, + reference=ref_fname, + moving=img_fname, + output="resampled.nii.gz", + extra="", + ) + + # skip test if command is not available on host + exe = cmd.split(" ", 1)[0] + if not shutil.which(exe): + pytest.skip(f"Command {exe} not found on host") + + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved = nb.load("resampled.nii.gz") + + nt_moved = xfm.apply(img_fname, order=0) + nt_moved.to_filename("nt_resampled.nii.gz") + diff = sw_moved.get_fdata() - nt_moved.get_fdata() + # A certain tolerance is necessary because of resampling at borders + assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL + + col_moved = xfm.collapse().apply(img_fname, order=0) + col_moved.to_filename("nt_collapse_resampled.nii.gz") + diff = sw_moved.get_fdata() - col_moved.get_fdata() + # A certain tolerance is necessary because of resampling at borders + assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL + + @pytest.mark.parametrize("ext0", ["lta", "tfm"]) @pytest.mark.parametrize("ext1", ["lta", "tfm"]) @pytest.mark.parametrize("ext2", ["lta", "tfm"]) @@ -31,7 +80,7 @@ def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2): ] ) assert np.allclose( - chain.asaffine().matrix, + chain.collapse().matrix, Affine.from_filename( data_path / "regressions" / f"from-fsnative_to-bold_mode-image.{ext2}", fmt=f"{FMT[ext2]}", diff --git a/nitransforms/tests/test_resampling.py b/nitransforms/tests/test_resampling.py index 2384ad97..0e11df5b 100644 --- a/nitransforms/tests/test_resampling.py +++ b/nitransforms/tests/test_resampling.py @@ -363,3 +363,28 @@ def test_LinearTransformsMapping_apply( reference=testdata_path / "sbref.nii.gz", serialize_nvols=2 if serialize_4d else np.inf, ) + + +@pytest.mark.parametrize("serialize_4d", [True, False]) +def test_apply_4d(serialize_4d): + """Regression test for per-volume transforms with serialized resampling.""" + nvols = 9 + shape = (10, 5, 5) + base = np.zeros(shape, dtype=np.float32) + base[9, 2, 2] = 1 + img = nb.Nifti1Image(np.stack([base] * nvols, axis=-1), np.eye(4)) + + transforms = [] + for i in range(nvols): + mat = np.eye(4) + mat[0, 3] = i + transforms.append(nitl.Affine(mat)) + + extraparams = {} if serialize_4d else {"serialize_nvols": nvols + 1} + + xfm = nitl.LinearTransformsMapping(transforms, reference=img) + + moved = apply(xfm, img, order=0, **extraparams) + data = np.asanyarray(moved.dataobj) + idxs = [tuple(np.argwhere(data[..., i])[0]) for i in range(nvols)] + assert idxs == [(9 - i, 2, 2) for i in range(nvols)] diff --git a/nitransforms/tests/test_x5.py b/nitransforms/tests/test_x5.py index 8502a387..89b49e06 100644 --- a/nitransforms/tests/test_x5.py +++ b/nitransforms/tests/test_x5.py @@ -1,7 +1,8 @@ import numpy as np +import pytest from h5py import File as H5File -from ..io.x5 import X5Transform, X5Domain, to_filename +from ..io.x5 import X5Transform, X5Domain, to_filename, from_filename def test_x5_transform_defaults(): @@ -39,3 +40,38 @@ def test_to_filename(tmp_path): assert "0" in grp assert grp["0"].attrs["Type"] == "linear" assert grp["0"].attrs["ArrayLength"] == 1 + + +def test_from_filename_roundtrip(tmp_path): + domain = X5Domain(grid=False, size=(5, 5, 5), mapping=np.eye(4)) + node = X5Transform( + type="linear", + transform=np.eye(4), + dimension_kinds=("space", "space", "space", "vector"), + domain=domain, + metadata={"foo": "bar"}, + inverse=np.eye(4), + ) + fname = tmp_path / "test.x5" + to_filename(fname, [node]) + + x5_list = from_filename(fname) + assert len(x5_list) == 1 + x5 = x5_list[0] + assert x5.type == node.type + assert np.allclose(x5.transform, node.transform) + assert x5.dimension_kinds == list(node.dimension_kinds) + assert x5.domain.grid == domain.grid + assert x5.domain.size == tuple(domain.size) + assert np.allclose(x5.domain.mapping, domain.mapping) + assert x5.metadata == node.metadata + assert np.allclose(x5.inverse, node.inverse) + + +def test_from_filename_invalid(tmp_path): + fname = tmp_path / "invalid.h5" + with H5File(fname, "w") as f: + f.attrs["Format"] = "NOTX5" + + with pytest.raises(TypeError): + from_filename(fname) diff --git a/pyproject.toml b/pyproject.toml index f11e2e5e..a6ac0859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,3 +98,12 @@ exclude_lines = [ "raise NotImplementedError", "warnings\\.warn", ] + +[tool.flake8] +max-line-length = 99 +doctests = false +ignore = [ + "E266", + "E231", + "W503", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index f355be94..00000000 --- a/setup.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -max-line-length = 99 -doctests = False -ignore = - E266 - E231 - W503 diff --git a/tox.ini b/tox.ini index fe549039..50d167bc 100644 --- a/tox.ini +++ b/tox.ini @@ -59,6 +59,7 @@ description = Check our style guide labels = check deps = flake8 + flake8-pyproject skip_install = true commands = flake8 nitransforms