diff --git a/webknossos/tests/dataset/test_layer.py b/webknossos/tests/dataset/test_layer.py index 34d943006..f69cd4a46 100644 --- a/webknossos/tests/dataset/test_layer.py +++ b/webknossos/tests/dataset/test_layer.py @@ -1,9 +1,466 @@ from pathlib import Path +from typing import Tuple, Any import numpy as np import tensorstore +import pytest # Import pytest +from skimage.transform import AffineTransform # Import AffineTransform import webknossos as wk +from webknossos import BoundingBox, Dataset, Layer, Mag # Import necessary classes + + +# Helper function to create a sample dataset with a color layer +def create_dataset_with_color_layer( + tmp_path: Path, + dataset_name: str, + layer_name: str, + bounding_box_shape: Tuple[int, int, int], + dtype: Any = np.uint8, + num_channels: int = 1, + voxel_size: Tuple[int, int, int] = (1, 1, 1), + mag: Mag = Mag(1), + data_format: str = "raw", # Default to raw for simplicity in tests + chunk_shape: Tuple[int, int, int] = (64, 64, 64), +) -> Layer: + """Helper function to create a dataset with a single color layer.""" + dataset_path = tmp_path / dataset_name + dataset = Dataset( + dataset_path, + voxel_size=voxel_size, + exist_ok=True, # Allow re-creation for tests if needed + ) + layer = dataset.add_layer( + layer_name, + wk.COLOR_CATEGORY, + bounding_box=BoundingBox((0, 0, 0), shape=bounding_box_shape), + dtype_per_channel=dtype, + num_channels=num_channels, + data_format=data_format, + ) + # Add the default mag if it doesn't exist, configure with chunk_shape + if mag not in layer.mags: + layer.add_mag(mag, chunk_shape=chunk_shape) + else: + # Ensure existing mag also has compatible chunk_shape if we were to rely on it + # For simplicity, we assume add_mag handles or we create fresh datasets. + pass + return layer + + +class TestLayerTransform: + def test_identity_transform(self, tmp_path: Path) -> None: + """Tests the transform method with an identity transformation.""" + input_dataset_name = "input_ds_identity" + input_layer_name = "input_layer_identity" + output_dataset_name = "output_ds_identity" + output_layer_name = "output_layer_identity" + bbox_shape = (128, 128, 128) # Make it a bit larger to test chunking + dtype = np.uint8 + mag_level = Mag(1) + chunk_s = (32,32,32) # Smaller chunk for testing multiple chunks + + # 1. Create an input layer with some data + input_layer = create_dataset_with_color_layer( + tmp_path, + input_dataset_name, + input_layer_name, + bbox_shape, + dtype=dtype, + mag=mag_level, + chunk_shape=chunk_s, + ) + # Populate with some data + # Data shape (C, X, Y, Z) + input_data = np.random.randint(0, 255, size=(input_layer.num_channels, *bbox_shape), dtype=dtype) + input_mag_view = input_layer.get_mag(mag_level) + # Write data to the entire bounding box of the mag view + input_mag_view.write(input_data, input_mag_view.bounding_box) + + + # 2. Create an empty output layer with the same properties + output_layer = create_dataset_with_color_layer( + tmp_path, + output_dataset_name, + output_layer_name, + bbox_shape, + dtype=dtype, + mag=mag_level, + chunk_shape=chunk_s, + ) + + # 3. Define an identity inverse_transform function + def identity_inverse_transform(coords: np.ndarray) -> np.ndarray: + # coords are (N, 3) in global output space + # For identity, input global space is the same + return coords + + # 4. Call layer.transform with the identity transform + # Use a small number of threads for testing to avoid overhead issues on CI + input_layer.transform( + output_layer=output_layer, + inverse_transform=identity_inverse_transform, + mag=mag_level, + num_threads=2, # Use 2 threads for testing parallelism + chunk_shape=chunk_s # Pass the same chunk_shape + ) + + # 5. Assert that the data in the output layer is identical to the input layer + output_mag_view = output_layer.get_mag(mag_level) + output_data = output_mag_view.read(output_mag_view.bounding_box) + + np.testing.assert_array_equal(output_data, input_data, + err_msg="Data in output layer does not match input layer after identity transform.") + + # Additional check: ensure bounding boxes were handled correctly + assert output_layer.bounding_box == input_layer.bounding_box + assert output_mag_view.bounding_box.shape == input_mag_view.bounding_box.shape + assert output_mag_view.bounding_box.min_coord == input_mag_view.bounding_box.min_coord + + def test_translation_transform(self, tmp_path: Path) -> None: + """Tests the transform method with a translation.""" + input_dataset_name = "input_ds_translate" + input_layer_name = "input_layer_translate" + output_dataset_name = "output_ds_translate" + output_layer_name = "output_layer_translate" + + bbox_shape_orig = (64, 64, 64) # Original data size + dtype = np.uint8 + mag_level = Mag(1) + chunk_s = (16, 16, 16) # Smaller chunk for testing + num_ch = 1 + + # Translation vector (in global coordinates) + translation_vector = np.array([10, -5, 20]) # dx, dy, dz + + # 1. Create an input layer with some data + input_layer = create_dataset_with_color_layer( + tmp_path, input_dataset_name, input_layer_name, + bbox_shape_orig, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # Populate with sequential data for easier verification + input_data_flat = np.arange(np.prod(bbox_shape_orig), dtype=dtype).reshape(bbox_shape_orig) + input_data = input_data_flat[np.newaxis, ...] # Add channel dimension: (1, X, Y, Z) + input_mag_view = input_layer.get_mag(mag_level) + input_mag_view.write(input_data, input_mag_view.bounding_box) + + # 2. Create an empty output layer with the same properties & original bbox + # The transform method itself doesn't change the output layer's bbox, + # it writes within the output_layer.bounding_box or the specified output_bounding_box. + # For a pure translation, the output data region will be the same size as input. + output_layer = create_dataset_with_color_layer( + tmp_path, output_dataset_name, output_layer_name, + bbox_shape_orig, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + + # 3. Define the inverse_transform for translation + # output_coord -> input_coord. So, if output is translated by T, input is output - T. + inverse_translation = lambda coords: coords - translation_vector + + # 4. Call layer.transform + input_layer.transform( + output_layer=output_layer, + inverse_transform=inverse_translation, + mag=mag_level, + num_threads=1, # Test with single thread first + chunk_shape=chunk_s + ) + + # 5. Manually compute the expected output data + expected_output_data = np.zeros_like(input_data) # Fill with 0 (default fill value) + + # Determine the overlapping region in the output layer's perspective + # These are slices for the output_data array + out_x_slice = slice(max(0, int(translation_vector[0])), min(bbox_shape_orig[0], int(bbox_shape_orig[0] + translation_vector[0]))) + out_y_slice = slice(max(0, int(translation_vector[1])), min(bbox_shape_orig[1], int(bbox_shape_orig[1] + translation_vector[1]))) + out_z_slice = slice(max(0, int(translation_vector[2])), min(bbox_shape_orig[2], int(bbox_shape_orig[2] + translation_vector[2]))) + + # Determine the corresponding region in the input layer's perspective + # These are slices for the input_data array + in_x_slice = slice(max(0, int(-translation_vector[0])), min(bbox_shape_orig[0], int(bbox_shape_orig[0] - translation_vector[0]))) + in_y_slice = slice(max(0, int(-translation_vector[1])), min(bbox_shape_orig[1], int(bbox_shape_orig[1] - translation_vector[1]))) + in_z_slice = slice(max(0, int(-translation_vector[2])), min(bbox_shape_orig[2], int(bbox_shape_orig[2] - translation_vector[2]))) + + expected_output_data[0, out_x_slice, out_y_slice, out_z_slice] = input_data[0, in_x_slice, in_y_slice, in_z_slice] + + output_mag_view = output_layer.get_mag(mag_level) + actual_output_data = output_mag_view.read(output_mag_view.bounding_box) + + np.testing.assert_array_equal(actual_output_data, expected_output_data, + err_msg="Data in output layer does not match expected translated data.") + + def test_scaling_transform(self, tmp_path: Path) -> None: + """Tests the transform method with a scaling transformation (2x magnification).""" + input_dataset_name = "input_ds_scale" + input_layer_name = "input_layer_scale" + output_dataset_name = "output_ds_scale" + output_layer_name = "output_layer_scale" + + input_bbox_shape = (32, 32, 32) # Small input for easier manual verification + dtype = np.uint16 # Use a different dtype + mag_level = Mag(1) # Assume scaling happens at mag 1 for simplicity + chunk_s = (16, 16, 16) + num_ch = 1 + scaling_factor = 2.0 + + # 1. Create input layer and populate with data + input_layer = create_dataset_with_color_layer( + tmp_path, input_dataset_name, input_layer_name, + input_bbox_shape, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # Create simple ramp data for predictable scaling + input_data_flat = np.arange(np.prod(input_bbox_shape), dtype=dtype).reshape(input_bbox_shape) + input_data = input_data_flat[np.newaxis, ...] # (1, X, Y, Z) + input_mag_view = input_layer.get_mag(mag_level) + input_mag_view.write(input_data, input_mag_view.bounding_box) + + # 2. Define output bounding box and create output layer + # Output bounding box should be scaled version of input + output_bbox_shape = tuple(int(s * scaling_factor) for s in input_bbox_shape) + output_layer = create_dataset_with_color_layer( + tmp_path, output_dataset_name, output_layer_name, + output_bbox_shape, dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # The output_bounding_box for the transform call will be the output_layer's bounding_box. + + # 3. Define inverse_transform for scaling + # output_coord -> input_coord. So, input_coord = output_coord / scaling_factor + inverse_scaling = lambda coords: coords / scaling_factor + + # 4. Call layer.transform + input_layer.transform( + output_layer=output_layer, + inverse_transform=inverse_scaling, + mag=mag_level, + num_threads=0, # Test with num_threads=0 (sequential) + chunk_shape=chunk_s + # output_bounding_box is implicitly output_layer.bounding_box here + ) + + # 5. Manually compute expected output data (nearest neighbor) + expected_output_data = np.zeros((num_ch, *output_bbox_shape), dtype=dtype) + for c in range(num_ch): + for ox in range(output_bbox_shape[0]): + for oy in range(output_bbox_shape[1]): + for oz in range(output_bbox_shape[2]): + # Corresponding input coordinate (float) + ix_f = ox / scaling_factor + iy_f = oy / scaling_factor + iz_f = oz / scaling_factor + + # Nearest neighbor rounding + ix = int(round(ix_f)) + iy = int(round(iy_f)) + iz = int(round(iz_f)) + + # Clamp to input bounds + ix = np.clip(ix, 0, input_bbox_shape[0] - 1) + iy = np.clip(iy, 0, input_bbox_shape[1] - 1) + iz = np.clip(iz, 0, input_bbox_shape[2] - 1) + + expected_output_data[c, ox, oy, oz] = input_data[c, ix, iy, iz] + + output_mag_view = output_layer.get_mag(mag_level) + actual_output_data = output_mag_view.read(output_mag_view.bounding_box) + + np.testing.assert_array_equal(actual_output_data, expected_output_data, + err_msg="Data in output layer does not match expected scaled data.") + + def test_affine_transform_simple_rotation(self, tmp_path: Path) -> None: + """Tests a simple 90-degree rotation around Z-axis, followed by translation.""" + input_name, output_name = "input_ds_affine", "output_ds_affine" + layer_name = "layer_affine" + # Input is a 2x1x1 strip along X to make rotation easy to verify + # Centered at (0.5, 0.5, 0.5) before translation for rotation + input_bbox_shape = (2, 1, 1) + dtype, mag_level, num_ch = np.uint8, Mag(1), 1 + chunk_s = (1, 1, 1) + + input_layer = create_dataset_with_color_layer( + tmp_path, input_name, layer_name, input_bbox_shape, + dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + # Data: [value1, value2] along X axis for the single channel + input_data = np.array([[[[100]], [[200]]]]).astype(dtype).reshape(num_ch, *input_bbox_shape) # (C,X,Y,Z) + input_mag_view = input_layer.get_mag(mag_level) + input_mag_view.write(input_data, input_mag_view.bounding_box) + + # Output layer will be large enough to contain the rotated & translated data + # Original data at (0,0,0) and (1,0,0) + # Rotated 90 deg around Z (about origin 0,0): (0,0,0) -> (0,0,0); (1,0,0) -> (0,1,0) + # Translated by (1,1,0): (0,0,0) -> (1,1,0); (0,1,0) -> (1,2,0) + # So, output needs to cover at least (1,1,0) to (1,2,0) + output_bbox_shape = (2, 3, 1) # Make it a bit larger to be safe + output_layer = create_dataset_with_color_layer( + tmp_path, output_name, layer_name, output_bbox_shape, + dtype=dtype, mag=mag_level, chunk_shape=chunk_s, num_channels=num_ch + ) + + # Define affine transform: 90 deg rotation around Z ( skimage uses degrees) then translate by (1,1,0) + # skimage AffineTransform works with (col, row, z) which is (x, y, z) + # Rotation matrix for 90 deg around Z: [[0, -1, 0], [1, 0, 0], [0, 0, 1]] + # Translation: (1,1,0) + # Center of rotation is implicitly (0,0,0) for the matrix part. + # The inverse_transform in the method expects global coords. + # Our input data is at global coords (0,0,0) and (1,0,0). + + # Transformation: T_translate * T_rotate + # Output_coord = T_translate * T_rotate * Input_coord + # Input_coord = T_rotate_inv * T_translate_inv * Output_coord + + # Rotation by +90 deg around Z (counter-clockwise) + # x' = x*cos(a) - y*sin(a) + # y' = x*sin(a) + y*cos(a) + # For +90: x' = -y, y' = x + # Inverse rotation (-90 deg): x = y', y = -x' + + # Translation by (tx, ty, tz) + # x_out = x_rot + tx + # y_out = y_rot + ty + # Inverse translation: x_rot = x_out - tx, y_rot = y_out - ty + + # Combined inverse: + # 1. Output coord (ox, oy, oz) + # 2. Inverse translate: (ox - tx, oy - ty, oz - tz) -> (ix_t, iy_t, iz_t) + # 3. Inverse rotate: (iy_t, -ix_t, iz_t) -> input coord (inx, iny, inz) + + tx, ty, tz = 1, 1, 0 + + def affine_inverse_transform(output_coords_global: np.ndarray) -> np.ndarray: + # output_coords_global is (N,3) + input_coords_translated_inv = output_coords_global - np.array([tx, ty, tz]) + + input_coords_rotated_inv = np.zeros_like(input_coords_translated_inv) + input_coords_rotated_inv[:, 0] = input_coords_translated_inv[:, 1] # x_in = y_translated_inv + input_coords_rotated_inv[:, 1] = -input_coords_translated_inv[:, 0] # y_in = -x_translated_inv + input_coords_rotated_inv[:, 2] = input_coords_translated_inv[:, 2] # z_in = z_translated_inv + return input_coords_rotated_inv + + input_layer.transform( + output_layer, affine_inverse_transform, mag=mag_level, num_threads=None, chunk_shape=chunk_s + ) + + expected_output = np.zeros((num_ch, *output_bbox_shape), dtype=dtype) + # Input (0,0,0) [val 100] -> Rot (0,0,0) -> Trans (1,1,0) + # Input (1,0,0) [val 200] -> Rot (0,1,0) -> Trans (1,2,0) + if 0 <= 1 < output_bbox_shape[0] and 0 <= 1 < output_bbox_shape[1] and 0 <= 0 < output_bbox_shape[2]: + expected_output[0, 1, 1, 0] = 100 # voxel at (1,1,0) in output + if 0 <= 1 < output_bbox_shape[0] and 0 <= 2 < output_bbox_shape[1] and 0 <= 0 < output_bbox_shape[2]: + expected_output[0, 1, 2, 0] = 200 # voxel at (1,2,0) in output + + actual_output = output_layer.get_mag(mag_level).read() + np.testing.assert_array_equal(actual_output, expected_output, err_msg="Affine transform failed.") + + def test_output_bounding_box_smaller_and_shifted(self, tmp_path: Path) -> None: + """Tests transform when output_bounding_box is smaller and shifted.""" + input_ds, input_layer_n = "in_ds_obb", "in_l_obb" + output_ds, output_layer_n = "out_ds_obb", "out_l_obb" + input_bbox_shape = (50, 50, 50) + dtype, mag, num_ch = np.uint8, Mag(1), 1 + chunk_s = (10,10,10) + + input_layer = create_dataset_with_color_layer( + tmp_path, input_ds, input_layer_n, input_bbox_shape, + dtype=dtype, mag=mag, chunk_shape=chunk_s, num_channels=num_ch + ) + # Sequential data + input_data = np.arange(np.prod(input_bbox_shape), dtype=dtype).reshape(num_ch, *input_bbox_shape) + input_layer.get_mag(mag).write(input_data, input_layer.get_mag(mag).bounding_box) + + # Output layer is larger, but we'll only write to a small, shifted part of it + output_layer_bbox_shape = (60, 60, 60) + output_layer = create_dataset_with_color_layer( + tmp_path, output_ds, output_layer_n, output_layer_bbox_shape, + dtype=dtype, mag=mag, chunk_shape=chunk_s, num_channels=num_ch + ) + # Fill output layer with a distinct value to check only target OBB is written + fill_value = 77 + output_layer.get_mag(mag).write(np.full((num_ch, *output_layer_bbox_shape), fill_value, dtype=dtype)) + + + # Define the specific output_bounding_box for the transform operation + # This OBB is in global coordinates. + # Let's pick a 20x20x20 cube shifted by (5,5,5) in the output layer + # This OBB will read from input layer starting at (5,5,5) due to identity transform + obb_min_coord = (5, 5, 5) + obb_shape = (20, 20, 20) + specific_output_bb = BoundingBox(min_coord=obb_min_coord, shape=obb_shape) + + input_layer.transform( + output_layer, + inverse_transform=lambda coords: coords, # Identity + mag=mag, + output_bounding_box=specific_output_bb, + num_threads=2, + chunk_shape=chunk_s + ) + + # Expected data in the output layer + expected_data_in_output_layer = np.full((num_ch, *output_layer_bbox_shape), fill_value, dtype=dtype) + # The part that should be overwritten comes from input_data[0, 5:25, 5:25, 5:25] + # and written to expected_data_in_output_layer[0, 5:25, 5:25, 5:25] + src_data_slice = (slice(None), slice(obb_min_coord[0], obb_min_coord[0]+obb_shape[0]), \ + slice(obb_min_coord[1], obb_min_coord[1]+obb_shape[1]), \ + slice(obb_min_coord[2], obb_min_coord[2]+obb_shape[2])) + expected_data_in_output_layer[src_data_slice] = input_data[src_data_slice] + + actual_output_data = output_layer.get_mag(mag).read() + np.testing.assert_array_equal(actual_output_data, expected_data_in_output_layer) + + def test_transform_all_coords_outside_input(self, tmp_path: Path) -> None: + """Tests transform when all transformed coords are outside input layer's bounds.""" + input_ds, input_layer_n = "in_ds_outside", "in_l_outside" + output_ds, output_layer_n = "out_ds_outside", "out_l_outside" + bbox_shape = (10, 10, 10) + dtype, mag, num_ch, chunk_s = np.uint8, Mag(1), 1, (5,5,5) + + input_layer = create_dataset_with_color_layer( + tmp_path, input_ds, input_layer_n, bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + input_layer.get_mag(mag).write(np.ones((num_ch, *bbox_shape), dtype=dtype)) # Fill with 1s + + output_layer = create_dataset_with_color_layer( + tmp_path, output_ds, output_layer_n, bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + + # Translation that shifts everything out of input bounds + # Input is at [0,0,0] to [9,9,9]. Shift by [100,100,100] + translation_far = np.array([100, 100, 100]) + inverse_transform_far = lambda coords: coords - translation_far + + input_layer.transform( + output_layer, inverse_transform_far, mag=mag, num_threads=1, chunk_shape=chunk_s + ) + + # Expected output is all zeros (default fill value from clamping/empty read) + expected_output = np.zeros((num_ch, *bbox_shape), dtype=dtype) + actual_output = output_layer.get_mag(mag).read() + np.testing.assert_array_equal(actual_output, expected_output) + + def test_transform_small_input_layer(self, tmp_path: Path) -> None: + """Tests transform with a very small input layer (2x2x2).""" + input_ds, input_layer_n = "in_ds_small", "in_l_small" + output_ds, output_layer_n = "out_ds_small", "out_l_small" + input_bbox_shape = (2, 2, 2) + dtype, mag, num_ch, chunk_s = np.uint8, Mag(1), 1, (1,1,1) # Chunk size 1 + + input_layer = create_dataset_with_color_layer( + tmp_path, input_ds, input_layer_n, input_bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + input_data = np.arange(np.prod(input_bbox_shape), dtype=dtype).reshape(num_ch, *input_bbox_shape) + input_layer.get_mag(mag).write(input_data, input_layer.get_mag(mag).bounding_box) + + output_layer = create_dataset_with_color_layer( + tmp_path, output_ds, output_layer_n, input_bbox_shape, dtype=dtype, mag=mag, chunk_shape=chunk_s + ) + + # Simple identity transform + input_layer.transform( + output_layer, lambda coords: coords, mag=mag, num_threads=None, chunk_shape=chunk_s + ) + + actual_output = output_layer.get_mag(mag).read() + np.testing.assert_array_equal(actual_output, input_data) def test_add_mag_from_zarrarray(tmp_path: Path) -> None: diff --git a/webknossos/webknossos/dataset/layer.py b/webknossos/webknossos/dataset/layer.py index d85417524..13cb0fc77 100644 --- a/webknossos/webknossos/dataset/layer.py +++ b/webknossos/webknossos/dataset/layer.py @@ -5,7 +5,7 @@ from os import PathLike from os.path import relpath from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union, Callable, Tuple from urllib.parse import urlparse import numpy as np @@ -14,7 +14,7 @@ from upath import UPath from ..client.context import _get_context -from ..geometry import Mag, NDBoundingBox, Vec3Int, Vec3IntLike +from ..geometry import Mag, NDBoundingBox, Vec3Int, Vec3IntLike, BoundingBox from ..geometry.mag import MagLike from ._array import ArrayException, TensorStoreArray from ._downsampling_utils import ( @@ -1619,6 +1619,201 @@ def __repr__(self) -> str: def _get_largest_segment_id_maybe(self) -> int | None: return None + def transform( + self, + output_layer: "Layer", + inverse_transform: Callable[[np.ndarray], np.ndarray], + mag: Mag = Mag(1), + num_threads: int | None = None, + output_bounding_box: Optional[BoundingBox] = None, + chunk_shape: Optional[Tuple[int, int, int]] = None, + ) -> None: + """Transforms the layer into the output_layer using the inverse_transform function. + + Args: + output_layer (Layer): The layer to write the transformed data to. + inverse_transform (Callable[[np.ndarray], np.ndarray]): A function that takes a NumPy array of shape (N, 3) + representing coordinates in the output space and returns a NumPy array of shape (N, 3) + representing the corresponding coordinates in the input space. + mag (Mag, optional): The magnification level to use for reading and writing. Defaults to Mag(1). + num_threads (int | None, optional): The number of threads to use for parallel processing. + Defaults to None, which means cluster_tools will use its default. + output_bounding_box (BoundingBox | None, optional): The bounding box in the output layer to transform. + Defaults to the bounding box of the output_layer. + chunk_shape (Tuple[int, int, int] | None, optional): The shape of chunks to process. + Defaults to (64, 64, 64). + """ + if output_bounding_box is None: + output_bounding_box = output_layer.bounding_box + if output_bounding_box is None: + raise ValueError( + "output_bounding_box must be provided if output_layer has no bounding_box" + ) + + if chunk_shape is None: + chunk_shape_vec = Vec3Int(64, 64, 64) + else: + chunk_shape_vec = Vec3Int.from_tuple(chunk_shape) + + input_mag_view = self.get_mag(mag) + output_mag_view = output_layer.get_mag(mag) + + # Ensure layers are writable if they are the target of write operations + # self._ensure_writable() # Input layer is read from + output_layer._ensure_writable() + + def process_chunk(chunk_bbox_mag_coords: NDBoundingBox) -> None: + # Generate coordinates for all voxels in the current output chunk (in mag coords) + # chunk_bbox_mag_coords is already in mag units + x_coords = np.arange(chunk_bbox_mag_coords.min_x, chunk_bbox_mag_coords.max_x) + y_coords = np.arange(chunk_bbox_mag_coords.min_y, chunk_bbox_mag_coords.max_y) + z_coords = np.arange(chunk_bbox_mag_coords.min_z, chunk_bbox_mag_coords.max_z) + + # Create a meshgrid of coordinates + # Shape: (3, Dx, Dy, Dz) where D are dimensions of the chunk + grid_x, grid_y, grid_z = np.meshgrid(x_coords, y_coords, z_coords, indexing="ij") + + # Flatten and stack to get (N, 3) array of output coordinates + # N = Dx * Dy * Dz + output_coords_mag = np.stack( + [grid_x.ravel(), grid_y.ravel(), grid_z.ravel()], axis=-1 + ) + + # Convert output coordinates from mag units to global units for inverse_transform + # The inverse_transform function expects global coordinates + output_coords_global = output_coords_mag * mag.to_vec3_int().to_np() + output_bounding_box.min_coord_np[:3] + + + # Apply inverse transform to get input coordinates in global space + input_coords_global = inverse_transform(output_coords_global) + + # Convert transformed input coordinates from global units back to mag units for reading + input_coords_mag = input_coords_global / mag.to_vec3_int().to_np() + + + # Handle out-of-bounds: Clamp to input layer's bounds (in mag units) + input_layer_bbox_mag = input_mag_view.bounding_box.in_mag(mag) # Bbox of input layer in mag units + input_coords_mag[:, 0] = np.clip( + input_coords_mag[:, 0], input_layer_bbox_mag.min_x, input_layer_bbox_mag.max_x -1 + ) + input_coords_mag[:, 1] = np.clip( + input_coords_mag[:, 1], input_layer_bbox_mag.min_y, input_layer_bbox_mag.max_y -1 + ) + input_coords_mag[:, 2] = np.clip( + input_coords_mag[:, 2], input_layer_bbox_mag.min_z, input_layer_bbox_mag.max_z -1 + ) + + # Nearest neighbor interpolation (by rounding to nearest integer) + input_coords_mag_rounded = np.round(input_coords_mag).astype(int) + + # Read data from input layer + # This requires careful handling as MagView.read expects a bounding box, + # not arbitrary coordinates. We need to read a bounding box that encompasses + # all unique input_coords_mag_rounded and then select the specific voxels. + # For simplicity in this step, let's assume we can read individual voxels + # or small patches. A more optimized approach would read larger encompassing BBs. + + # A direct voxel-by-voxel read is inefficient. + # Instead, read the bounding box enclosing all needed input voxels. + min_input_coords = np.min(input_coords_mag_rounded, axis=0) + max_input_coords = np.max(input_coords_mag_rounded, axis=0) + + # Create a bounding box for reading from the input layer (in mag coordinates) + # Add 1 to max_coords because NDBoundingBox to_slices is exclusive for the max coord + input_read_bbox_mag = NDBoundingBox.from_iterator( + np.concatenate((min_input_coords, max_input_coords + 1)) + ) + + + # Ensure the read bounding box is within the input layer's actual data bounds + # This is a bit redundant due to clamping, but good for safety + input_read_bbox_mag = input_read_bbox_mag.intersection(input_layer_bbox_mag) + + if input_read_bbox_mag.is_empty(): # All requested coords are outside input layer + # Fill with zeros or appropriate background value + data_shape = ( + self.num_channels, # C + chunk_bbox_mag_coords.shape[0], # X + chunk_bbox_mag_coords.shape[1], # Y + chunk_bbox_mag_coords.shape[2], # Z + ) + read_data_flat = np.zeros((input_coords_mag_rounded.shape[0], self.num_channels), dtype=self.dtype_per_channel) + else: + # Read the encompassing bounding box from the input layer + encompassing_data = input_mag_view.read(input_read_bbox_mag) # Shape (C, Dx', Dy', Dz') + + # Map the rounded input coordinates to indices within the `encompassing_data` + # These are 0-indexed relative to the start of `input_read_bbox_mag` + relative_input_coords_x = input_coords_mag_rounded[:, 0] - input_read_bbox_mag.min_x + relative_input_coords_y = input_coords_mag_rounded[:, 1] - input_read_bbox_mag.min_y + relative_input_coords_z = input_coords_mag_rounded[:, 2] - input_read_bbox_mag.min_z + + # Gather the data using these relative coordinates + # encompassing_data has shape (C, X, Y, Z) + # We want to select N voxels, resulting in (N, C) + read_data_flat = encompassing_data[ + :, # All channels + relative_input_coords_x, + relative_input_coords_y, + relative_input_coords_z, + ].transpose() # Transpose to get (N, C) + + # Reshape flat data (N, C) back to chunk shape (C, Dx, Dy, Dz) + # The order of raveling for grid_x, grid_y, grid_z was 'ij' which corresponds to 'F' (Fortran-like) order + # when thinking about (x,y,z) dimensions. + # However, numpy's default reshape order is 'C' (C-like). + # The output of np.stack was (N,3) where N = X*Y*Z (iterating Z first, then Y, then X for meshgrid 'ij') + # So, read_data_flat corresponds to this N. + # We need to reshape it to (Dx, Dy, Dz, C) and then transpose to (C, Dx, Dy, Dz) + # Dx, Dy, Dz are dimensions of the output chunk + dx = chunk_bbox_mag_coords.shape[0] + dy = chunk_bbox_mag_coords.shape[1] + dz = chunk_bbox_mag_coords.shape[2] + + output_data = read_data_flat.reshape((dx, dy, dz, self.num_channels)).transpose(3, 0, 1, 2) + + + # Write data to the output layer + # chunk_bbox_mag_coords is already in the correct MagView space for output_mag_view + output_mag_view.write(output_data, chunk_bbox_mag_coords) + + + # Iterate over the output_bounding_box in chunks + # output_bounding_box is in global coordinates. Convert to mag coordinates for iteration. + output_bbox_mag = output_bounding_box.in_mag(mag) + + # Align chunk_shape_vec with the axes of the bounding box + # Assuming output_bbox_mag.axes are ('x', 'y', 'z') for simplicity here. + # If axes can be different, this needs to be handled. + # For now, assume standard 'x', 'y', 'z' order for chunk_shape_vec. + + # Create tasks for the executor + tasks = [] + for current_chunk_min_coord_mag in output_bbox_mag.chunked_coords_iter(chunk_shape_vec): + chunk_max_coord_mag = Vec3Int.min(current_chunk_min_coord_mag + chunk_shape_vec, output_bbox_mag.max_coord) + # Define the bounding box for the current chunk in mag coordinates + # NDBoundingBox expects [min_x, min_y, min_z, max_x, max_y, max_z] + # where max is exclusive for slicing, so it's effectively size. + chunk_bbox_mag = NDBoundingBox.from_min_max_coords(current_chunk_min_coord_mag, chunk_max_coord_mag, axes=output_bbox_mag.axes) + if chunk_bbox_mag.is_empty(): + continue + tasks.append(chunk_bbox_mag) + + if num_threads == 0: # Run sequentially for debugging or specific cases + logging.info(f"Running transform for layer {self.name} sequentially.") + for task_bbox in tasks: + process_chunk(task_bbox) + else: + logging.info(f"Running transform for layer {self.name} with up to {num_threads or 'default'} threads.") + with Executor(max_workers=num_threads) as executor: + # cf_executor is the concurrent.futures.Executor + # The map function in cluster_tools.Executor is for cloud jobs. + # We need to use submit for local multithreading. + futures = [executor.submit(process_chunk, chunk_bbox) for chunk_bbox in tasks] + for future in futures: + future.result() # Wait for completion and raise exceptions if any + + def as_segmentation_layer(self) -> "SegmentationLayer": """Casts into SegmentationLayer.""" if isinstance(self, SegmentationLayer): @@ -1626,6 +1821,7 @@ def as_segmentation_layer(self) -> "SegmentationLayer": else: raise TypeError(f"self is not a SegmentationLayer. Got: {type(self)}") + @classmethod def _ensure_layer(cls, layer: Union[str, PathLike, "Layer"]) -> "Layer": if isinstance(layer, Layer):