Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 215 additions & 11 deletions sleap_nn/data/custom_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Custom `torch.utils.data.Dataset`s for different model types."""

from kornia.geometry.transform import crop_and_resize
from kornia.geometry.transform import crop_and_resize, crop_by_boxes
from itertools import cycle
from pathlib import Path
import torch.distributed as dist
Expand All @@ -21,14 +21,19 @@
convert_to_rgb,
)
from sleap_nn.data.providers import get_max_instances, get_max_height_width, process_lf
from sleap_nn.data.resizing import apply_pad_to_stride, apply_sizematcher, apply_resizer
from sleap_nn.data.resizing import (
apply_pad_to_stride,
apply_sizematcher,
apply_resizer,
apply_padding,
)
from sleap_nn.data.augmentation import (
apply_geometric_augmentation,
apply_intensity_augmentation,
)
from sleap_nn.data.confidence_maps import generate_confmaps, generate_multiconfmaps
from sleap_nn.data.edge_maps import generate_pafs
from sleap_nn.data.instance_cropping import make_centered_bboxes
from sleap_nn.data.instance_cropping import make_centered_bboxes, get_fit_bbox
from sleap_nn.training.utils import is_distributed_initialized, get_dist_rank


Expand Down Expand Up @@ -290,12 +295,13 @@ def __getitem__(self, index) -> Dict:
sample["image"] = convert_to_grayscale(sample["image"])

# size matcher
sample["image"], eff_scale = apply_sizematcher(
sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
sample["image"],
max_height=self.max_hw[0],
max_width=self.max_hw[1],
)
sample["instances"] = sample["instances"] * eff_scale
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# resize image
sample["image"], sample["instances"] = apply_resizer(
Expand All @@ -305,9 +311,10 @@ def __getitem__(self, index) -> Dict:
)

# Pad the image (if needed) according max stride
sample["image"] = apply_pad_to_stride(
sample["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride(
sample["image"], max_stride=self.max_stride
)
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# apply augmentation
if self.apply_aug and self.augmentation_config is not None:
Expand Down Expand Up @@ -499,12 +506,13 @@ def __getitem__(self, index) -> Dict:
image = convert_to_grayscale(image)

# size matcher
image, eff_scale = apply_sizematcher(
image, eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
image,
max_height=self.max_hw[0],
max_width=self.max_hw[1],
)
instances = instances * eff_scale
instances = instances + torch.Tensor((pad_w_l, pad_h_t))

# resize image
image, instances = apply_resizer(
Expand Down Expand Up @@ -571,9 +579,200 @@ def __getitem__(self, index) -> Dict:
sample["centroid"] = centered_centroid # (n_samples=1, 2)

# Pad the image (if needed) according max stride
sample["instance_image"] = apply_pad_to_stride(
sample["instance_image"], (pad_w_l, pad_h_t) = apply_pad_to_stride(
sample["instance_image"], max_stride=self.max_stride
)
sample["instance"] = sample["instance"] + torch.Tensor((pad_w_l, pad_h_t))

img_hw = sample["instance_image"].shape[-2:]

# Generate confidence maps
confidence_maps = generate_confmaps(
sample["instance"],
img_hw=img_hw,
sigma=self.confmap_head_config.sigma,
output_stride=self.confmap_head_config.output_stride,
)

sample["confidence_maps"] = confidence_maps

return sample


class CenteredInstanceDatasetFitBbox(CenteredInstanceDataset):
def __init__(
self,
labels,
max_crop_hw,
confmap_head_config,
max_stride,
user_instances_only=True,
is_rgb=False,
augmentation_config=None,
scale=1,
apply_aug=False,
max_hw=(None, None),
cache_img=None,
cache_img_path=None,
use_existing_imgs=False,
rank=None,
):
super().__init__(
labels,
max_crop_hw,
confmap_head_config,
max_stride,
user_instances_only,
is_rgb,
augmentation_config,
scale,
apply_aug,
max_hw,
cache_img,
cache_img_path,
use_existing_imgs,
rank,
)
self.max_crop_h_w = max_crop_hw

def __getitem__(self, index):
"""Return dict with cropped image and confmaps of instance for given index."""
lf_idx, inst_idx = self.instance_idx_list[index]
lf = self.labels[lf_idx]

if lf_idx == self.cache_lf[0]:
img = self.cache_lf[1]
else:
# load the img
if self.cache_img is not None:
if self.cache_img == "disk":
img = np.array(
Image.open(f"{self.cache_img_path}/sample_{lf_idx}.jpg")
)
elif self.cache_img == "memory":
img = self.cache[lf_idx].copy()

else: # load from slp file if not cached
img = lf.image # TODO: doesn't work when num_workers > 0

if img.ndim == 2:
img = np.expand_dims(img, axis=2)

self.cache_lf = [lf_idx, img]

video_idx = self._get_video_idx(lf)

image = np.transpose(img, (2, 0, 1)) # HWC -> CHW

instances = []
for inst in lf:
instances.append(inst.numpy())
instances = np.stack(instances, axis=0)

# Add singleton time dimension for single frames.
image = np.expand_dims(image, axis=0) # (n_samples=1, C, H, W)
instances = np.expand_dims(
instances, axis=0
) # (n_samples=1, num_instances, num_nodes, 2)

instances = torch.from_numpy(instances.astype("float32"))
image = torch.from_numpy(image)

num_instances, _ = instances.shape[1:3]
orig_img_height, orig_img_width = image.shape[-2:]

instances = instances[:, inst_idx]

# apply normalization
image = apply_normalization(image)

if self.is_rgb:
image = convert_to_rgb(image)
else:
image = convert_to_grayscale(image)

# size matcher
image, eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
image,
max_height=self.max_hw[0],
max_width=self.max_hw[1],
)
instances = instances * eff_scale
instances = instances + torch.Tensor((pad_w_l, pad_h_t))

# resize image
image, instances = apply_resizer(
image,
instances,
scale=self.scale,
)

# apply augmentation
if self.apply_aug and self.augmentation_config is not None:
if "intensity" in self.augmentation_config:
image, instances = apply_intensity_augmentation(
image,
instances,
**self.augmentation_config.intensity,
)

if "geometric" in self.augmentation_config:
image, instances = apply_geometric_augmentation(
image,
instances,
**self.augmentation_config.geometric,
)

instance = instances[0] # (n_samples=1)

bbox = get_fit_bbox(instance) # bbox => (x_min, y_min, x_max, y_max)
bbox[0] = bbox[0] - 16
bbox[1] = bbox[1] - 16
bbox[2] = bbox[2] + 16
bbox[3] = bbox[3] + 16 # padding of 16 on all sides
x_min, y_min, x_max, y_max = bbox
crop_hw = (y_max - y_min, x_max - x_min)

cropped_image = crop_by_boxes(
image,
src_box=torch.Tensor(
[[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]]
).unsqueeze(dim=0),
dst_box=torch.tensor(
[
[
[0.0, 0.0],
[crop_hw[1], 0.0],
[crop_hw[1], crop_hw[0]],
[0.0, crop_hw[0]],
]
]
),
)
instance = instance - bbox[:2] # adjust for crops

cropped_image_match_hw, eff_scale, pad_wh = apply_sizematcher(
cropped_image, self.max_crop_h_w[0], self.max_crop_h_w[1]
) # resize and pad to max crfop size
instance = instance * eff_scale # adjust keypoints acc to resizing/ padding
instance = instance + torch.Tensor(pad_wh)
instance = torch.unsqueeze(instance, dim=0)

sample = {}
sample["instance_image"] = cropped_image_match_hw
sample["instance_bbox"] = bbox
sample["instance"] = instance
sample["frame_idx"] = torch.tensor(lf.frame_idx, dtype=torch.int32)
sample["video_idx"] = torch.tensor(video_idx, dtype=torch.int32)
sample["num_instances"] = num_instances
sample["orig_size"] = torch.Tensor([orig_img_height, orig_img_width])
sample["crop_hw"] = torch.Tensor([crop_hw])

# Pad the image (if needed) according max stride
sample["instance_image"], pad_wh = apply_pad_to_stride(
sample["instance_image"], max_stride=self.max_stride
)
sample["instance"] = sample["instance"] + torch.Tensor(pad_wh)

img_hw = sample["instance_image"].shape[-2:]

Expand Down Expand Up @@ -694,12 +893,13 @@ def __getitem__(self, index) -> Dict:
sample["image"] = convert_to_grayscale(sample["image"])

# size matcher
sample["image"], eff_scale = apply_sizematcher(
sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
sample["image"],
max_height=self.max_hw[0],
max_width=self.max_hw[1],
)
sample["instances"] = sample["instances"] * eff_scale
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# resize image
sample["image"], sample["instances"] = apply_resizer(
Expand All @@ -716,9 +916,11 @@ def __getitem__(self, index) -> Dict:
sample["centroids"] = centroids

# Pad the image (if needed) according max stride
sample["image"] = apply_pad_to_stride(
sample["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride(
sample["image"], max_stride=self.max_stride
)
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))
sample["centroids"] = sample["centroids"] + torch.Tensor((pad_w_l, pad_h_t))

# apply augmentation
if self.apply_aug and self.augmentation_config is not None:
Expand Down Expand Up @@ -857,12 +1059,13 @@ def __getitem__(self, index) -> Dict:
sample["image"] = convert_to_grayscale(sample["image"])

# size matcher
sample["image"], eff_scale = apply_sizematcher(
sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
sample["image"],
max_height=self.max_hw[0],
max_width=self.max_hw[1],
)
sample["instances"] = sample["instances"] * eff_scale
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# resize image
sample["image"], sample["instances"] = apply_resizer(
Expand All @@ -872,9 +1075,10 @@ def __getitem__(self, index) -> Dict:
)

# Pad the image (if needed) according max stride
sample["image"] = apply_pad_to_stride(
sample["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride(
sample["image"], max_stride=self.max_stride
)
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# apply augmentation
if self.apply_aug and self.augmentation_config is not None:
Expand Down
12 changes: 8 additions & 4 deletions sleap_nn/data/get_data_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ def bottomup_data_chunks(
data_config.preprocessing.max_height,
data_config.preprocessing.max_width,
)
sample["image"], eff_scale = apply_sizematcher(
sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
sample["image"],
max_height=max_height if max_height is not None else max_hw[0],
max_width=max_width if max_width is not None else max_hw[1],
)
sample["instances"] = sample["instances"] * eff_scale
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# resize the image
sample["image"], sample["instances"] = apply_resizer(
Expand Down Expand Up @@ -140,12 +141,13 @@ def centered_instance_data_chunks(
data_config.preprocessing.max_height,
data_config.preprocessing.max_width,
)
sample["image"], eff_scale = apply_sizematcher(
sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
sample["image"],
max_height=max_height if max_height is not None else max_hw[0],
max_width=max_width if max_width is not None else max_hw[1],
)
sample["instances"] = sample["instances"] * eff_scale
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# get the centroids based on the anchor idx
centroids = generate_centroids(sample["instances"], anchor_ind=anchor_ind)
Expand Down Expand Up @@ -230,13 +232,14 @@ def centroid_data_chunks(
data_config.preprocessing.max_height,
data_config.preprocessing.max_width,
)
sample["image"], eff_scale = apply_sizematcher(
sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
sample["image"],
max_height=max_height if max_height is not None else max_hw[0],
max_width=max_width if max_width is not None else max_hw[1],
)

sample["instances"] = sample["instances"] * eff_scale
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# get the centroids based on the anchor idx
centroids = generate_centroids(sample["instances"], anchor_ind=anchor_ind)
Expand Down Expand Up @@ -302,12 +305,13 @@ def single_instance_data_chunks(
data_config.preprocessing.max_height,
data_config.preprocessing.max_width,
)
sample["image"], eff_scale = apply_sizematcher(
sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher(
sample["image"],
max_height=max_height if max_height is not None else max_hw[0],
max_width=max_width if max_width is not None else max_hw[1],
)
sample["instances"] = sample["instances"] * eff_scale
sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t))

# resize image
sample["image"], sample["instances"] = apply_resizer(
Expand Down
Loading
Loading