Skip to content

Mg/features loadimage #398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 72 additions & 30 deletions deeptrack/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def propagate_data_to_dependencies(
"Merge",
"OneOf",
"OneOfDict",
"LoadImage", # TODO ***MG***
"LoadImage",
"SampleToMasks", # TODO ***MG***
"AsType", # TODO ***MG***
"ChannelFirst2d",
Expand Down Expand Up @@ -7195,29 +7195,29 @@ def get(
class LoadImage(Feature):
"""Load an image from disk and preprocess it.

This feature loads an image file using multiple fallback file readers
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a suitable reader is
found. The image can be optionally converted to grayscale, reshaped to
ensure a minimum number of dimensions, or treated as a list of images if
This feature loads an image file using multiple fallback file readers
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a suitable reader is
found. The image can be optionally converted to grayscale, reshaped to
ensure a minimum number of dimensions, or treated as a list of images if
multiple paths are provided.

Parameters
----------
path: PropertyLike[str or list[str]]
The path(s) to the image(s) to load. Can be a single string or a list
The path(s) to the image(s) to load. Can be a single string or a list
of strings.
load_options: PropertyLike[dict[str, Any]], optional
Additional options passed to the file reader. It defaults to `None`.
as_list: PropertyLike[bool], optional
If `True`, the first dimension of the image will be treated as a list.
If `True`, the first dimension of the image will be treated as a list.
It defaults to `False`.
ndim: PropertyLike[int], optional
Ensures the image has at least this many dimensions. It defaults to
`3`.
to_grayscale: PropertyLike[bool], optional
If `True`, converts the image to grayscale. It defaults to `False`.
get_one_random: PropertyLike[bool], optional
If `True`, extracts a single random image from a stack of images. Only
If `True`, extracts a single random image from a stack of images. Only
used when `as_list` is `True`. It defaults to `False`.

Attributes
Expand All @@ -7228,22 +7228,37 @@ class LoadImage(Feature):

Methods
-------
`get(image: Any, path: str or list[str], load_options: dict[str, Any] | None, ndim: int, to_grayscale: bool, as_list: bool, get_one_random: bool, **kwargs: Any) -> array`
`get(
path: str | list[str],
load_options: dict[str, Any] | None,
ndim: int,
to_grayscale: bool,
as_list: bool,
get_one_random: bool,
**kwargs: Any
) -> NDArray | torch.Tensor | list`
Load the image(s) from disk and process them.

Raises
------
IOError
If no file reader could parse the file or the file does not exist.

Notes
----
By default, `LoadImage` returns a NumPy array. If you want the output as
a PyTorch tensor, either set the backend to `'torch'` globally using
`dt.backend.config.set_backend('torch')` or convert the feature by calling
`.torch()` before resolving.

Examples
--------
>>> import deeptrack as dt

Create a temporary image file:
>>> import numpy as np
>>> import os, tempfile
>>>
>>>
>>> temp_file = tempfile.NamedTemporaryFile(suffix=".npy", delete=False)
>>> np.save(temp_file.name, np.random.rand(100, 100, 3))

Expand Down Expand Up @@ -7271,7 +7286,21 @@ class LoadImage(Feature):
... )
>>> loaded_image = load_image_feature.resolve()
>>> loaded_image.shape
(2, 2, 3, 1)
(100, 100, 3, 1)

Load an image as a PyTorch tensor by setting the backend of the feature:
>>> load_image_feature = dt.LoadImage(path=temp_file.name)
>>> load_image_feature.torch()
>>> loaded_image = load_image_feature.resolve()
>>> print(type(loaded_image))
<class 'torch.Tensor'>

Load an image as a PyTorch tensor by setting the backend globally:
>>> dt.backend.config.set_backend('torch')
>>> load_image_feature = dt.LoadImage(path=temp_file.name)
>>> loaded_image = load_image_feature.resolve()
>>> print(type(loaded_image))
<class 'torch.Tensor'>

Cleanup the temporary file:
>>> os.remove(temp_file.name)
Expand Down Expand Up @@ -7313,7 +7342,7 @@ def __init__(
If `True`, selects a single random image from a stack when
`as_list=True`. It defaults to `False`.
**kwargs: Any
Additional keyword arguments passed to the parent `Feature` class,
Additional keyword arguments passed to the parent `Feature` class,
allowing further customization.

"""
Expand All @@ -7338,31 +7367,36 @@ def get(
as_list: bool,
get_one_random: bool,
**kwargs: Any,
) -> NDArray | torch.Tensor:
) -> NDArray | torch.Tensor | list:
"""Load and process an image or a list of images from disk.

This method attempts to load an image using multiple file readers
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a valid format is
This method attempts to load an image using multiple file readers
(`imageio`, `numpy`, `Pillow`, and `OpenCV`) until a valid format is
found. It supports optional processing steps such as ensuring a minimum
number of dimensions, grayscale conversion, and treating multi-frame
number of dimensions, grayscale conversion, and treating multi-frame
images as lists.

The output is returned as a NumPy array by default. If `as_list=True`,
the result is a Python list of arrays. If the backend is `'torch'`, the
image is returned as a PyTorch tensor.

Parameters
----------
path: str or list[str]
The file path(s) to the image(s) to be loaded. A single string
The file path(s) to the image(s) to be loaded. A single string
loads one image, while a list of paths loads multiple images.
load_options: dict of str to Any, optional
Additional options passed to the file reader (e.g., `allow_pickle`
Additional options passed to the file reader (e.g., `allow_pickle`
for NumPy, `mode` for OpenCV). It defaults to `None`.
ndim: int
Ensures the image has at least this many dimensions. If the loaded
image has fewer dimensions, extra dimensions are added.
Ensures the image has at least this many dimensions. If the loaded
image has fewer dimensions, extra dimensions are added. It defaults
to `3`.
to_grayscale: bool
If `True`, converts the image to grayscale. It defaults to `False`.
as_list: bool
If `True`, treats the first dimension as a list of images instead
of stacking them into a NumPy array.
If `True`, treats the first dimension as a list of images instead
of stacking them into a NumPy array. It defaults to `False`.
get_one_random: bool
If `True`, selects a single random image from a multi-frame stack
when `as_list=True`. It defaults to `False`.
Expand All @@ -7371,15 +7405,15 @@ def get(

Returns
-------
array
The loaded and processed image(s). If `as_list=True`, returns a
NDArray | torch.Tensor | list
The loaded and processed image(s). If `as_list=True`, returns a
list of images; otherwise, returns a single NumPy array or PyTorch
tensor.

Raises
------
IOError
If no valid file reader is found or if the specified file does not
If no valid file reader is found or if the specified file does not
exist.

"""
Expand All @@ -7402,8 +7436,9 @@ def get(
try:
import PIL.Image

image = [PIL.Image.open(file, **load_options)
for file in path]
image = [
PIL.Image.open(file, **load_options) for file in path
]
except (IOError, ImportError):
import cv2

Expand Down Expand Up @@ -7439,11 +7474,18 @@ def get(
)

# Ensure the image has at least `ndim` dimensions.
while ndim and image.ndim < ndim:
image = np.expand_dims(image, axis=-1)
if not isinstance(image, list) and ndim:
while image.ndim < ndim:
image = np.expand_dims(image, axis=-1)

# Convert to PyTorch tensor if needed.
#TODO
if self.get_backend() == "torch":

# Convert to stack if needed.
if isinstance(image, list):
image = np.stack(image, axis=0)

image = torch.from_numpy(image)

return image

Expand Down
62 changes: 58 additions & 4 deletions deeptrack/tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,9 +1836,14 @@ def test_LoadImage(self):

try:
with NamedTemporaryFile(suffix=".npy", delete=False) as temp_npy:
np.save(temp_npy.name, test_image_array)
pass
np.save(temp_npy.name, test_image_array)
# npy_filename = temp_npy.name

with NamedTemporaryFile(suffix=".npy", delete=False) as temp_npy2:
pass
np.save(temp_npy2.name, test_image_array)

with NamedTemporaryFile(suffix=".png", delete=False) as temp_png:
PIL_Image.fromarray(test_image_array).save(temp_png.name)
# png_filename = temp_png.name
Expand Down Expand Up @@ -1877,12 +1882,61 @@ def test_LoadImage(self):
loaded_image = load_feature.resolve()
self.assertGreaterEqual(len(loaded_image.shape), 4)

# Test loading a list of images
load_feature = features.LoadImage(
path=[temp_npy.name, temp_npy2.name], as_list=True
)
loaded_list = load_feature.resolve()
self.assertIsInstance(loaded_list, list)
self.assertEqual(len(loaded_list), 2)

for img in loaded_list:
self.assertTrue(isinstance(img, np.ndarray))

# Test loading a random image from a list of images
load_feature = features.LoadImage(
path=[temp_npy.name, temp_npy2.name],
ndim=4,
as_list=True,
get_one_random=True,
)
loaded_image = load_feature.resolve()
self.assertTrue(
np.allclose(
loaded_image[:, :, 0, 0], test_image_array, rtol=1.e-3
)
)
self.assertEqual(loaded_image.shape, (50, 50, 1, 1))

import gc
gc.collect()

# Test loading an image as a torch tensor.
if TORCH_AVAILABLE:
load_feature = features.LoadImage(path=temp_png.name)
load_feature.torch()
loaded_image = load_feature.resolve()
self.assertIsInstance(loaded_image, torch.Tensor)
self.assertEqual(
loaded_image.shape[:2], test_image_array.shape
)

loaded_image_np = loaded_image.numpy()
self.assertTrue(
np.allclose(
test_image_array, loaded_image_np[:, :, 0], rtol=1.e-3
)
)

finally:
for file in [temp_npy.name, temp_png.name, temp_jpg.name]:
for file in [
temp_npy.name,
temp_png.name,
temp_jpg.name,
temp_npy2.name
]:
os.remove(file)

#TODO: Add a test for loading a list of images.


def test_SampleToMasks(self):
# Parameters
Expand Down
Loading