diff --git a/sleap_nn/data/custom_datasets.py b/sleap_nn/data/custom_datasets.py index c6f348ed..ed623a2d 100644 --- a/sleap_nn/data/custom_datasets.py +++ b/sleap_nn/data/custom_datasets.py @@ -1,6 +1,6 @@ """Custom `torch.utils.data.Dataset`s for different model types.""" -from kornia.geometry.transform import crop_and_resize +from kornia.geometry.transform import crop_and_resize, crop_by_boxes from itertools import cycle from pathlib import Path import torch.distributed as dist @@ -21,14 +21,19 @@ convert_to_rgb, ) from sleap_nn.data.providers import get_max_instances, get_max_height_width, process_lf -from sleap_nn.data.resizing import apply_pad_to_stride, apply_sizematcher, apply_resizer +from sleap_nn.data.resizing import ( + apply_pad_to_stride, + apply_sizematcher, + apply_resizer, + apply_padding, +) from sleap_nn.data.augmentation import ( apply_geometric_augmentation, apply_intensity_augmentation, ) from sleap_nn.data.confidence_maps import generate_confmaps, generate_multiconfmaps from sleap_nn.data.edge_maps import generate_pafs -from sleap_nn.data.instance_cropping import make_centered_bboxes +from sleap_nn.data.instance_cropping import make_centered_bboxes, get_fit_bbox from sleap_nn.training.utils import is_distributed_initialized, get_dist_rank @@ -290,12 +295,13 @@ def __getitem__(self, index) -> Dict: sample["image"] = convert_to_grayscale(sample["image"]) # size matcher - sample["image"], eff_scale = apply_sizematcher( + sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( sample["image"], max_height=self.max_hw[0], max_width=self.max_hw[1], ) sample["instances"] = sample["instances"] * eff_scale + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # resize image sample["image"], sample["instances"] = apply_resizer( @@ -305,9 +311,10 @@ def __getitem__(self, index) -> Dict: ) # Pad the image (if needed) according max stride - sample["image"] = apply_pad_to_stride( + sample["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( sample["image"], max_stride=self.max_stride ) + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # apply augmentation if self.apply_aug and self.augmentation_config is not None: @@ -499,12 +506,13 @@ def __getitem__(self, index) -> Dict: image = convert_to_grayscale(image) # size matcher - image, eff_scale = apply_sizematcher( + image, eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( image, max_height=self.max_hw[0], max_width=self.max_hw[1], ) instances = instances * eff_scale + instances = instances + torch.Tensor((pad_w_l, pad_h_t)) # resize image image, instances = apply_resizer( @@ -571,9 +579,200 @@ def __getitem__(self, index) -> Dict: sample["centroid"] = centered_centroid # (n_samples=1, 2) # Pad the image (if needed) according max stride - sample["instance_image"] = apply_pad_to_stride( + sample["instance_image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( + sample["instance_image"], max_stride=self.max_stride + ) + sample["instance"] = sample["instance"] + torch.Tensor((pad_w_l, pad_h_t)) + + img_hw = sample["instance_image"].shape[-2:] + + # Generate confidence maps + confidence_maps = generate_confmaps( + sample["instance"], + img_hw=img_hw, + sigma=self.confmap_head_config.sigma, + output_stride=self.confmap_head_config.output_stride, + ) + + sample["confidence_maps"] = confidence_maps + + return sample + + +class CenteredInstanceDatasetFitBbox(CenteredInstanceDataset): + def __init__( + self, + labels, + max_crop_hw, + confmap_head_config, + max_stride, + user_instances_only=True, + is_rgb=False, + augmentation_config=None, + scale=1, + apply_aug=False, + max_hw=(None, None), + cache_img=None, + cache_img_path=None, + use_existing_imgs=False, + rank=None, + ): + super().__init__( + labels, + max_crop_hw, + confmap_head_config, + max_stride, + user_instances_only, + is_rgb, + augmentation_config, + scale, + apply_aug, + max_hw, + cache_img, + cache_img_path, + use_existing_imgs, + rank, + ) + self.max_crop_h_w = max_crop_hw + + def __getitem__(self, index): + """Return dict with cropped image and confmaps of instance for given index.""" + lf_idx, inst_idx = self.instance_idx_list[index] + lf = self.labels[lf_idx] + + if lf_idx == self.cache_lf[0]: + img = self.cache_lf[1] + else: + # load the img + if self.cache_img is not None: + if self.cache_img == "disk": + img = np.array( + Image.open(f"{self.cache_img_path}/sample_{lf_idx}.jpg") + ) + elif self.cache_img == "memory": + img = self.cache[lf_idx].copy() + + else: # load from slp file if not cached + img = lf.image # TODO: doesn't work when num_workers > 0 + + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + + self.cache_lf = [lf_idx, img] + + video_idx = self._get_video_idx(lf) + + image = np.transpose(img, (2, 0, 1)) # HWC -> CHW + + instances = [] + for inst in lf: + instances.append(inst.numpy()) + instances = np.stack(instances, axis=0) + + # Add singleton time dimension for single frames. + image = np.expand_dims(image, axis=0) # (n_samples=1, C, H, W) + instances = np.expand_dims( + instances, axis=0 + ) # (n_samples=1, num_instances, num_nodes, 2) + + instances = torch.from_numpy(instances.astype("float32")) + image = torch.from_numpy(image) + + num_instances, _ = instances.shape[1:3] + orig_img_height, orig_img_width = image.shape[-2:] + + instances = instances[:, inst_idx] + + # apply normalization + image = apply_normalization(image) + + if self.is_rgb: + image = convert_to_rgb(image) + else: + image = convert_to_grayscale(image) + + # size matcher + image, eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( + image, + max_height=self.max_hw[0], + max_width=self.max_hw[1], + ) + instances = instances * eff_scale + instances = instances + torch.Tensor((pad_w_l, pad_h_t)) + + # resize image + image, instances = apply_resizer( + image, + instances, + scale=self.scale, + ) + + # apply augmentation + if self.apply_aug and self.augmentation_config is not None: + if "intensity" in self.augmentation_config: + image, instances = apply_intensity_augmentation( + image, + instances, + **self.augmentation_config.intensity, + ) + + if "geometric" in self.augmentation_config: + image, instances = apply_geometric_augmentation( + image, + instances, + **self.augmentation_config.geometric, + ) + + instance = instances[0] # (n_samples=1) + + bbox = get_fit_bbox(instance) # bbox => (x_min, y_min, x_max, y_max) + bbox[0] = bbox[0] - 16 + bbox[1] = bbox[1] - 16 + bbox[2] = bbox[2] + 16 + bbox[3] = bbox[3] + 16 # padding of 16 on all sides + x_min, y_min, x_max, y_max = bbox + crop_hw = (y_max - y_min, x_max - x_min) + + cropped_image = crop_by_boxes( + image, + src_box=torch.Tensor( + [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]] + ).unsqueeze(dim=0), + dst_box=torch.tensor( + [ + [ + [0.0, 0.0], + [crop_hw[1], 0.0], + [crop_hw[1], crop_hw[0]], + [0.0, crop_hw[0]], + ] + ] + ), + ) + instance = instance - bbox[:2] # adjust for crops + + cropped_image_match_hw, eff_scale, pad_wh = apply_sizematcher( + cropped_image, self.max_crop_h_w[0], self.max_crop_h_w[1] + ) # resize and pad to max crfop size + instance = instance * eff_scale # adjust keypoints acc to resizing/ padding + instance = instance + torch.Tensor(pad_wh) + instance = torch.unsqueeze(instance, dim=0) + + sample = {} + sample["instance_image"] = cropped_image_match_hw + sample["instance_bbox"] = bbox + sample["instance"] = instance + sample["frame_idx"] = torch.tensor(lf.frame_idx, dtype=torch.int32) + sample["video_idx"] = torch.tensor(video_idx, dtype=torch.int32) + sample["num_instances"] = num_instances + sample["orig_size"] = torch.Tensor([orig_img_height, orig_img_width]) + sample["crop_hw"] = torch.Tensor([crop_hw]) + + # Pad the image (if needed) according max stride + sample["instance_image"], pad_wh = apply_pad_to_stride( sample["instance_image"], max_stride=self.max_stride ) + sample["instance"] = sample["instance"] + torch.Tensor(pad_wh) img_hw = sample["instance_image"].shape[-2:] @@ -694,12 +893,13 @@ def __getitem__(self, index) -> Dict: sample["image"] = convert_to_grayscale(sample["image"]) # size matcher - sample["image"], eff_scale = apply_sizematcher( + sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( sample["image"], max_height=self.max_hw[0], max_width=self.max_hw[1], ) sample["instances"] = sample["instances"] * eff_scale + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # resize image sample["image"], sample["instances"] = apply_resizer( @@ -716,9 +916,11 @@ def __getitem__(self, index) -> Dict: sample["centroids"] = centroids # Pad the image (if needed) according max stride - sample["image"] = apply_pad_to_stride( + sample["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( sample["image"], max_stride=self.max_stride ) + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) + sample["centroids"] = sample["centroids"] + torch.Tensor((pad_w_l, pad_h_t)) # apply augmentation if self.apply_aug and self.augmentation_config is not None: @@ -857,12 +1059,13 @@ def __getitem__(self, index) -> Dict: sample["image"] = convert_to_grayscale(sample["image"]) # size matcher - sample["image"], eff_scale = apply_sizematcher( + sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( sample["image"], max_height=self.max_hw[0], max_width=self.max_hw[1], ) sample["instances"] = sample["instances"] * eff_scale + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # resize image sample["image"], sample["instances"] = apply_resizer( @@ -872,9 +1075,10 @@ def __getitem__(self, index) -> Dict: ) # Pad the image (if needed) according max stride - sample["image"] = apply_pad_to_stride( + sample["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( sample["image"], max_stride=self.max_stride ) + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # apply augmentation if self.apply_aug and self.augmentation_config is not None: diff --git a/sleap_nn/data/get_data_chunks.py b/sleap_nn/data/get_data_chunks.py index 7a4711c2..540ea772 100644 --- a/sleap_nn/data/get_data_chunks.py +++ b/sleap_nn/data/get_data_chunks.py @@ -66,12 +66,13 @@ def bottomup_data_chunks( data_config.preprocessing.max_height, data_config.preprocessing.max_width, ) - sample["image"], eff_scale = apply_sizematcher( + sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( sample["image"], max_height=max_height if max_height is not None else max_hw[0], max_width=max_width if max_width is not None else max_hw[1], ) sample["instances"] = sample["instances"] * eff_scale + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # resize the image sample["image"], sample["instances"] = apply_resizer( @@ -140,12 +141,13 @@ def centered_instance_data_chunks( data_config.preprocessing.max_height, data_config.preprocessing.max_width, ) - sample["image"], eff_scale = apply_sizematcher( + sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( sample["image"], max_height=max_height if max_height is not None else max_hw[0], max_width=max_width if max_width is not None else max_hw[1], ) sample["instances"] = sample["instances"] * eff_scale + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # get the centroids based on the anchor idx centroids = generate_centroids(sample["instances"], anchor_ind=anchor_ind) @@ -230,13 +232,14 @@ def centroid_data_chunks( data_config.preprocessing.max_height, data_config.preprocessing.max_width, ) - sample["image"], eff_scale = apply_sizematcher( + sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( sample["image"], max_height=max_height if max_height is not None else max_hw[0], max_width=max_width if max_width is not None else max_hw[1], ) sample["instances"] = sample["instances"] * eff_scale + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # get the centroids based on the anchor idx centroids = generate_centroids(sample["instances"], anchor_ind=anchor_ind) @@ -302,12 +305,13 @@ def single_instance_data_chunks( data_config.preprocessing.max_height, data_config.preprocessing.max_width, ) - sample["image"], eff_scale = apply_sizematcher( + sample["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( sample["image"], max_height=max_height if max_height is not None else max_hw[0], max_width=max_width if max_width is not None else max_hw[1], ) sample["instances"] = sample["instances"] * eff_scale + sample["instances"] = sample["instances"] + torch.Tensor((pad_w_l, pad_h_t)) # resize image sample["image"], sample["instances"] = apply_resizer( diff --git a/sleap_nn/data/instance_cropping.py b/sleap_nn/data/instance_cropping.py index 1f414c9c..9f267b49 100644 --- a/sleap_nn/data/instance_cropping.py +++ b/sleap_nn/data/instance_cropping.py @@ -61,6 +61,25 @@ def find_instance_crop_size( return int(crop_size) +def get_fit_bbox(instance: torch.Tensor) -> torch.Tensor: + """Get a fit bbox around the instance. + + Args: + instance: tensor with shape (..., nodes, 2) + + Returns: + bbox coordinates of the form x_min, y_min, x_max, y_max representing the fit bbox around the given instance. + """ + x, y = instance[:, 0], instance[:, 1] + masked_x = x[~torch.isnan(x)] + masked_y = y[~torch.isnan(y)] + x_min, x_max = torch.min(masked_x), torch.max(masked_x) + y_min, y_max = torch.min(masked_y), torch.max(masked_y) + bbox = torch.Tensor([x_min, y_min, x_max, y_max]).to(torch.int32) + + return bbox + + def make_centered_bboxes( centroids: torch.Tensor, box_height: int, box_width: int ) -> torch.Tensor: diff --git a/sleap_nn/data/resizing.py b/sleap_nn/data/resizing.py index 918d04ce..417dedbc 100644 --- a/sleap_nn/data/resizing.py +++ b/sleap_nn/data/resizing.py @@ -3,7 +3,6 @@ from typing import Dict, Iterator, Optional, Tuple, List, Union import torch -import torch.nn.functional as F from sleap_nn.data.providers import LabelsReaderDP, VideoReader import torchvision.transforms.v2.functional as tvf from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -49,8 +48,9 @@ def apply_pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor: reduction layers in the model. Returns: - The input image with 0-padding applied to the bottom and/or right such that the - new shape's height and width are both divisible by `max_stride`. + A tuple with the input image with 0-padding applied to the bottom and/or right such that the + new shape's height and width are both divisible by `max_stride` and (pad_width_left, pad_height_top) + to shift the ground-truth keypoints according to the padded image. """ if max_stride > 1: image_height, image_width = image.shape[-2:] @@ -60,13 +60,20 @@ def apply_pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor: max_stride=max_stride, ) + pad_width_left = pad_width // 2 + pad_width_right = pad_width - pad_width_left + + pad_height_top = pad_height // 2 + pad_height_bottom = pad_height - pad_height_top + if pad_height > 0 or pad_width > 0: - image = F.pad( + image = tvf.pad( image, - (0, pad_width, 0, pad_height), - mode="constant", + (pad_width_left, pad_height_top, pad_width_right, pad_height_bottom), + 0, + "constant", ).to(torch.float32) - return image + return image, (pad_width_left, pad_height_top) def resize_image(image: torch.Tensor, scale: float): @@ -104,6 +111,41 @@ def apply_resizer(image: torch.Tensor, instances: torch.Tensor, scale: float = 1 return image, instances +def apply_padding( + image: torch.Tensor, + max_height: Optional[int] = None, + max_width: Optional[int] = None, +): + """Apply scaling and padding to image to (max_height, max_width) shape.""" + img_height, img_width = image.shape[-2:] + # pad images to max_height and max_width + if max_height is None: + max_height = img_height + if max_width is None: + max_width = img_width + if img_height != max_height or img_width != max_width: + + pad_height = max_height - img_height + pad_width = max_width - img_width + + pad_width_left = pad_width // 2 + pad_width_right = pad_width - pad_width_left + + pad_height_top = pad_height // 2 + pad_height_bottom = pad_height - pad_height_top + + image = tvf.pad( + image, + (pad_width_left, pad_height_top, pad_width_right, pad_height_bottom), + 0, + "constant", + ).to(torch.float32) + + return image, (pad_width_left, pad_height_top) + else: + return image, (0, 0) + + def apply_sizematcher( image: torch.Tensor, max_height: Optional[int] = None, @@ -134,15 +176,22 @@ def apply_sizematcher( pad_height = max_height - target_h pad_width = max_width - target_w - image = F.pad( + pad_width_left = pad_width // 2 + pad_width_right = pad_width - pad_width_left + + pad_height_top = pad_height // 2 + pad_height_bottom = pad_height - pad_height_top + + image = tvf.pad( image, - (0, pad_width, 0, pad_height), - mode="constant", + (pad_width_left, pad_height_top, pad_width_right, pad_height_bottom), + 0, + "constant", ).to(torch.float32) - return image, eff_scale_ratio + return image, eff_scale_ratio, (pad_width_left, pad_height_top) else: - return image, 1.0 + return image, 1.0, (0, 0) class Resizer(IterDataPipe): @@ -217,7 +266,7 @@ def __init__( def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Return an example dictionary with the resized image and `orig_size` key to represent the original shape of the source image.""" for ex in self.source_datapipe: - ex[self.image_key] = apply_pad_to_stride( + ex[self.image_key], _ = apply_pad_to_stride( ex[self.image_key], self.max_stride ) yield ex diff --git a/sleap_nn/data/streaming_datasets.py b/sleap_nn/data/streaming_datasets.py index 9eaf520f..a2338ff4 100644 --- a/sleap_nn/data/streaming_datasets.py +++ b/sleap_nn/data/streaming_datasets.py @@ -80,7 +80,10 @@ def __getitem__(self, index): ) # Pad the image (if needed) according max stride - ex["image"] = apply_pad_to_stride(ex["image"], max_stride=self.max_stride) + ex["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( + ex["image"], max_stride=self.max_stride + ) + ex["instances"] = ex["instances"] + torch.Tensor((pad_w_l, pad_h_t)) img_hw = ex["image"].shape[-2:] @@ -187,9 +190,10 @@ def __getitem__(self, index): ex["centroid"] = centered_centroid.unsqueeze(0) # (n_samples=1, 2) # Pad the image (if needed) according max stride - ex["instance_image"] = apply_pad_to_stride( + ex["instance_image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( ex["instance_image"], max_stride=self.max_stride ) + ex["instance"] = ex["instance"] + torch.Tensor((pad_w_l, pad_h_t)) img_hw = ex["instance_image"].shape[-2:] @@ -261,7 +265,10 @@ def __getitem__(self, index): ) # Pad the image (if needed) according max stride - ex["image"] = apply_pad_to_stride(ex["image"], max_stride=self.max_stride) + ex["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( + ex["image"], max_stride=self.max_stride + ) + ex["instances"] = ex["instances"] + torch.Tensor((pad_w_l, pad_h_t)) img_hw = ex["image"].shape[-2:] @@ -335,7 +342,10 @@ def __getitem__(self, index): ) # Pad the image (if needed) according max stride - ex["image"] = apply_pad_to_stride(ex["image"], max_stride=self.max_stride) + ex["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( + ex["image"], max_stride=self.max_stride + ) + ex["instances"] = ex["instances"] + torch.Tensor((pad_w_l, pad_h_t)) img_hw = ex["image"].shape[-2:] diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 848ade8b..49d0cb32 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -101,6 +101,7 @@ def from_model_paths( device: str = "cpu", preprocess_config: Optional[OmegaConf] = None, output_head_skeleton_num: int = 1, + centered_fitbbox: bool = False, ) -> "Predictor": """Create the appropriate `Predictor` subclass from from the ckpt path. @@ -190,6 +191,7 @@ def from_model_paths( device=device, preprocess_config=preprocess_config, output_head_skeleton_num=output_head_skeleton_num, + centered_fitbbox=centered_fitbbox, ) elif "bottomup" in model_names: @@ -270,34 +272,64 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: batch_size = self.preprocess_config["batch_size"] done = False while not done: + source_imgs = [] imgs = [] fidxs = [] vidxs = [] org_szs = [] instances = [] eff_scales = [] + pad_shifts = [] + pad_shifts_stride = [] for _ in range(batch_size): frame = self.pipeline.frame_buffer.get() if frame["image"] is None: done = True break frame["image"] = apply_normalization(frame["image"]) - frame["image"], eff_scale = apply_sizematcher( + source_img = frame["image"].clone() + frame["image"], eff_scale, (pad_w_l, pad_h_t) = apply_sizematcher( frame["image"], self.preprocess_config["max_height"], self.preprocess_config["max_width"], ) + pad_shift = torch.Tensor((pad_w_l, pad_h_t)) if self.instances_key: frame["instances"] = frame["instances"] * eff_scale + frame["instances"] = frame["instances"] + pad_shift.unsqueeze( + dim=0 + ).unsqueeze(dim=0).unsqueeze(dim=0) if self.preprocess_config["is_rgb"] and frame["image"].shape[-3] != 3: frame["image"] = frame["image"].repeat(1, 3, 1, 1) + source_img = source_img.repeat(1, 3, 1, 1) elif not self.preprocess_config["is_rgb"]: frame["image"] = F.rgb_to_grayscale( frame["image"], num_output_channels=1 ) + source_img = F.rgb_to_grayscale(source_img, num_output_channels=1) + + pad_shift_stride = torch.tensor((0, 0)) + if self.preprocess: + scale = self.preprocess_config["scale"] + if scale != 1.0: + if self.instances_key: + frame["image"], frame["instances"] = apply_resizer( + frame["image"], frame["instances"] + ) + else: + frame["image"] = resize_image(frame["image"], scale) + frame["image"], (pad_w_l, pad_h_t) = apply_pad_to_stride( + frame["image"], self.preprocess_config["max_stride"] + ) + pad_shift_stride = torch.Tensor((pad_w_l, pad_h_t)) + if self.instances_key: + frame["instances"] = frame["instances"] + pad_shift_stride eff_scales.append(torch.tensor(eff_scale)) + pad_shifts.append(pad_shift.unsqueeze(dim=0)) + pad_shifts_stride.append(pad_shift_stride.unsqueeze(dim=0)) imgs.append(frame["image"].unsqueeze(dim=0)) + source_imgs.append(source_img.unsqueeze(dim=0)) fidxs.append(frame["frame_idx"]) vidxs.append(frame["video_idx"]) org_szs.append(frame["orig_size"].unsqueeze(dim=0)) @@ -306,37 +338,28 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]: if imgs: # TODO: all preprocessing should be moved into InferenceModels to be exportable. imgs = torch.concatenate(imgs, dim=0) + source_imgs = torch.nested.nested_tensor(source_imgs) fidxs = torch.tensor(fidxs, dtype=torch.int32) vidxs = torch.tensor(vidxs, dtype=torch.int32) org_szs = torch.concatenate(org_szs, dim=0) + pad_shifts = torch.concatenate(pad_shifts, dim=0) + pad_shifts_stride = torch.concatenate(pad_shifts_stride, dim=0) eff_scales = torch.tensor(eff_scales, dtype=torch.float32) if self.instances_key: instances = torch.concatenate(instances, dim=0) ex = { "image": imgs, + "source_image": source_imgs, "frame_idx": fidxs, "video_idx": vidxs, "orig_size": org_szs, "eff_scale": eff_scales, + "pad_shifts": pad_shifts, + "pad_shifts_stride": pad_shifts_stride, } if self.instances_key: ex["instances"] = instances - if self.preprocess_config["is_rgb"] and ex["image"].shape[-3] != 3: - ex["image"] = ex["image"].repeat(1, 1, 3, 1, 1) - elif not self.preprocess_config["is_rgb"]: - ex["image"] = F.rgb_to_grayscale(ex["image"], num_output_channels=1) - if self.preprocess: - scale = self.preprocess_config["scale"] - if scale != 1.0: - if self.instances_key: - ex["image"], ex["instances"] = apply_resizer( - ex["image"], ex["instances"] - ) - else: - ex["image"] = resize_image(ex["image"], scale) - ex["image"] = apply_pad_to_stride( - ex["image"], self.preprocess_config["max_stride"] - ) + outputs_list = self.inference_model(ex) if outputs_list is not None: for output in outputs_list: @@ -455,6 +478,7 @@ class TopDownPredictor(Predictor): anchor_ind: Optional[int] = None is_multi_head_model: bool = False output_head_skeleton_num: int = 0 + centered_fitbbox: bool = False def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" @@ -569,8 +593,31 @@ def _initialize_inference_model(self): return_confmaps=self.return_confmaps, max_stride=max_stride, input_scale=scale, + centered_fitbbox=self.centered_fitbbox, ) centroid_crop_layer.precrop_resize = scale + centroid_crop_layer.output_head_skeleton_num = self.output_head_skeleton_num + if self.preprocess_config.centered_max_height is None: + self.preprocess_config.centered_max_height = ( + self.confmap_config.data_config.preprocessing.max_height[ + self.output_head_skeleton_num + ] + ) + if self.preprocess_config.centered_max_width is None: + self.preprocess_config.centered_max_width = ( + self.confmap_config.data_config.preprocessing.max_width[ + self.output_head_skeleton_num + ] + ) + if self.preprocess_config.max_crop_size is None: + if self.centered_fitbbox: + self.preprocess_config.max_crop_size = ( + self.confmap_config.data_config.max_crop_sizes[ + self.output_head_skeleton_num + ] + ) + centroid_crop_layer.preprocess_config = self.preprocess_config + centroid_crop_layer.centered_fitbbox = self.centered_fitbbox if self.centroid_config is None and self.confmap_config is not None: self.instances_key = ( @@ -611,6 +658,7 @@ def from_trained_models( device: str = "cpu", preprocess_config: Optional[OmegaConf] = None, output_head_skeleton_num: int = 1, + centered_fitbbox: bool = False, ) -> "TopDownPredictor": """Create predictor from saved models. @@ -815,6 +863,7 @@ def from_trained_models( anchor_ind=preprocess_config["anchor_ind"], output_head_skeleton_num=output_head_skeleton_num, is_multi_head_model=is_multi_head_model, + centered_fitbbox=centered_fitbbox, ) obj._initialize_inference_model() @@ -976,57 +1025,117 @@ def _make_labeled_frames_from_generator( preds = defaultdict(list) predicted_frames = [] skeleton_idx = 0 - # Loop through each predicted instance. - for ex in generator: - # loop through each sample in a batch - for ( - video_idx, - frame_idx, - bbox, - pred_instances, - pred_values, - instance_score, - org_size, - ) in zip( - ex["video_idx"], - ex["frame_idx"], - ex["instance_bbox"], - ex["pred_instance_peaks"], - ex["pred_peak_values"], - ex["centroid_val"], - ex["orig_size"], - ): - pred_instances = pred_instances + bbox.squeeze(axis=0)[0, :] - preds[(int(video_idx), int(frame_idx))].append( - sio.PredictedInstance.from_numpy( - points_data=pred_instances, - skeleton=self.skeletons[skeleton_idx], - point_scores=pred_values, - score=instance_score, + if isinstance( + self.inference_model.instance_peaks, FindInstancePeaksGroundTruth + ): + for ex in generator: + # loop through each sample in a batch + for ( + video_idx, + frame_idx, + pred_instances, + pred_values, + instance_score, + org_size, + ) in zip( + ex["video_idx"], + ex["frame_idx"], + ex["pred_instance_peaks"], + ex["pred_peak_values"], + ex["centroid_vals"], + ex["orig_size"], + ): + # Loop over instances. + predicted_instances = [] + for pts, confs, score in zip( + pred_instances, pred_values, instance_score + ): + if np.isnan(pts).all(): + continue + + predicted_instances.append( + sio.PredictedInstance.from_numpy( + points_data=pts, + point_scores=confs, + score=score, + skeleton=self.skeletons[skeleton_idx], + ) + ) + + lf = sio.LabeledFrame( + video=self.videos[video_idx], + frame_idx=frame_idx, + instances=predicted_instances, ) - ) - for key, inst in preds.items(): - # Create list of LabeledFrames. - video_idx, frame_idx = key - lf = sio.LabeledFrame( - video=self.videos[video_idx], - frame_idx=frame_idx, - instances=inst, + + if self.tracker: + lf.instances = self.tracker.track( + untracked_instances=inst, + frame_idx=frame_idx, + image=lf.image, + ) + + predicted_frames.append(lf) + + pred_labels = sio.Labels( + videos=self.videos, + skeletons=self.skeletons, + labeled_frames=predicted_frames, ) + return pred_labels - if self.tracker: - lf.instances = self.tracker.track( - untracked_instances=inst, frame_idx=frame_idx, image=lf.image + else: + # Loop through each predicted instance. + for ex in generator: + # loop through each sample in a batch + for ( + video_idx, + frame_idx, + bbox, + pred_instances, + pred_values, + instance_score, + org_size, + ) in zip( + ex["video_idx"], + ex["frame_idx"], + ex["instance_bbox"], + ex["pred_instance_peaks"], + ex["pred_peak_values"], + ex["centroid_val"], + ex["orig_size"], + ): + pred_instances = pred_instances + bbox.squeeze(axis=0)[0, :] + preds[(int(video_idx), int(frame_idx))].append( + sio.PredictedInstance.from_numpy( + points_data=pred_instances, + skeleton=self.skeletons[skeleton_idx], + point_scores=pred_values, + score=instance_score, + ) + ) + for key, inst in preds.items(): + # Create list of LabeledFrames. + video_idx, frame_idx = key + lf = sio.LabeledFrame( + video=self.videos[video_idx], + frame_idx=frame_idx, + instances=inst, ) - predicted_frames.append(lf) + if self.tracker: + lf.instances = self.tracker.track( + untracked_instances=inst, frame_idx=frame_idx, image=lf.image + ) - pred_labels = sio.Labels( - videos=self.videos, - skeletons=self.skeletons, - labeled_frames=predicted_frames, - ) - return pred_labels + predicted_frames.append(lf) + + pred_labels = sio.Labels( + videos=self.videos, + skeletons=self.skeletons, + labeled_frames=predicted_frames, + ) + return pred_labels @attrs.define @@ -1937,6 +2046,9 @@ def main( max_instances: Optional[int] = None, max_width: Optional[int] = None, max_height: Optional[int] = None, + centered_fitbbox: bool = False, + centered_max_height: Optional[int] = None, + centered_max_width: Optional[int] = None, is_rgb: bool = False, anchor_ind: Optional[int] = None, provider: Optional[str] = None, @@ -1945,6 +2057,7 @@ def main( videoreader_start_idx: Optional[int] = None, videoreader_end_idx: Optional[int] = None, crop_hw: Optional[List[int]] = None, + max_crop_size: Optional[list] = None, peak_threshold: Union[float, List[float]] = 0.2, integral_refinement: str = None, integral_patch_size: int = 5, @@ -2094,6 +2207,9 @@ def main( "max_width": max_width, "max_height": max_height, "anchor_ind": anchor_ind, + "centered_max_height": centered_max_height, + "centered_max_width": centered_max_width, + "max_crop_size": max_crop_size, } if provider is None: @@ -2121,6 +2237,7 @@ def main( device=device, preprocess_config=OmegaConf.create(preprocess_config), output_head_skeleton_num=output_head_skeleton_num, + centered_fitbbox=centered_fitbbox, ) if tracking: diff --git a/sleap_nn/inference/topdown.py b/sleap_nn/inference/topdown.py index 1736ea02..d255bc6c 100644 --- a/sleap_nn/inference/topdown.py +++ b/sleap_nn/inference/topdown.py @@ -5,15 +5,18 @@ import lightning as L import numpy as np from sleap_nn.data.resizing import ( + apply_sizematcher, resize_image, apply_pad_to_stride, ) from sleap_nn.inference.peak_finding import crop_bboxes from sleap_nn.data.instance_centroids import generate_centroids -from sleap_nn.data.instance_cropping import make_centered_bboxes +from sleap_nn.data.instance_cropping import make_centered_bboxes, get_fit_bbox from sleap_nn.inference.peak_finding import find_global_peaks, find_local_peaks from loguru import logger from collections import defaultdict +from omegaconf import DictConfig +from kornia.geometry.transform import crop_and_resize, crop_by_boxes class CentroidCrop(L.LightningModule): @@ -77,6 +80,9 @@ def __init__( max_stride: int = 1, use_gt_centroids: bool = False, anchor_ind: Optional[int] = None, + output_head_skeleton_num: Optional[int] = 0, + preprocess_config: Optional[dict] = None, + centered_fitbbox: bool = False, **kwargs, ): """Initialise the model attributes.""" @@ -95,19 +101,35 @@ def __init__( self.max_stride = max_stride self.use_gt_centroids = use_gt_centroids self.anchor_ind = anchor_ind + self.output_head_skeleton_num = output_head_skeleton_num + self.preprocess_config = preprocess_config + self.centered_fitbbox = centered_fitbbox def _generate_crops(self, inputs): """Generate Crops from the predicted centroids.""" crops_dict = [] - for centroid, centroid_val, image, fidx, vidx, sz, eff_sc in zip( + for centroid, centroid_val, image, fidx, vidx, sz, eff_sc, pad_shifts in zip( self.refined_peaks_batched, self.peak_vals_batched, - inputs["image"], + inputs["source_image"], inputs["frame_idx"], inputs["video_idx"], inputs["orig_size"], inputs["eff_scale"], + inputs["pad_shifts"], ): + + # size matcher + max_h = self.preprocess_config.centered_max_height + max_w = self.preprocess_config.centered_max_width + image, eff_sc, (pad_w_l, pad_h_t) = apply_sizematcher(image, max_h, max_w) + centroid = centroid * eff_sc + pad_shifts = torch.Tensor((pad_w_l, pad_h_t)) + centroid = centroid + pad_shifts + + image = resize_image(image, self.precrop_resize) + centroid = centroid * self.precrop_resize + if torch.any(torch.isnan(centroid)): if torch.all(torch.isnan(centroid)): continue @@ -148,6 +170,114 @@ def _generate_crops(self, inputs): ex["instance_image"] = instance_image.unsqueeze(dim=1) ex["orig_size"] = torch.cat([torch.Tensor(sz)] * n) ex["eff_scale"] = torch.Tensor([eff_sc] * n) + ex["pad_shifts"] = pad_shifts.unsqueeze(0).repeat(n, 1) + crops_dict.append(ex) + + return crops_dict + + def _generate_fitbbox_crops(self, inputs): + """Generate Crops from the predicted centroids.""" + crops_dict = [] + for ( + centroid, + centroid_val, + image, + fidx, + vidx, + sz, + eff_sc, + pad_shifts, + instances, + ) in zip( + self.refined_peaks_batched, + self.peak_vals_batched, + inputs["source_image"], + inputs["frame_idx"], + inputs["video_idx"], + inputs["orig_size"], + inputs["eff_scale"], + inputs["pad_shifts"], + inputs["instances"], + ): + + # adjust for initial size matching in preprocessing + + instances = ( + instances - pad_shifts + ) # .unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0) + instances = instances / eff_sc + + # size matcher + max_h = self.preprocess_config.centered_max_height + max_w = self.preprocess_config.centered_max_width + image, eff_sc, (pad_w_l, pad_h_t) = apply_sizematcher(image, max_h, max_w) + instances = instances * eff_sc + pad_shifts = torch.Tensor((pad_w_l, pad_h_t)) + instances = instances + pad_shifts.unsqueeze(dim=0).unsqueeze( + dim=0 + ).unsqueeze(dim=0) + + image = resize_image(image, self.precrop_resize) + instances = instances * self.precrop_resize + + n = centroid.shape[0] + + # get max bbox size for this batch + max_crop_size = self.preprocess_config.max_crop_size + + instance_images = [] + bbox_shifts = [] + eff_scale_crops = [] + padding_shifts_crops = [] + for instance in instances[0]: + if torch.all(torch.isnan(instance)): + continue + bbox = get_fit_bbox(instance) # bbox => (x_min, y_min, x_max, y_max) + bbox[0] = bbox[0] - 16 + bbox[1] = bbox[1] - 16 + bbox[2] = bbox[2] + 16 + bbox[3] = bbox[3] + 16 # padding of 16 on all sides + x_min, y_min, x_max, y_max = bbox + crop_hw = (y_max - y_min, x_max - x_min) + + cropped_image = crop_by_boxes( + image, + src_box=torch.Tensor( + [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]] + ).unsqueeze(dim=0), + dst_box=torch.tensor( + [ + [ + [0.0, 0.0], + [crop_hw[1], 0.0], + [crop_hw[1], crop_hw[0]], + [0.0, crop_hw[0]], + ] + ] + ), + ) + bbox_shifts.append(bbox[:2].unsqueeze(dim=0)) + + cropped_image_match_hw, eff_scale, pad_wh = apply_sizematcher( + cropped_image, max_crop_size[0], max_crop_size[1] + ) # resize and pad to max crfop size + instance_images.append(cropped_image_match_hw.unsqueeze(dim=0)) + eff_scale_crops.append(eff_scale) + padding_shifts_crops.append(torch.Tensor(pad_wh).unsqueeze(dim=0)) + + ex = {} + ex["image"] = torch.cat([image] * n) + ex["centroid_val"] = centroid_val + ex["frame_idx"] = torch.Tensor([fidx] * n) + ex["video_idx"] = torch.Tensor([vidx] * n) + ex["instance_bbox"] = torch.zeros((n, 1, 4, 2)) + ex["instance_image"] = torch.cat(instance_images, dim=0) + ex["eff_scale_crops"] = torch.Tensor(eff_scale_crops) + ex["padding_shifts_crops"] = torch.cat(padding_shifts_crops, dim=0) + ex["bbox_shifts"] = torch.cat(bbox_shifts, dim=0) + ex["orig_size"] = torch.cat([torch.Tensor(sz)] * n) + ex["eff_scale"] = torch.Tensor([eff_sc] * n) + ex["pad_shifts"] = pad_shifts.unsqueeze(0).repeat(n, 1) crops_dict.append(ex) return crops_dict @@ -176,21 +306,21 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: centroids = generate_centroids( inputs["instances"], anchor_ind=self.anchor_ind ) + centroid_vals = torch.ones(centroids.shape)[..., 0] self.refined_peaks_batched = [x[0] for x in centroids] self.peak_vals_batched = [x[0] for x in centroid_vals] - max_instances = ( - self.max_instances - if self.max_instances is not None - else inputs["instances"].shape[-3] - ) + max_instances = inputs["instances"].shape[-3] refined_peaks_with_nans = torch.zeros((batch, max_instances, 2)) peak_vals_with_nans = torch.zeros((batch, max_instances)) for ind, (r, p) in enumerate( zip(self.refined_peaks_batched, self.peak_vals_batched) ): + if not self.centered_fitbbox: + r -= inputs["pad_shifts"][ind].unsqueeze(dim=0) + r /= inputs["eff_scale"][ind] refined_peaks_with_nans[ind] = r peak_vals_with_nans[ind] = p @@ -202,12 +332,13 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: ) if self.return_crops: - crops_dict = self._generate_crops(inputs) - inputs["image"] = resize_image(inputs["image"], self.precrop_resize) - inputs["centroids"] *= self.precrop_resize + if self.centered_fitbbox: + crops_dict = self._generate_fitbbox_crops(inputs) + else: + crops_dict = self._generate_crops(inputs) scaled_refined_peaks = [] for ref_peak in self.refined_peaks_batched: - scaled_refined_peaks.append(ref_peak * self.precrop_resize) + scaled_refined_peaks.append(ref_peak) self.refined_peaks_batched = scaled_refined_peaks return crops_dict else: @@ -217,7 +348,9 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: orig_image = inputs["image"] scaled_image = resize_image(orig_image, self.input_scale) if self.max_stride != 1: - scaled_image = apply_pad_to_stride(scaled_image, self.max_stride) + scaled_image, (pad_stride_w, pad_stride_h) = apply_pad_to_stride( + scaled_image, self.max_stride + ) cms = self.torch_model(scaled_image) if isinstance(cms, list): # only one head for centroid model @@ -231,7 +364,6 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: ) # Adjust for stride and scale. refined_peaks = refined_peaks * self.output_stride # (n_centroids, 2) - refined_peaks = refined_peaks / self.input_scale batch = cms.shape[0] @@ -275,10 +407,18 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # Generate crops if return_crops=True to pass the crops to CenteredInstance model. if self.return_crops: - inputs["image"] = resize_image(inputs["image"], self.precrop_resize) scaled_refined_peaks = [] - for ref_peak in self.refined_peaks_batched: - scaled_refined_peaks.append(ref_peak * self.precrop_resize) + for ind, ref_peak in enumerate(self.refined_peaks_batched): + # remove padding stride -> input scale + if self.max_stride != 1: + ref_peak = ref_peak - torch.tensor( + (pad_stride_w, pad_stride_h) + ).to(ref_peak.device) + ref_peak = ref_peak / self.input_scale + ref_peak = ref_peak - inputs["pad_shifts"][ind].to(ref_peak.device) + ref_peak = ref_peak / (inputs["eff_scale"][ind].to(ref_peak.device)) + + scaled_refined_peaks.append(ref_peak) self.refined_peaks_batched = scaled_refined_peaks inputs.update( @@ -296,6 +436,12 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: for ind, (r, p) in enumerate( zip(self.refined_peaks_batched, self.peak_vals_batched) ): + # remove padding from max stride -> input scale -> size matcher padding -> size matcher scaling + if self.max_stride != 1: + r = r - torch.tensor((pad_stride_w, pad_stride_h)).to(r.device) + r = r / self.input_scale + r = r - inputs["pad_shifts"][ind].to(r.device) + refined_peaks_with_nans[ind] = r peak_vals_with_nans[ind] = p refined_peaks_with_nans = refined_peaks_with_nans / ( @@ -361,6 +507,14 @@ def __init__( def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, np.array]: """Return the ground truth instance peaks given a set of crops.""" b, _, max_inst, nodes, _ = batch["instances"].shape + batch["instances"] = batch["instances"] - batch["pad_shifts"].unsqueeze( + dim=1 + ).unsqueeze(dim=2).unsqueeze(dim=3).to(batch["instances"].device) + batch["instances"] = batch["instances"] / batch["eff_scale"].unsqueeze( + dim=1 + ).unsqueeze(dim=2).unsqueeze(dim=3).unsqueeze(dim=4).to( + batch["instances"].device + ) inst = ( batch["instances"].unsqueeze(dim=-4).float() ) # (batch, 1, 1, n_inst, nodes, 2) @@ -422,14 +576,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, np.array]: peaks_vals = vals peaks_output = batch - if peaks.size(0) != 0: - peaks = peaks / ( - batch["eff_scale"] - .unsqueeze(dim=1) - .unsqueeze(dim=2) - .unsqueeze(dim=3) - .to(peaks.device) - ) + peaks_output["instance_bbox"] = torch.Tensor([0, 0]).to(peaks.device) peaks_output["pred_instance_peaks"] = peaks peaks_output["pred_peak_values"] = peaks_vals @@ -478,6 +625,7 @@ def __init__( return_confmaps: Optional[bool] = False, input_scale: float = 1.0, max_stride: int = 1, + centered_fitbbox: bool = False, **kwargs, ): """Initialise the model attributes.""" @@ -490,6 +638,7 @@ def __init__( self.return_confmaps = return_confmaps self.input_scale = input_scale self.max_stride = max_stride + self.centered_fitbbox = centered_fitbbox def forward( self, inputs: Dict[str, torch.Tensor], output_head_skeleton_num: int = 0 @@ -525,7 +674,9 @@ def forward( # resize and pad the input image input_image = inputs["instance_image"] if self.max_stride != 1: - input_image = apply_pad_to_stride(input_image, self.max_stride) + input_image, (pad_w_l, pad_h_t) = apply_pad_to_stride( + input_image, self.max_stride + ) cms = self.torch_model(input_image) if isinstance(cms, list): @@ -540,15 +691,48 @@ def forward( # Adjust for stride and scale. peak_points = peak_points * self.output_stride + + # adjust for padding for max stride + if self.max_stride != 1: + pad_shift_strides = torch.Tensor((pad_w_l, pad_h_t)).to(peak_points.device) + peak_points = peak_points - pad_shift_strides + # inputs["instance_bbox"] = inputs["instance_bbox"] - pad_shift_strides + + if self.centered_fitbbox: + # max crop size size matching + peak_points = peak_points - inputs["padding_shifts_crops"].unsqueeze( + dim=1 + ).to(peak_points.device) + + peak_points = peak_points / ( + inputs["eff_scale_crops"] + .unsqueeze(dim=1) + .unsqueeze(dim=2) + .to(peak_points.device) + ) + + # bbox shifts (for fit bbox) + peak_points = peak_points + ( + inputs["bbox_shifts"].unsqueeze(dim=1).to(peak_points.device) + ) + # adjust for scaling: resizing if self.input_scale != 1.0: peak_points = peak_points / self.input_scale + inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale + + # adjust for sizematcher: padding + peak_points = peak_points - inputs["pad_shifts"].unsqueeze(dim=1).to( + peak_points.device + ) + + # inputs["instance_bbox"] = inputs["instance_bbox"] - (inputs["pad_shifts"].unsqueeze(dim=1).unsqueeze(dim=2)) + + # # adjust for sizematcher: scaling peak_points = peak_points / ( inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device) ) - inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale - inputs["instance_bbox"] = inputs["instance_bbox"] / ( inputs["eff_scale"] .unsqueeze(dim=1) @@ -561,6 +745,7 @@ def forward( outputs = {"pred_instance_peaks": peak_points, "pred_peak_values": peak_vals} if self.return_confmaps: outputs["pred_confmaps"] = cms.detach() + inputs["instance_image"] = input_image inputs.update(outputs) return inputs diff --git a/sleap_nn/training/lightning_modules.py b/sleap_nn/training/lightning_modules.py index 0673cc5a..64dd58dc 100644 --- a/sleap_nn/training/lightning_modules.py +++ b/sleap_nn/training/lightning_modules.py @@ -823,7 +823,10 @@ def __init__( model_type=model_type, ) self.inf_layer = FindInstancePeaks( - torch_model=self.forward, peak_threshold=0.2, return_confmaps=True + torch_model=self.forward, + peak_threshold=0.2, + return_confmaps=True, + centered_fitbbox=True, ) def on_train_epoch_start(self): @@ -840,6 +843,13 @@ def on_train_epoch_start(self): for d_num in self.config.dataset_mapper: sample = next(iter(self.trainer.val_dataloaders[d_num])) sample["eff_scale"] = torch.ones(sample["video_idx"].shape) + sample["pad_shifts"] = torch.zeros( + (sample["video_idx"].shape[0], 2) + ) + sample["eff_scale_crops"] = torch.ones(sample["video_idx"].shape) + sample["padding_shifts_crops"] = torch.zeros( + (sample["video_idx"].shape[0], 2) + ) for k, v in sample.items(): sample[k] = v.to(device=self.device) self.inf_layer.output_stride = self.config.model_config.head_configs.centered_instance.confmaps[ @@ -1260,6 +1270,9 @@ def on_train_epoch_start(self): sample = next(iter(self.trainer.val_dataloaders[d_num])) gt_centroids = sample["centroids"] sample["eff_scale"] = torch.ones(sample["video_idx"].shape) + sample["pad_shifts"] = torch.zeros( + (sample["video_idx"].shape[0], 2) + ) for k, v in sample.items(): sample[k] = v.to(device=self.device) output = self.centroid_inf_layer(sample) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 05204c40..d22c44a9 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -2,6 +2,7 @@ from pathlib import Path import os +import numpy as np import psutil import shutil import subprocess @@ -41,6 +42,7 @@ from sleap_nn.data.custom_datasets import ( BottomUpDataset, CenteredInstanceDataset, + CenteredInstanceDatasetFitBbox, CentroidDataset, SingleInstanceDataset, CyclerDataLoader, @@ -1035,6 +1037,7 @@ def __init__( self.max_heights = {} self.max_widths = {} self.crop_hws = {} + self.max_crop_hws = {} self.skeletons_dict = {} OmegaConf.update(self.config.data_config, f"skeletons", {}) @@ -1134,32 +1137,46 @@ def __init__( if self.model_type == "centered_instance": # compute crop size - self.crop_hws[d_num] = self.config.data_config.preprocessing.crop_hw[ - d_num - ] - if self.crop_hws[d_num] is None: - - min_crop_size = ( - self.config.data_config.preprocessing.min_crop_size - if "min_crop_size" in self.config.data_config.preprocessing - else None - ) - crop_size = find_instance_crop_size( - self.train_labels[d_num], - maximum_stride=self.max_stride, - min_crop_size=min_crop_size, - input_scaling=self.config.data_config.preprocessing.scale[ - d_num - ], - ) - self.crop_hws[d_num] = crop_size - self.config.data_config.preprocessing.crop_hw[d_num] = ( - self.crop_hws[d_num], - self.crop_hws[d_num], - ) - else: - self.crop_hws[d_num] = self.crop_hws[d_num][0] - + self.max_crop_hws[d_num] = [0, 0] + for lf in self.train_labels[d_num]: + for instance in lf.instances: + inst = instance.numpy() + x, y = inst[:, 0], inst[:, 1] + x_min, x_max = np.nanmin(x), np.nanmax(x) + y_min, y_max = np.nanmin(y), np.nanmax(y) + h, w = y_max - y_min, x_max - x_min + if h > self.max_crop_hws[d_num][0]: + self.max_crop_hws[d_num][0] = int(h) + if w > self.max_crop_hws[d_num][1]: + self.max_crop_hws[d_num][1] = int(w) + + # self.crop_hws[d_num] = self.config.data_config.preprocessing.crop_hw[ + # d_num + # ] + # if self.crop_hws[d_num] is None: + + # min_crop_size = ( + # self.config.data_config.preprocessing.min_crop_size + # if "min_crop_size" in self.config.data_config.preprocessing + # else None + # ) + # crop_size = find_instance_crop_size( + # self.train_labels[d_num], + # maximum_stride=self.max_stride, + # min_crop_size=min_crop_size, + # input_scaling=self.config.data_config.preprocessing.scale[ + # d_num + # ], + # ) + # self.crop_hws[d_num] = crop_size + # self.config.data_config.preprocessing.crop_hw[d_num] = ( + # self.crop_hws[d_num], + # self.crop_hws[d_num], + # ) + # else: + # self.crop_hws[d_num] = self.crop_hws[d_num][0] + + self.config.data_config.max_crop_sizes = self.max_crop_hws OmegaConf.save(config=self.config, f=f"{self.dir_path}/training_config.yaml") def _create_data_loaders_torch_dataset(self, d_num): @@ -1233,7 +1250,7 @@ def _create_data_loaders_torch_dataset(self, d_num): ) elif self.model_type == "centered_instance": - self.train_datasets[d_num] = CenteredInstanceDataset( + self.train_datasets[d_num] = CenteredInstanceDatasetFitBbox( labels=self.train_labels[d_num], confmap_head_config=self.config.model_config.head_configs.centered_instance.confmaps[ d_num @@ -1244,14 +1261,14 @@ def _create_data_loaders_torch_dataset(self, d_num): augmentation_config=self.config.data_config.augmentation_config, scale=self.config.data_config.preprocessing.scale[d_num], apply_aug=self.config.data_config.use_augmentations_train, - crop_hw=(self.crop_hws[d_num], self.crop_hws[d_num]), + max_crop_hw=self.max_crop_hws[d_num], max_hw=(self.max_heights[d_num], self.max_widths[d_num]), cache_img=self.cache_img, cache_img_path=self.train_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], rank=self.trainer.global_rank if self.trainer is not None else None, ) - self.val_datasets[d_num] = CenteredInstanceDataset( + self.val_datasets[d_num] = CenteredInstanceDatasetFitBbox( labels=self.val_labels[d_num], confmap_head_config=self.config.model_config.head_configs.centered_instance.confmaps[ d_num @@ -1262,7 +1279,7 @@ def _create_data_loaders_torch_dataset(self, d_num): augmentation_config=None, scale=self.config.data_config.preprocessing.scale[d_num], apply_aug=False, - crop_hw=(self.crop_hws[d_num], self.crop_hws[d_num]), + max_crop_hw=self.max_crop_hws[d_num], max_hw=(self.max_heights[d_num], self.max_widths[d_num]), cache_img=self.cache_img, cache_img_path=self.val_cache_img_paths[d_num], diff --git a/tests/data/test_resizing.py b/tests/data/test_resizing.py index 6862ae86..284d77d1 100644 --- a/tests/data/test_resizing.py +++ b/tests/data/test_resizing.py @@ -116,11 +116,12 @@ def test_apply_pad_to_stride(minimal_instance): lf = labels[0] ex = process_lf(lf, 0, 2) - image = apply_pad_to_stride(ex["image"], max_stride=2) + image, _ = apply_pad_to_stride(ex["image"], max_stride=2) assert image.shape == torch.Size([1, 1, 384, 384]) - image = apply_pad_to_stride(ex["image"], max_stride=200) + image, padding = apply_pad_to_stride(ex["image"], max_stride=200) assert image.shape == torch.Size([1, 1, 400, 400]) + assert padding == (8, 8) def test_apply_sizematcher(caplog, minimal_instance): @@ -129,14 +130,15 @@ def test_apply_sizematcher(caplog, minimal_instance): lf = labels[0] ex = process_lf(lf, 0, 2) - image, _ = apply_sizematcher(ex["image"], 500, 500) + image, _, _ = apply_sizematcher(ex["image"], 500, 500) assert image.shape == torch.Size([1, 1, 500, 500]) - image, _ = apply_sizematcher(ex["image"], 700, 600) + image, _, _ = apply_sizematcher(ex["image"], 700, 600) assert image.shape == torch.Size([1, 1, 700, 600]) - image, _ = apply_sizematcher(ex["image"]) + image, _, padding = apply_sizematcher(ex["image"]) assert image.shape == torch.Size([1, 1, 384, 384]) + assert padding == (0, 0) - image, eff = apply_sizematcher(ex["image"], 100, 480) + image, eff, _ = apply_sizematcher(ex["image"], 100, 480) assert image.shape == torch.Size([1, 1, 100, 480]) diff --git a/tests/training/test_model_trainer.py b/tests/training/test_model_trainer.py index 4753f8a8..00cdca38 100644 --- a/tests/training/test_model_trainer.py +++ b/tests/training/test_model_trainer.py @@ -16,7 +16,6 @@ CentroidLightningModule, BottomUpLightningModule, ) -from torch.nn.functional import mse_loss import os import wandb from lightning.pytorch.loggers import WandbLogger