diff --git a/deeptrack/image.py b/deeptrack/image.py index 6a221a3f..54697c33 100644 --- a/deeptrack/image.py +++ b/deeptrack/image.py @@ -96,11 +96,17 @@ class is central to DeepTrack2, acting as a container for numerical data import operator as ops from typing import Any, Callable, Iterable +import array_api_compat as apc import numpy as np +from numpy.typing import NDArray +from deeptrack.backend import config, TORCH_AVAILABLE, xp from deeptrack.properties import Property from deeptrack.types import NumberLike +if TORCH_AVAILABLE: + import torch + #TODO ***??*** revise _binary_method - typing, docstring, unit test def _binary_method( @@ -1694,12 +1700,11 @@ def coerce( _FASTEST_SIZES = np.sort(_FASTEST_SIZES) -#TODO ***??*** revise pad_image_to_fft - typing, docstring, unit test def pad_image_to_fft( - image: Image | np.ndarray | np.ndarray, + image: Image | NDArray | torch.Tensor, axes: Iterable[int] = (0, 1), -) -> Image | np.ndarray: - """Pads an image to optimize Fast Fourier Transform (FFT) performance. +) -> Image | NDArray | torch.Tensor: + """Pad an image to optimize Fast Fourier Transform (FFT) performance. This function pads an image by adding zeros to the end of specified axes so that their lengths match the nearest larger size in `_FASTEST_SIZES`. @@ -1707,7 +1712,7 @@ def pad_image_to_fft( Parameters ---------- - image: Image | np.ndarray + image: Image | np.ndarray | torch.tensor The input image to pad. It should be an instance of the `Image` class or any array-like structure compatible with FFT operations. axes: Iterable[int], optional @@ -1715,7 +1720,7 @@ def pad_image_to_fft( Returns ------- - Image | np.ndarray + Image | np.ndarray | torch.tensor The padded image with dimensions optimized for FFT performance. Raises @@ -1725,30 +1730,37 @@ def pad_image_to_fft( Examples -------- - >>> import numpy as np >>> from deeptrack.image import Image, pad_image_to_fft Pad an Image object: - - >>> img = Image(np.zeros((7, 13))) + >>> import numpy as np + >>> + >>> img = Image(np.ones((7, 13))) >>> padded_img = pad_image_to_fft(img) >>> print(padded_img.shape) (8, 16) Pad a NumPy array: - - >>> img = np.zeros((5, 11))) + >>> img = np.ones((5, 11)) >>> padded_img = pad_image_to_fft(img) >>> print(padded_img.shape) (6, 12) + Pad a PyTorch tensor: + >>> import torch + >>> + >>> img = torch.ones(7, 11) + >>> padded_img = pad_image_to_fft(img) + >>> print(padded_img.shape) + (8, 12) + """ def _closest( dim: int, ) -> int: - # Returns the smallest value frin _FASTEST_SIZES larger than dim. + # Return the smallest value from _FASTEST_SIZES that is >= dim. for size in _FASTEST_SIZES: if size >= dim: return size @@ -1763,7 +1775,18 @@ def _closest( new_shape[axis] = _closest(new_shape[axis]) # Calculate the padding for each axis. - pad_width = [(0, increase) for increase in np.array(new_shape) - image.shape] + pad_width = [ + (0, increase) + for increase in np.array(new_shape) - np.array(image.shape) + ] + + # Apply zero-padding with torch.nn.functional.pad if the input is a + # PyTorch tensor + if apc.is_torch_array(image): + pad = [] + for before, after in reversed(pad_width): + pad.extend([before, after]) + return torch.nn.functional.pad(image, pad, mode="constant", value=0) - # Pad the image using constant mode (add zeros). + # Apply zero-padding with np.pad if the input is a NumPy array or an Image return np.pad(image, pad_width, mode="constant") diff --git a/deeptrack/tests/test_image.py b/deeptrack/tests/test_image.py index d413c8da..5d4901a8 100644 --- a/deeptrack/tests/test_image.py +++ b/deeptrack/tests/test_image.py @@ -12,7 +12,10 @@ import numpy as np -from deeptrack import features, image +from deeptrack import features, image, TORCH_AVAILABLE + +if TORCH_AVAILABLE: + import torch class TestImage(unittest.TestCase): @@ -389,6 +392,7 @@ def test_Image__view(self): def test_pad_image_to_fft(self): + # Test with dt.Image input_image = image.Image(np.zeros((7, 25))) padded_image = image.pad_image_to_fft(input_image) self.assertEqual(padded_image.shape, (8, 27)) @@ -401,6 +405,33 @@ def test_pad_image_to_fft(self): padded_image = image.pad_image_to_fft(input_image) self.assertEqual(padded_image.shape, (324, 432)) + # Test with NumPy array + input_image = np.ones((7, 13)) + padded_image = image.pad_image_to_fft(input_image) + self.assertEqual(padded_image.shape, (8, 16)) + + input_image = np.ones((5,)) + padded_image = image.pad_image_to_fft(input_image, axes=(0,)) + self.assertEqual(padded_image.shape, (6,)) + + ### Test with PyTorch tensor (if available) + if TORCH_AVAILABLE: + input_image = torch.ones(3, 5) + padded_image = image.pad_image_to_fft(input_image) + self.assertEqual(padded_image.shape, (3, 6)) + self.assertIsInstance(padded_image, torch.Tensor) + + input_image = torch.ones(5, 7, 11, 13) + padded_image = image.pad_image_to_fft(input_image, axes=(0, 1, 3)) + padded_image_np = image.pad_image_to_fft( + input_image.numpy(), axes=(0, 1, 3) + ) + self.assertEqual(padded_image.shape, (6, 8, 11, 16)) + self.assertIsInstance(padded_image, torch.Tensor) + np.testing.assert_allclose( + padded_image.numpy(), padded_image_np, atol=1e-6 + ) + if __name__ == "__main__": unittest.main() \ No newline at end of file