From 6b1bfa6a0242b6aee26bc749371a7c4f46be755e Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 23 Apr 2025 16:48:16 -0700 Subject: [PATCH 1/8] Add torch nn module for multi-head model --- docs/config_slumbr.yaml | 147 ++++++++++++++++++++++++++++++++ sleap_nn/architectures/model.py | 131 +++++++++++++++++++++++++++- 2 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 docs/config_slumbr.yaml diff --git a/docs/config_slumbr.yaml b/docs/config_slumbr.yaml new file mode 100644 index 00000000..82582b14 --- /dev/null +++ b/docs/config_slumbr.yaml @@ -0,0 +1,147 @@ +dataset_mapper: + 0: dataset_1 + 1: dataset_2 + 2: dataset_3 +data_config: + provider: LabelsReader + train_labels_path: + 0: minimal_instance.pkg.slp + 1: minimal_instance.pkg.slp + 2: minimal_instance.pkg.slp + val_labels_path: + 0: minimal_instance.pkg.slp + 1: minimal_instance.pkg.slp + 2: minimal_instance.pkg.slp + test_file_path: + 0: minimal_instance.pkg.slp + 1: minimal_instance.pkg.slp + 2: minimal_instance.pkg.slp + data_pipeline_fw: torch_dataset_cache_img_disk + cache_img_path: ./img_dir/ + use_existing_imgs: + 0: + 1: + 2: + user_instances_only: True + litdata_chunks_path: + chunk_size: + delete_cache_imgs_after_training: + preprocessing: + max_width: + 0: + 1: + 2: + max_height: + 0: + 1: + 2: + scale: + 0: + 1: + 2: + is_rgb: True + crop_hw: + 0: + 1: + 2: + min_crop_size: + use_augmentations_train: true + augmentation_config: + geometric: + rotation: 180.0 + scale: null + translate_width: 0 + translate_height: 0 + affine_p: 0.5 + +model_config: + init_weights: xavier + pre_trained_weights: + pretrained_backbone_weights: + pretrained_head_weights: + backbone_type: unet + backbone_config: + in_channels: 1 + kernel_size: 3 + filters: 16 + filters_rate: 2 + max_stride: 16 + convs_per_block: 2 + stacks: 1 + stem_stride: + middle_block: True + up_interpolate: True + head_configs: + single_instance: + centroid: + bottomup: + centered_instance: + confmaps: + 0: + part_names: + anchor_part: 0 + sigma: 1.5 + output_stride: 2 + 1: + part_names: + anchor_part: 2 + sigma: 1.5 + output_stride: 2 + 2: + part_names: + anchor_part: 4 + sigma: 1.5 + output_stride: 2 +trainer_config: + train_data_loader: + batch_size: 4 + shuffle: true + num_workers: 2 + val_data_loader: + batch_size: 4 + num_workers: 2 + combined_loader_mode: "max_size_cycle" # or "min_size", "max_size" + model_ckpt: + save_top_k: 1 + save_last: true + trainer_devices: 1 + trainer_accelerator: gpu + profiler: simple + trainer_strategy: "ddp_find_unused_parameters_false" # "auto", "ddp", "ddp_find_unused_parameters_true" + enable_progress_bar: false + log_inf_epochs: 5 + steps_per_epoch: + 0: + 1: + 2: + max_epochs: 10 + seed: 1000 + use_wandb: true + wandb: + entity: + project: 'test_centroid_centered' + name: 'fly_unet_centered' + wandb_mode: '' + api_key: '' + prv_runid: + group: + save_ckpt: true + save_ckpt_path: 'multi_head' + resume_ckpt_path: + optimizer_name: Adam + optimizer: + lr: 0.0001 + amsgrad: false + lr_scheduler: + scheduler: ReduceLROnPlateau + reduce_lr_on_plateau: + threshold: 1.0e-07 + threshold_mode: abs + cooldown: 3 + patience: 5 + factor: 0.5 + min_lr: 1.0e-08 + early_stopping: + stop_training_on_plateau: True + min_delta: 1.0e-08 + patience: 20 \ No newline at end of file diff --git a/sleap_nn/architectures/model.py b/sleap_nn/architectures/model.py index 5e8c3a83..744fca12 100644 --- a/sleap_nn/architectures/model.py +++ b/sleap_nn/architectures/model.py @@ -6,7 +6,7 @@ """ from typing import List - +from collections import defaultdict import torch from omegaconf.dictconfig import DictConfig from torch import nn @@ -184,3 +184,132 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: outputs[head.name] = head_layer(backbone_outputs["outputs"][idx]) return outputs + + +class MultiHeadModel(nn.Module): + """Model creates a model consisting of a backbone and head. + + Attributes: + backbone_type: Backbone type. One of `unet`, `convnext` and `swint`. + backbone_config: An `DictConfig` configuration dictionary for the model backbone. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. + head_configs: An `DictConfig` configuration dictionary for the model heads + (this should have multiple head configs for each dataset). + """ + + def __init__( + self, + backbone_type: str, + backbone_config: DictConfig, + model_type: str, + head_configs: DictConfig, + ) -> None: + """Initialize the backbone and head based on the backbone_config.""" + super().__init__() + self.backbone_type = backbone_type + self.backbone_config = backbone_config + self.model_type = model_type + self.head_configs = head_configs + + self.heads = [] + if self.model_type == "single_instance": + for d_num, _ in self.head_configs.confmaps.items(): + self.heads.append( + SingleInstanceConfmapsHead(**self.head_configs.confmaps[d_num]) + ) + + elif self.model_type == "centered_instance": + for d_num, _ in self.head_configs.confmaps.items(): + self.heads.append( + CenteredInstanceConfmapsHead(**self.head_configs.confmaps[d_num]) + ) + + elif self.model_type == "centroid": + centroid_confmaps = self.head_configs.confmaps[0].copy() + centroid_confmaps.anchor_part = None + self.heads.append(CentroidConfmapsHead(**centroid_confmaps)) + + elif self.model_type == "bottomup": + for d_num, _ in self.head_configs.confmaps.items(): + self.heads.append( + MultiInstanceConfmapsHead(**self.head_configs.confmaps[d_num]) + ) + for d_num, _ in self.head_configs.pafs.items(): + self.heads.append( + PartAffinityFieldsHead(**self.head_configs.pafs[d_num]) + ) + + else: + message = f"{self.model_type} is not a defined model type. Please choose one of `single_instance`, `centered_instance`, `centroid`, `bottomup`." + logger.error(message) + raise Exception(message) + + output_strides = [] + for head_type in head_configs: + head_config = head_configs[head_type] + output_strides.extend([cfg.output_stride for cfg in head_config]) + + min_output_stride = min(output_strides) + min_output_stride = min(min_output_stride, self.backbone_config.output_stride) + + self.backbone = get_backbone( + self.backbone_type, + backbone_config, + ) + + strides = self.backbone.dec.current_strides + self.head_layers = nn.ModuleList([]) + for head in self.heads: + in_channels = int( + round( + self.backbone.max_channels + / ( + self.backbone_config.filters_rate + ** len(self.backbone.dec.decoder_stack) + ) + ) + ) + if head.output_stride != min_output_stride: + factor = strides.index(min_output_stride) - strides.index( + head.output_stride + ) + in_channels = in_channels * (self.backbone_config.filters_rate**factor) + self.head_layers.append(head.make_head(x_in=int(in_channels))) + + @classmethod + def from_config( + cls, + backbone_type: str, + backbone_config: DictConfig, + model_type: str, + head_configs: DictConfig, + ) -> "MultiHeadModel": + """Create the model from a config dictionary.""" + return cls( + backbone_type=backbone_type, + backbone_config=backbone_config, + model_type=model_type, + head_configs=head_configs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the model. + + Args: + x: Input image. + + Returns: + A dictionary with key as the head name and values as list of confmaps + for each of the skeleton formats (in the order of datasets in the + config). + """ + backbone_outputs = self.backbone(x) + + outputs = defaultdict(list) + for head, head_layer in zip(self.heads, self.head_layers): + idx = backbone_outputs["strides"].index(head.output_stride) + outputs[head.name].append( + head_layer(backbone_outputs["outputs"][idx]) + ) # eg: outputs = {"SingleInstanceConfmapsHead" : [output_head_0, output_head_1, output_head_2, ...]} + + return outputs From 381d0c37f9fd9a2d796d3cd23c899c51fcd41027 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 23 Apr 2025 17:07:47 -0700 Subject: [PATCH 2/8] Modify inference pipeline for multi-head-model --- sleap_nn/inference/bottomup.py | 13 +- sleap_nn/inference/predictors.py | 483 ++++++++++++++++++++----- sleap_nn/inference/single_instance.py | 9 + sleap_nn/inference/topdown.py | 17 +- sleap_nn/training/lightning_modules.py | 20 + 5 files changed, 450 insertions(+), 92 deletions(-) diff --git a/sleap_nn/inference/bottomup.py b/sleap_nn/inference/bottomup.py index f06936eb..3ebfbc66 100644 --- a/sleap_nn/inference/bottomup.py +++ b/sleap_nn/inference/bottomup.py @@ -47,6 +47,9 @@ class BottomUpInferenceModel(L.LightningModule): representation used during instance grouping. input_scale: Float indicating if the images should be resized before being passed to the model. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. """ def __init__( @@ -62,6 +65,7 @@ def __init__( return_pafs: Optional[bool] = False, return_paf_graph: Optional[bool] = False, input_scale: float = 1.0, + output_head_skeleton_num: int = 0, ): """Initialise the model attributes.""" super().__init__() @@ -76,6 +80,7 @@ def __init__( self.return_pafs = return_pafs self.return_paf_graph = return_paf_graph self.input_scale = input_scale + self.output_head_skeleton_num = output_head_skeleton_num def _generate_cms_peaks(self, cms): peaks, peak_vals, sample_inds, peak_channel_inds = find_local_peaks( @@ -120,7 +125,13 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: self.batch_size = inputs["image"].shape[0] output = self.torch_model(inputs["image"]) cms = output["MultiInstanceConfmapsHead"] - pafs = output["PartAffinityFieldsHead"].permute(0, 2, 3, 1) + # for multi-head-model, output is a list of confmaps for the different heads + if isinstance(cms, list): + cms = cms[self.output_head_skeleton_num] + pafs = output["PartAffinityFieldsHead"] + if isinstance(pafs, list): + pafs = pafs[self.output_head_skeleton_num] + pafs = pafs.permute(0, 2, 3, 1) cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms) ( diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 9947cbe9..de27b499 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -28,6 +28,10 @@ SingleInstanceModel, CentroidModel, BottomUpModel, + TopDownCenteredInstanceMultiHeadLightningModule, + SingleInstanceMultiHeadLightningModule, + CentroidMultiHeadLightningModule, + BottomUpMultiHeadLightningModule, ) from sleap_nn.inference.single_instance import SingleInstanceInferenceModel from sleap_nn.inference.bottomup import BottomUpInferenceModel @@ -96,6 +100,7 @@ def from_model_paths( return_confmaps: bool = False, device: str = "cpu", preprocess_config: Optional[OmegaConf] = None, + output_head_skeleton_num: int = 1, ) -> "Predictor": """Create the appropriate `Predictor` subclass from from the ckpt path. @@ -126,6 +131,9 @@ def from_model_paths( Default: "cpu" preprocess_config: (OmegaConf) OmegaConf object with keys as the parameters in the `data_config.preprocessing` section. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. Returns: A subclass of `Predictor`. @@ -156,6 +164,7 @@ def from_model_paths( return_confmaps=return_confmaps, device=device, preprocess_config=preprocess_config, + output_head_skeleton_num=output_head_skeleton_num, ) elif "centroid" in model_names or "centered_instance" in model_names: @@ -180,6 +189,7 @@ def from_model_paths( return_confmaps=return_confmaps, device=device, preprocess_config=preprocess_config, + output_head_skeleton_num=output_head_skeleton_num, ) elif "bottomup" in model_names: @@ -196,6 +206,7 @@ def from_model_paths( return_confmaps=return_confmaps, device=device, preprocess_config=preprocess_config, + output_head_skeleton_num=output_head_skeleton_num, ) else: message = f"Could not create predictor from model paths:\n{model_paths}" @@ -417,6 +428,10 @@ class TopDownPredictor(Predictor): if this is `None`. anchor_ind: (int) The index of the node to use as the anchor for the centroid. If not provided, the anchor idx in the `training_config.yaml` is used instead. + is_multi_head_model: True if inference should be performed on a multi-head model. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. """ @@ -438,6 +453,8 @@ class TopDownPredictor(Predictor): preprocess_config: Optional[OmegaConf] = None tracker: Optional[Tracker] = None anchor_ind: Optional[int] = None + is_multi_head_model: bool = False + output_head_skeleton_num: int = 0 def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" @@ -457,15 +474,29 @@ def _initialize_inference_model(self): self.data_config.crop_hw = ( self.confmap_config.data_config.preprocessing.crop_hw ) + if self.is_multi_head_model: + self.data_config.crop_hw = ( + self.confmap_config.data_config.preprocessing.crop_hw[ + self.output_head_skeleton_num + ] + ) if self.centroid_config is None: + if self.is_multi_head_model: + anchor_part = self.confmap_config.model_config.head_configs.centered_instance.confmaps[ + self.output_head_skeleton_num + ][ + "anchor_part" + ] + else: + anchor_part = ( + self.confmap_config.model_config.head_configs.centered_instance.confmaps.anchor_part + ) centroid_crop_layer = CentroidCrop( use_gt_centroids=True, crop_hw=self.data_config.crop_hw, anchor_ind=( - self.anchor_ind - if self.anchor_ind is not None - else self.confmap_config.model_config.head_configs.centered_instance.confmaps.anchor_part + self.anchor_ind if self.anchor_ind is not None else anchor_part ), return_crops=return_crops, ) @@ -475,17 +506,31 @@ def _initialize_inference_model(self): f"{self.centroid_backbone_type}" ]["max_stride"] # initialize centroid crop layer + if self.is_multi_head_model: + output_stride = ( + self.centroid_config.model_config.head_configs.centroid.confmaps[ + self.output_head_skeleton_num + ]["output_stride"] + ) + scale = self.centroid_config.data_config.preprocessing.scale[ + self.output_head_skeleton_num + ] + else: + output_stride = ( + self.centroid_config.model_config.head_configs.centroid.confmaps.output_stride + ) + scale = self.centroid_config.data_config.preprocessing.scale centroid_crop_layer = CentroidCrop( torch_model=self.centroid_model, peak_threshold=centroid_peak_threshold, - output_stride=self.centroid_config.model_config.head_configs.centroid.confmaps.output_stride, + output_stride=output_stride, refinement=self.integral_refinement, integral_patch_size=self.integral_patch_size, return_confmaps=self.return_confmaps, return_crops=return_crops, max_instances=self.max_instances, max_stride=max_stride, - input_scale=self.centroid_config.data_config.preprocessing.scale, + input_scale=scale, crop_hw=self.data_config.crop_hw, use_gt_centroids=False, ) @@ -499,19 +544,33 @@ def _initialize_inference_model(self): max_stride = self.confmap_config.model_config.backbone_config[ f"{self.centered_instance_backbone_type}" ]["max_stride"] + + if self.is_multi_head_model: + output_stride = self.confmap_config.model_config.head_configs.centered_instance.confmaps[ + self.output_head_skeleton_num + ][ + "output_stride" + ] + scale = self.confmap_config.data_config.preprocessing.scale[ + self.output_head_skeleton_num + ] + else: + output_stride = ( + self.confmap_config.model_config.head_configs.centered_instance.confmaps.output_stride + ) + scale = self.confmap_config.data_config.preprocessing.scale + instance_peaks_layer = FindInstancePeaks( torch_model=self.confmap_model, peak_threshold=centered_instance_peak_threshold, - output_stride=self.confmap_config.model_config.head_configs.centered_instance.confmaps.output_stride, + output_stride=output_stride, refinement=self.integral_refinement, integral_patch_size=self.integral_patch_size, return_confmaps=self.return_confmaps, max_stride=max_stride, - input_scale=self.confmap_config.data_config.preprocessing.scale, - ) - centroid_crop_layer.precrop_resize = ( - self.confmap_config.data_config.preprocessing.scale + input_scale=scale, ) + centroid_crop_layer.precrop_resize = scale if self.centroid_config is None and self.confmap_config is not None: self.instances_key = ( @@ -520,7 +579,9 @@ def _initialize_inference_model(self): # Initialize the inference model with centroid and instance peak layers self.inference_model = TopDownInferenceModel( - centroid_crop=centroid_crop_layer, instance_peaks=instance_peaks_layer + centroid_crop=centroid_crop_layer, + instance_peaks=instance_peaks_layer, + output_head_skeleton_num=self.output_head_skeleton_num, ) @property @@ -549,6 +610,7 @@ def from_trained_models( return_confmaps: bool = False, device: str = "cpu", preprocess_config: Optional[OmegaConf] = None, + output_head_skeleton_num: int = 1, ) -> "TopDownPredictor": """Create predictor from saved models. @@ -576,6 +638,9 @@ def from_trained_models( Default: "cpu" preprocess_config: (OmegaConf) OmegaConf object with keys as the parameters in the `data_config.preprocessing` section and the `anchor_ind`. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. Returns: An instance of `TopDownPredictor` with the loaded models. @@ -591,7 +656,10 @@ def from_trained_models( centroid_config = OmegaConf.load( f"{centroid_ckpt_path}/training_config.yaml" ) - skeletons = get_skeleton_from_config(centroid_config.data_config.skeletons) + is_multi_head_model = False + if "dataset_mapper" in centroid_config: + is_multi_head_model = True + ckpt_path = f"{centroid_ckpt_path}/best.ckpt" # check which backbone architecture @@ -600,13 +668,31 @@ def from_trained_models( centroid_backbone_type = k break - centroid_model = CentroidModel.load_from_checkpoint( - checkpoint_path=ckpt_path, - config=centroid_config, - skeletons=skeletons, - model_type="centroid", - backbone_type=centroid_backbone_type, - ) + if is_multi_head_model: + skeletons_dict = {} + for k in centroid_config.data_config.skeletons: + skeletons_dict[k] = get_skeleton_from_config( + centroid_config.data_config.skeletons[k] + ) + centroid_model = CentroidMultiHeadModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=centroid_config, + skeletons_dict=skeletons_dict, + model_type="centroid", + backbone_type=centroid_backbone_type, + ) + skeletons = skeletons_dict[output_head_skeleton_num] + else: + skeletons = get_skeleton_from_config( + centroid_config.data_config.skeletons + ) + centroid_model = CentroidModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=centroid_config, + skeletons=skeletons, + model_type="centroid", + backbone_type=centroid_backbone_type, + ) if backbone_ckpt_path is not None and head_ckpt_path is not None: logger.info(f"Loading backbone weights from `{backbone_ckpt_path}` ...") @@ -642,21 +728,45 @@ def from_trained_models( if confmap_ckpt_path is not None: # Load confmap model. confmap_config = OmegaConf.load(f"{confmap_ckpt_path}/training_config.yaml") - skeletons = get_skeleton_from_config(confmap_config.data_config.skeletons) ckpt_path = f"{confmap_ckpt_path}/best.ckpt" + is_multi_head_model = False + if "dataset_mapper" in confmap_config: + is_multi_head_model = True # check which backbone architecture for k, v in confmap_config.model_config.backbone_config.items(): if v is not None: centered_instance_backbone_type = k break - confmap_model = TopDownCenteredInstanceModel.load_from_checkpoint( - checkpoint_path=ckpt_path, - config=confmap_config, - skeletons=skeletons, - model_type="centered_instance", - backbone_type=centered_instance_backbone_type, - ) + + if is_multi_head_model: + skeletons_dict = {} + for k in confmap_config.data_config.skeletons: + skeletons_dict[k] = get_skeleton_from_config( + confmap_config.data_config.skeletons[k] + ) + confmap_model = ( + TopDownCenteredInstanceMultiHeadModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=confmap_config, + skeletons_dict=skeletons_dict, + model_type="centered_instance", + backbone_type=centered_instance_backbone_type, + ) + ) + skeletons = skeletons_dict[output_head_skeleton_num] + else: + skeletons = get_skeleton_from_config( + confmap_config.data_config.skeletons + ) + confmap_model = TopDownCenteredInstanceModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=confmap_config, + skeletons=skeletons, + model_type="centered_instance", + backbone_type=centered_instance_backbone_type, + ) + if backbone_ckpt_path is not None and head_ckpt_path is not None: logger.info(f"Loading backbone weights from `{backbone_ckpt_path}` ...") ckpt = torch.load(backbone_ckpt_path) @@ -706,6 +816,8 @@ def from_trained_models( device=device, preprocess_config=preprocess_config, anchor_ind=preprocess_config["anchor_ind"], + output_head_skeleton_num=output_head_skeleton_num, + is_multi_head_model=is_multi_head_model, ) obj._initialize_inference_model() @@ -741,6 +853,16 @@ def make_pipeline( scale = self.centroid_config.data_config.preprocessing.scale max_height = self.centroid_config.data_config.preprocessing.max_height max_width = self.centroid_config.data_config.preprocessing.max_width + if self.is_multi_head_model: + scale = self.centroid_config.data_config.preprocessing.scale[ + self.output_head_skeleton_num + ] + max_height = self.centroid_config.data_config.preprocessing.max_height[ + self.output_head_skeleton_num + ] + max_width = self.centroid_config.data_config.preprocessing.max_width[ + self.output_head_skeleton_num + ] else: max_stride = self.confmap_config.model_config.backbone_config[ f"{self.centered_instance_backbone_type}" @@ -748,6 +870,16 @@ def make_pipeline( scale = self.confmap_config.data_config.preprocessing.scale max_height = self.confmap_config.data_config.preprocessing.max_height max_width = self.confmap_config.data_config.preprocessing.max_width + if self.is_multi_head_model: + scale = self.confmap_config.data_config.preprocessing.scale[ + self.output_head_skeleton_num + ] + max_height = self.confmap_config.data_config.preprocessing.max_height[ + self.output_head_skeleton_num + ] + max_width = self.confmap_config.data_config.preprocessing.max_width[ + self.output_head_skeleton_num + ] # LabelsReader provider if self.provider == "LabelsReader": @@ -792,7 +924,7 @@ def make_pipeline( self.preprocess = False self.preprocess_config = { "batch_size": self.batch_size, - "scale": self.centroid_config.data_config.preprocessing.scale, + "scale": scale, "is_rgb": self.data_config.is_rgb, "max_stride": ( self.centroid_config.model_config.backbone_config[ @@ -934,6 +1066,10 @@ class SingleInstancePredictor(Predictor): Default: "cpu" preprocess_config: (OmegaConf) OmegaConf object with keys as the parameters in the `data_config.preprocessing` section. + is_multi_head_model: True if inference should be performed on a multi-head model. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. """ @@ -949,17 +1085,36 @@ class SingleInstancePredictor(Predictor): return_confmaps: bool = False device: str = "cpu" preprocess_config: Optional[OmegaConf] = None + is_multi_head_model: bool = False + output_head_skeleton_num: int = 0 def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" + if self.is_multi_head_model: + output_stride = ( + self.confmap_config.model_config.head_configs.single_instance.confmaps[ + self.output_head_skeleton_num + ]["output_stride"] + ) + scale = self.confmap_config.data_config.preprocessing.scale[ + self.output_head_skeleton_num + ] + else: + output_stride = ( + self.confmap_config.model_config.head_configs.single_instance.confmaps[ + "output_stride" + ] + ) + scale = self.confmap_config.data_config.preprocessing.scale self.inference_model = SingleInstanceInferenceModel( torch_model=self.confmap_model, peak_threshold=self.peak_threshold, - output_stride=self.confmap_config.model_config.head_configs.single_instance.confmaps.output_stride, + output_stride=output_stride, refinement=self.integral_refinement, integral_patch_size=self.integral_patch_size, return_confmaps=self.return_confmaps, - input_scale=self.confmap_config.data_config.preprocessing.scale, + input_scale=scale, + output_head_skeleton_num=self.output_head_skeleton_num, ) @property @@ -983,6 +1138,7 @@ def from_trained_models( return_confmaps: bool = False, device: str = "cpu", preprocess_config: Optional[OmegaConf] = None, + output_head_skeleton_num: int = 1, ) -> "SingleInstancePredictor": """Create predictor from saved models. @@ -1008,14 +1164,18 @@ def from_trained_models( Default: "cpu" preprocess_config: (OmegaConf) OmegaConf object with keys as the parameters in the `data_config.preprocessing` section. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. Returns: An instance of `SingleInstancePredictor` with the loaded models. """ confmap_config = OmegaConf.load(f"{confmap_ckpt_path}/training_config.yaml") - skeletons = get_skeleton_from_config(confmap_config.data_config.skeletons) - ckpt_path = f"{confmap_ckpt_path}/best.ckpt" + is_multi_head_model = False + if "dataset_mapper" in confmap_config: + is_multi_head_model = True # check which backbone architecture for k, v in confmap_config.model_config.backbone_config.items(): @@ -1023,13 +1183,31 @@ def from_trained_models( backbone_type = k break - confmap_model = SingleInstanceModel.load_from_checkpoint( - checkpoint_path=ckpt_path, - config=confmap_config, - skeletons=skeletons, - model_type="single_instance", - backbone_type=backbone_type, - ) + ckpt_path = f"{confmap_ckpt_path}/best.ckpt" + + if is_multi_head_model: + skeletons_dict = {} + for k in confmap_config.data_config.skeletons: + skeletons_dict[k] = get_skeleton_from_config( + confmap_config.data_config.skeletons[k] + ) + confmap_model = SingleInstanceMultiHeadModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=confmap_config, + skeletons_dict=skeletons_dict, + model_type="single_instance", + backbone_type=backbone_type, + ) + skeletons = skeletons_dict[output_head_skeleton_num] + else: + skeletons = get_skeleton_from_config(confmap_config.data_config.skeletons) + confmap_model = SingleInstanceModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=confmap_config, + skeletons=skeletons, + model_type="single_instance", + backbone_type=backbone_type, + ) if backbone_ckpt_path is not None and head_ckpt_path is not None: logger.info(f"Loading backbone weights from `{backbone_ckpt_path}` ...") ckpt = torch.load(backbone_ckpt_path) @@ -1069,6 +1247,8 @@ def from_trained_models( return_confmaps=return_confmaps, device=device, preprocess_config=preprocess_config, + output_head_skeleton_num=output_head_skeleton_num, + is_multi_head_model=is_multi_head_model, ) obj._initialize_inference_model() @@ -1098,6 +1278,19 @@ def make_pipeline( """ self.provider = provider + scale = self.confmap_config.data_config.preprocessing.scale + max_height = self.confmap_config.data_config.preprocessing.max_height + max_height = ( + max_height[self.output_head_skeleton_num] + if self.is_multi_head_model + else max_height + ) + max_width = self.confmap_config.data_config.preprocessing.max_width + max_width = ( + max_width[self.output_head_skeleton_num] + if self.is_multi_head_model + else max_width + ) # LabelsReader provider if self.provider == "LabelsReader": @@ -1110,18 +1303,22 @@ def make_pipeline( self.preprocess = False self.preprocess_config = { "batch_size": self.batch_size, - "scale": self.confmap_config.data_config.preprocessing.scale, + "scale": ( + scale[self.output_head_skeleton_num] + if self.is_multi_head_model + else scale + ), "is_rgb": self.data_config.is_rgb, "max_stride": max_stride, "max_height": ( self.data_config.max_height if self.data_config.max_height is not None - else self.confmap_config.data_config.preprocessing.max_height + else max_height ), "max_width": ( self.data_config.max_width if self.data_config.max_width is not None - else self.confmap_config.data_config.preprocessing.max_width + else max_width ), } @@ -1136,7 +1333,11 @@ def make_pipeline( self.preprocess = True self.preprocess_config = { "batch_size": self.batch_size, - "scale": self.confmap_config.data_config.preprocessing.scale, + "scale": ( + scale[self.output_head_skeleton_num] + if self.is_multi_head_model + else scale + ), "is_rgb": self.data_config.is_rgb, "max_stride": ( self.confmap_config.model_config.backbone_config[ @@ -1146,12 +1347,12 @@ def make_pipeline( "max_height": ( self.data_config.max_height if self.data_config.max_height is not None - else self.confmap_config.data_config.preprocessing.max_height + else max_height ), "max_width": ( self.data_config.max_width if self.data_config.max_width is not None - else self.confmap_config.data_config.preprocessing.max_width + else max_width ), } @@ -1279,6 +1480,10 @@ class BottomUpPredictor(Predictor): tracker: A `sleap.nn.tracking.Tracker` that will be called to associate detections over time. Predicted instances will not be assigned to tracks if if this is `None`. + is_multi_head_model: True if inference should be performed on a multi-head model. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. """ @@ -1301,40 +1506,88 @@ class BottomUpPredictor(Predictor): device: str = "cpu" preprocess_config: Optional[OmegaConf] = None tracker: Optional[Tracker] = None + is_multi_head_model: bool = False + output_head_skeleton_num: int = 0 def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" # initialize the paf scorer - paf_scorer = PAFScorer.from_config( - config=OmegaConf.create( - { - "confmaps": self.bottomup_config.model_config.head_configs.bottomup[ - "confmaps" - ], - "pafs": self.bottomup_config.model_config.head_configs.bottomup[ - "pafs" - ], - } - ), - max_edge_length_ratio=self.max_edge_length_ratio, - dist_penalty_weight=self.dist_penalty_weight, - n_points=self.n_points, - min_instance_peaks=self.min_instance_peaks, - min_line_scores=self.min_line_scores, - ) + if self.is_multi_head_model: + paf_scorer = PAFScorer.from_config( + config=OmegaConf.create( + { + "confmaps": self.bottomup_config.model_config.head_configs.bottomup[ + "confmaps" + ][ + self.output_head_skeleton_num + ], + "pafs": self.bottomup_config.model_config.head_configs.bottomup[ + "pafs" + ][self.output_head_skeleton_num], + } + ), + max_edge_length_ratio=self.max_edge_length_ratio, + dist_penalty_weight=self.dist_penalty_weight, + n_points=self.n_points, + min_instance_peaks=self.min_instance_peaks, + min_line_scores=self.min_line_scores, + ) - # initialize the BottomUpInferenceModel - self.inference_model = BottomUpInferenceModel( - torch_model=self.bottomup_model, - paf_scorer=paf_scorer, - peak_threshold=self.peak_threshold, - cms_output_stride=self.bottomup_config.model_config.head_configs.bottomup.confmaps.output_stride, - pafs_output_stride=self.bottomup_config.model_config.head_configs.bottomup.pafs.output_stride, - refinement=self.integral_refinement, - integral_patch_size=self.integral_patch_size, - return_confmaps=self.return_confmaps, - input_scale=self.bottomup_config.data_config.preprocessing.scale, - ) + # initialize the BottomUpInferenceModel + self.inference_model = BottomUpInferenceModel( + torch_model=self.bottomup_model, + paf_scorer=paf_scorer, + peak_threshold=self.peak_threshold, + cms_output_stride=self.bottomup_config.model_config.head_configs.bottomup.confmaps[ + self.output_head_skeleton_num + ][ + "output_stride" + ], + pafs_output_stride=self.bottomup_config.model_config.head_configs.bottomup.pafs[ + self.output_head_skeleton_num + ][ + "output_stride" + ], + refinement=self.integral_refinement, + integral_patch_size=self.integral_patch_size, + return_confmaps=self.return_confmaps, + input_scale=self.bottomup_config.data_config.preprocessing.scale[ + self.output_head_skeleton_num + ], + output_head_skeleton_num=self.output_head_skeleton_num, + ) + else: + paf_scorer = PAFScorer.from_config( + config=OmegaConf.create( + { + "confmaps": self.bottomup_config.model_config.head_configs.bottomup[ + "confmaps" + ], + "pafs": self.bottomup_config.model_config.head_configs.bottomup[ + "pafs" + ], + } + ), + max_edge_length_ratio=self.max_edge_length_ratio, + dist_penalty_weight=self.dist_penalty_weight, + n_points=self.n_points, + min_instance_peaks=self.min_instance_peaks, + min_line_scores=self.min_line_scores, + ) + + # initialize the BottomUpInferenceModel + self.inference_model = BottomUpInferenceModel( + torch_model=self.bottomup_model, + paf_scorer=paf_scorer, + peak_threshold=self.peak_threshold, + cms_output_stride=self.bottomup_config.model_config.head_configs.bottomup.confmaps.output_stride, + pafs_output_stride=self.bottomup_config.model_config.head_configs.bottomup.pafs.output_stride, + refinement=self.integral_refinement, + integral_patch_size=self.integral_patch_size, + return_confmaps=self.return_confmaps, + input_scale=self.bottomup_config.data_config.preprocessing.scale, + output_head_skeleton_num=self.output_head_skeleton_num, + ) @property def data_config(self) -> OmegaConf: @@ -1358,6 +1611,7 @@ def from_trained_models( return_confmaps: bool = False, device: str = "cpu", preprocess_config: Optional[OmegaConf] = None, + output_head_skeleton_num: int = 1, ) -> "BottomUpPredictor": """Create predictor from saved models. @@ -1384,13 +1638,19 @@ def from_trained_models( Default: "cpu" preprocess_config: (OmegaConf) OmegaConf object with keys as the parameters in the `data_config.preprocessing` section. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. Returns: An instance of `BottomUpPredictor` with the loaded models. """ bottomup_config = OmegaConf.load(f"{bottomup_ckpt_path}/training_config.yaml") - skeletons = get_skeleton_from_config(bottomup_config.data_config.skeletons) + is_multi_head_model = False + if "dataset_mapper" in bottomup_config: + is_multi_head_model = True + ckpt_path = f"{bottomup_ckpt_path}/best.ckpt" # check which backbone architecture @@ -1399,13 +1659,32 @@ def from_trained_models( backbone_type = k break - bottomup_model = BottomUpModel.load_from_checkpoint( - checkpoint_path=ckpt_path, - config=bottomup_config, - skeletons=skeletons, - backbone_type=backbone_type, - model_type="bottomup", - ) + if is_multi_head_model: + skeletons_dict = {} + for k in bottomup_config.data_config.skeletons: + skeletons_dict[k] = get_skeleton_from_config( + bottomup_config.data_config.skeletons[k] + ) + + bottomup_model = BottomUpMultiHeadModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=bottomup_config, + skeletons_dict=skeletons_dict, + backbone_type=backbone_type, + model_type="bottomup", + ) + skeletons = skeletons_dict[output_head_skeleton_num] + + else: + skeletons = get_skeleton_from_config(bottomup_config.data_config.skeletons) + bottomup_model = BottomUpModel.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=bottomup_config, + skeletons=skeletons, + backbone_type=backbone_type, + model_type="bottomup", + ) + if backbone_ckpt_path is not None and head_ckpt_path is not None: logger.info(f"Loading backbone weights from `{backbone_ckpt_path}` ...") ckpt = torch.load(backbone_ckpt_path) @@ -1445,6 +1724,8 @@ def from_trained_models( max_instances=max_instances, return_confmaps=return_confmaps, preprocess_config=preprocess_config, + output_head_skeleton_num=output_head_skeleton_num, + is_multi_head_model=is_multi_head_model, ) obj._initialize_inference_model() @@ -1474,6 +1755,19 @@ def make_pipeline( """ self.provider = provider # LabelsReader provider + scale = self.bottomup_config.data_config.preprocessing.scale + max_height = self.bottomup_config.data_config.preprocessing.max_height + max_height = ( + max_height[self.output_head_skeleton_num] + if self.is_multi_head_model + else max_height + ) + max_width = self.bottomup_config.data_config.preprocessing.max_width + max_width = ( + max_width[self.output_head_skeleton_num] + if self.is_multi_head_model + else max_width + ) if self.provider == "LabelsReader": provider = LabelsReader @@ -1484,18 +1778,22 @@ def make_pipeline( self.preprocess = False self.preprocess_config = { "batch_size": self.batch_size, - "scale": self.bottomup_config.data_config.preprocessing.scale, + "scale": ( + scale[self.output_head_skeleton_num] + if self.is_multi_head_model + else scale + ), "is_rgb": self.data_config.is_rgb, "max_stride": max_stride, "max_height": ( self.data_config.max_height if self.data_config.max_height is not None - else self.bottomup_config.data_config.preprocessing.max_height + else max_height ), "max_width": ( self.data_config.max_width if self.data_config.max_width is not None - else self.bottomup_config.data_config.preprocessing.max_width + else max_width ), } @@ -1510,7 +1808,11 @@ def make_pipeline( self.preprocess = True self.preprocess_config = { "batch_size": self.batch_size, - "scale": self.bottomup_config.data_config.preprocessing.scale, + "scale": ( + scale[self.output_head_skeleton_num] + if self.is_multi_head_model + else scale + ), "is_rgb": self.data_config.is_rgb, "max_stride": ( self.bottomup_config.model_config.backbone_config[ @@ -1520,12 +1822,12 @@ def make_pipeline( "max_height": ( self.data_config.max_height if self.data_config.max_height is not None - else self.bottomup_config.data_config.preprocessing.max_height + else max_height ), "max_width": ( self.data_config.max_width if self.data_config.max_width is not None - else self.bottomup_config.data_config.preprocessing.max_width + else max_width ), } @@ -1677,6 +1979,7 @@ def main( of_img_scale: float = 1.0, of_window_size: int = 21, of_max_levels: int = 3, + output_head_skeleton_num: int = 1, ): """Entry point to run inference on trained SLEAP-NN models. @@ -1783,6 +2086,9 @@ def main( of_max_levels: Number of pyramid scale levels to consider. This is different from the scale parameter, which determines the initial image scaling. Default: 3. (only if `use_flow` is True) + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. Returns: Returns `sio.Labels` object if `make_labels` is True. Else this function returns @@ -1821,6 +2127,7 @@ def main( return_confmaps=return_confmaps, device=device, preprocess_config=OmegaConf.create(preprocess_config), + output_head_skeleton_num=output_head_skeleton_num, ) if tracking: diff --git a/sleap_nn/inference/single_instance.py b/sleap_nn/inference/single_instance.py index b20b9ed3..e50da32c 100644 --- a/sleap_nn/inference/single_instance.py +++ b/sleap_nn/inference/single_instance.py @@ -32,6 +32,9 @@ class SingleInstanceInferenceModel(L.LightningModule): the predicted peaks. input_scale: Float indicating if the images should be resized before being passed to the model. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. """ def __init__( @@ -43,6 +46,7 @@ def __init__( integral_patch_size: int = 5, return_confmaps: Optional[bool] = False, input_scale: float = 1.0, + output_head_skeleton_num: int = 0, ): """Initialise the model attributes.""" super().__init__() @@ -53,6 +57,7 @@ def __init__( self.output_stride = output_stride self.return_confmaps = return_confmaps self.input_scale = input_scale + self.output_head_skeleton_num = output_head_skeleton_num def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Predict confidence maps and infer peak coordinates. @@ -72,6 +77,10 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ # Network forward pass. cms = self.torch_model(inputs["image"]) + if isinstance( + cms, list + ): # for multi-head-model, output is a list of confmaps for the different heads + cms = cms[self.output_head_skeleton_num] peak_points, peak_vals = find_global_peaks( cms.detach(), diff --git a/sleap_nn/inference/topdown.py b/sleap_nn/inference/topdown.py index 9464c9a6..1736ea02 100644 --- a/sleap_nn/inference/topdown.py +++ b/sleap_nn/inference/topdown.py @@ -220,6 +220,8 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: scaled_image = apply_pad_to_stride(scaled_image, self.max_stride) cms = self.torch_model(scaled_image) + if isinstance(cms, list): # only one head for centroid model + cms = cms[0] refined_peaks, peak_vals, peak_sample_inds, _ = find_local_peaks( cms.detach(), @@ -490,8 +492,7 @@ def __init__( self.max_stride = max_stride def forward( - self, - inputs: Dict[str, torch.Tensor], + self, inputs: Dict[str, torch.Tensor], output_head_skeleton_num: int = 0 ) -> Dict[str, torch.Tensor]: """Predict confidence maps and infer peak coordinates. @@ -502,6 +503,9 @@ def forward( inputs: Dictionary with keys: `"instance_image"`: Cropped images. Other keys will be passed down the pipeline. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. Returns: A dictionary of outputs with keys: @@ -524,6 +528,8 @@ def forward( input_image = apply_pad_to_stride(input_image, self.max_stride) cms = self.torch_model(input_image) + if isinstance(cms, list): + cms = cms[output_head_skeleton_num] peak_points, peak_vals = find_global_peaks( cms.detach(), @@ -576,17 +582,22 @@ class TopDownInferenceModel(L.LightningModule): or `None`. This layer takes as input the output of the centroid cropper (if CentroidCrop not None else the image is cropped with the InstanceCropper module) and outputs the detected peaks for the instances within each crop. + output_head_skeleton_num: Dataset number (as given in the config) indicating + which skeleton format to output. This parameter is only required for + multi-head model inference. """ def __init__( self, centroid_crop: Union[CentroidCrop, None], instance_peaks: Union[FindInstancePeaks, FindInstancePeaksGroundTruth], + output_head_skeleton_num: int = 0, ): """Initialize the class with Inference models.""" super().__init__() self.centroid_crop = centroid_crop self.instance_peaks = instance_peaks + self.output_head_skeleton_num = output_head_skeleton_num def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Predict instances for one batch of images. @@ -629,7 +640,7 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: self.instance_peaks.eval() peaks_output.append( self.instance_peaks( - i, + i, output_head_skeleton_num=self.output_head_skeleton_num ) ) return peaks_output diff --git a/sleap_nn/training/lightning_modules.py b/sleap_nn/training/lightning_modules.py index ebcb684e..c5252e44 100644 --- a/sleap_nn/training/lightning_modules.py +++ b/sleap_nn/training/lightning_modules.py @@ -584,3 +584,23 @@ def validation_step(self, batch, batch_idx): on_epoch=True, logger=True, ) + + +class MultiHeadLightningModule(L.LightningModule): + pass + + +class TopDownCenteredInstanceMultiHeadLightningModule(MultiHeadLightningModule): + pass + + +class SingleInstanceMultiHeadLightningModule(MultiHeadLightningModule): + pass + + +class CentroidMultiHeadLightningModule(MultiHeadLightningModule): + pass + + +class BottomUpMultiHeadLightningModule(MultiHeadLightningModule): + pass From 7581075ae4da57fc1f5ccbb8b57328ba6286a248 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 23 Apr 2025 17:29:56 -0700 Subject: [PATCH 3/8] Add lightning modules --- sleap_nn/inference/predictors.py | 22 +- sleap_nn/training/lightning_modules.py | 1112 +++++++++++++++++++++++- sleap_nn/training/utils.py | 62 ++ 3 files changed, 1175 insertions(+), 21 deletions(-) diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index de27b499..8c7a851b 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -674,10 +674,9 @@ def from_trained_models( skeletons_dict[k] = get_skeleton_from_config( centroid_config.data_config.skeletons[k] ) - centroid_model = CentroidMultiHeadModel.load_from_checkpoint( + centroid_model = CentroidMultiHeadLightningModule.load_from_checkpoint( checkpoint_path=ckpt_path, config=centroid_config, - skeletons_dict=skeletons_dict, model_type="centroid", backbone_type=centroid_backbone_type, ) @@ -745,14 +744,11 @@ def from_trained_models( skeletons_dict[k] = get_skeleton_from_config( confmap_config.data_config.skeletons[k] ) - confmap_model = ( - TopDownCenteredInstanceMultiHeadModel.load_from_checkpoint( - checkpoint_path=ckpt_path, - config=confmap_config, - skeletons_dict=skeletons_dict, - model_type="centered_instance", - backbone_type=centered_instance_backbone_type, - ) + confmap_model = TopDownCenteredInstanceMultiHeadLightningModule.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=confmap_config, + model_type="centered_instance", + backbone_type=centered_instance_backbone_type, ) skeletons = skeletons_dict[output_head_skeleton_num] else: @@ -1191,10 +1187,9 @@ def from_trained_models( skeletons_dict[k] = get_skeleton_from_config( confmap_config.data_config.skeletons[k] ) - confmap_model = SingleInstanceMultiHeadModel.load_from_checkpoint( + confmap_model = SingleInstanceMultiHeadLightningModule.load_from_checkpoint( checkpoint_path=ckpt_path, config=confmap_config, - skeletons_dict=skeletons_dict, model_type="single_instance", backbone_type=backbone_type, ) @@ -1666,10 +1661,9 @@ def from_trained_models( bottomup_config.data_config.skeletons[k] ) - bottomup_model = BottomUpMultiHeadModel.load_from_checkpoint( + bottomup_model = BottomUpMultiHeadLightningModule.load_from_checkpoint( checkpoint_path=ckpt_path, config=bottomup_config, - skeletons_dict=skeletons_dict, backbone_type=backbone_type, model_type="bottomup", ) diff --git a/sleap_nn/training/lightning_modules.py b/sleap_nn/training/lightning_modules.py index c5252e44..0bedeacf 100644 --- a/sleap_nn/training/lightning_modules.py +++ b/sleap_nn/training/lightning_modules.py @@ -30,9 +30,9 @@ from sleap_nn.inference.single_instance import SingleInstanceInferenceModel from sleap_nn.inference.bottomup import BottomUpInferenceModel from sleap_nn.inference.paf_grouping import PAFScorer -from sleap_nn.architectures.model import Model +from sleap_nn.architectures.model import Model, MultiHeadModel from loguru import logger -from sleap_nn.training.utils import xavier_init_weights +from sleap_nn.training.utils import xavier_init_weights, plot_pred_confmaps_peaks import matplotlib.pyplot as plt MODEL_WEIGHTS = { @@ -587,20 +587,1118 @@ def validation_step(self, batch, batch_idx): class MultiHeadLightningModule(L.LightningModule): - pass + """Base PyTorch Lightning Module for multi-head models. + + This class is a sub-class of Torch Lightning Module to configure the training and validation steps. + + Args: + config: OmegaConf dictionary which has the following: + (i) dataset_mapper: mapping between dataset numbers and dataset name. + (ii) data_config: data loading pre-processing configs. + (iii) model_config: backbone and head configs to be passed to `Model` class. + (iv) trainer_config: trainer configs like accelerator, optimiser params. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. + backbone_type: Backbone model. One of `unet`, `convnext` and `swint`. + """ + + def __init__( + self, + config: OmegaConf, + model_type: str, + backbone_type: str, + ): + """Initialise the configs and the model.""" + super().__init__() + + self.config = config + self.model_config = self.config.model_config + self.trainer_config = self.config.trainer_config + self.data_config = self.config.data_config + self.model_type = model_type + self.save_ckpt = self.config.trainer_config.save_ckpt + self.use_wandb = self.config.trainer_config.use_wandb + if self.save_ckpt: + self.results_path = ( + Path(self.config.trainer_config.save_ckpt_path) / "visualizations" + ) + if not Path(self.results_path).exists(): + Path(self.results_path).mkdir(parents=True, exist_ok=True) + self.backbone_type = backbone_type + self.pretrained_backbone_weights = ( + self.config.model_config.pretrained_backbone_weights + ) + self.pretrained_head_weights = self.config.model_config.pretrained_head_weights + self.in_channels = self.model_config.backbone_config[f"{self.backbone_type}"][ + "in_channels" + ] + self.input_expand_channels = self.in_channels + + # only for swint and convnext + if self.model_config.pre_trained_weights: + ckpt = MODEL_WEIGHTS[ + self.model_config.pre_trained_weights + ].DEFAULT.get_state_dict(progress=True, check_hash=True) + input_channels = ckpt["features.0.0.weight"].shape[-3] + if self.in_channels != input_channels: + self.input_expand_channels = input_channels + OmegaConf.update( + self.model_config, + f"backbone_config.{self.backbone_type}.in_channels", + input_channels, + ) + + self.model = MultiHeadModel( + backbone_type=self.backbone_type, + backbone_config=self.model_config.backbone_config[f"{self.backbone_type}"], + head_configs=self.model_config.head_configs[self.model_type], + model_type=self.model_type, + ) + + if ( + len(self.model_config.head_configs[self.model_type]) > 1 + ): # TODO: online mining for each dataset + self.loss_weights = [ + ( + self.model_config.head_configs[self.model_type][x][1].loss_weight + if self.model_config.head_configs[self.model_type][x][1].loss_weight + is not None + else 1.0 + ) + for x in self.model_config.head_configs[self.model_type] + ] + + self.training_loss = {} + self.val_loss = {} + self.learning_rate = {} + + # Initialization for encoder and decoder stacks. + if self.model_config.init_weights == "xavier": + self.model.apply(xavier_init_weights) + + self.automatic_optimization = False + + self.loss_func = nn.MSELoss() + + # Pre-trained weights for the encoder stack - only for swint and convnext + if self.model_config.pre_trained_weights: + self.model.backbone.enc.load_state_dict(ckpt, strict=False) + + # TODO: Handling different input channels + # Initializing backbone (encoder + decoder) with trained ckpts + if self.pretrained_backbone_weights is not None: + logger.info( + f"Loading backbone weights from `{self.pretrained_backbone_weights}` ..." + ) + ckpt = torch.load(self.pretrained_backbone_weights) + ckpt["state_dict"] = { + k: ckpt["state_dict"][k] + for k in ckpt["state_dict"].keys() + if ".backbone" in k + } + self.load_state_dict(ckpt["state_dict"], strict=False) + + # Initializing head layers with trained ckpts. + if self.pretrained_head_weights is not None: + logger.info( + f"Loading head weights from `{self.pretrained_head_weights}` ..." + ) + ckpt = torch.load(self.pretrained_head_weights) + ckpt["state_dict"] = { + k: ckpt["state_dict"][k] + for k in ckpt["state_dict"].keys() + if ".head_layers" in k + } + self.load_state_dict(ckpt["state_dict"], strict=False) + + def forward(self, img): + """Forward pass of the model.""" + pass + + def on_save_checkpoint(self, checkpoint): + """Configure checkpoint to save parameters.""" + # save the config to the checkpoint file + checkpoint["config"] = self.config + + def on_train_epoch_start(self): + """Configure the train timer at the beginning of each epoch.""" + self.train_start_time = time.time() + + def on_train_epoch_end(self): + """Configure the train timer at the end of every epoch.""" + train_time = time.time() - self.train_start_time + self.log( + "train_time", + train_time, + prog_bar=False, + on_step=False, + on_epoch=True, + logger=True, + ) + + def on_validation_epoch_start(self): + """Configure the val timer at the beginning of each epoch.""" + self.val_start_time = time.time() + + def on_validation_epoch_end(self): + """Configure the val timer at the end of every epoch.""" + val_time = time.time() - self.val_start_time + self.log( + "val_time", + val_time, + prog_bar=False, + on_step=False, + on_epoch=True, + logger=True, + ) + + def training_step(self, batch, batch_idx): + """Training step.""" + pass + + def validation_step(self, batch, batch_idx): + """Validation step.""" + pass + + def configure_optimizers(self): + """Configure optimiser and learning rate scheduler.""" + if self.trainer_config.optimizer_name == "Adam": + optim = torch.optim.Adam + elif self.trainer_config.optimizer_name == "AdamW": + optim = torch.optim.AdamW + + optimizer = optim( + self.parameters(), + lr=self.trainer_config.optimizer.lr, + amsgrad=self.trainer_config.optimizer.amsgrad, + ) + + scheduler = None + for k, v in self.trainer_config.lr_scheduler.items(): + if v is not None: + if k == "step_lr": + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer=optimizer, + step_size=self.trainer_config.lr_scheduler.step_lr.step_size, + gamma=self.trainer_config.lr_scheduler.step_lr.gamma, + ) + break + elif k == "reduce_lr_on_plateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + threshold=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.threshold, + threshold_mode=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.threshold_mode, + cooldown=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.cooldown, + patience=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.patience, + factor=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.factor, + min_lr=self.trainer_config.lr_scheduler.reduce_lr_on_plateau.min_lr, + ) + break + + if scheduler is None: + return { + "optimizer": optimizer, + } + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } class TopDownCenteredInstanceMultiHeadLightningModule(MultiHeadLightningModule): - pass + """Lightning Module for TopDownCenteredInstanceMultiHeadLightningModule Model. + + This is a subclass of the `MultiHeadLightningModule` to configure the training/ validation steps + and forward pass specific to TopDown Centered instance multi-head model. + + Args: + config: OmegaConf dictionary which has the following: + (i) dataset_mapper: mapping between dataset numbers and dataset name. + (ii) data_config: data loading pre-processing configs. + (iii) model_config: backbone and head configs to be passed to `Model` class. + (iv) trainer_config: trainer configs like accelerator, optimiser params. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. + backbone_type: Backbone model. One of `unet`, `convnext` and `swint`. + + """ + + def __init__( + self, + config: OmegaConf, + model_type: str, + backbone_type: str, + ): + """Initialise the configs and the model.""" + super().__init__( + config=config, + backbone_type=backbone_type, + model_type=model_type, + ) + self.inf_layer = FindInstancePeaks( + torch_model=self.forward, peak_threshold=0.2, return_confmaps=True + ) + + def on_train_epoch_start(self): + """Configure the train timer at the beginning of each epoch.""" + # add eval + if self.config.trainer_config.log_inf_epochs is not None: + if ( + self.current_epoch > 0 + and self.global_rank == 0 + and (self.current_epoch % self.config.trainer_config.log_inf_epochs) + == 0 + ): + img_array = [] + for d_num in self.config.dataset_mapper: + sample = next(iter(self.trainer.val_dataloaders[d_num])) + sample["eff_scale"] = torch.ones(sample["video_idx"].shape) + for k, v in sample.items(): + sample[k] = v.to(device=self.device) + self.inf_layer.output_stride = self.config.model_config.head_configs.centered_instance.confmaps[ + d_num + ][ + "output_stride" + ] + output = self.inf_layer(sample, output_head_skeleton_num=d_num) + batch_idx = 0 + + # plot predictions on sample image + if self.use_wandb or self.save_ckpt: + peaks = output["pred_instance_peaks"][batch_idx].cpu().numpy() + gt_instances = sample["instance"][batch_idx, 0].cpu().numpy() + img = output["instance_image"][batch_idx, 0].cpu().numpy() + confmaps = output["pred_confmaps"][batch_idx].cpu().numpy() + fig = plot_pred_confmaps_peaks( + img=img, + confmaps=confmaps, + peaks=np.expand_dims(peaks, axis=0), + gt_instances=np.expand_dims(gt_instances, axis=0), + plot_title=f"{self.config.dataset_mapper[d_num]}", + ) + + if self.save_ckpt: + curr_results_path = ( + Path(self.config.trainer_config.save_ckpt_path) + / "visualizations" + / f"epoch_{self.current_epoch}" + ) + if not Path(curr_results_path).exists(): + Path(curr_results_path).mkdir(parents=True, exist_ok=True) + fig.savefig( + (Path(curr_results_path) / f"pred_on_{d_num}").as_posix(), + bbox_inches="tight", + ) + + if self.use_wandb: + fig.canvas.draw() + img = Image.frombytes( + "RGB", + fig.canvas.get_width_height(), + fig.canvas.tostring_rgb(), + ) + + img_array.append(wandb.Image(img)) + + plt.close(fig) + + if self.use_wandb and img_array: + # wandb logging metrics in table + + wandb_table = wandb.Table( + columns=[ + "epoch", + "Predictions on test set", + ], + data=[[self.current_epoch, img_array]], + ) + wandb.log({"Performance": wandb_table}) + + self.train_start_time = time.time() + + def forward(self, img): + """Forward pass of the model.""" + img = torch.squeeze(img, dim=1).to(self.device) + return self.model(img)["CenteredInstanceConfmapsHead"] + + def training_step(self, batch, batch_idx): + """Training step.""" + loss = 0 + opt = self.optimizers() + opt.zero_grad() + for d_num in batch.keys(): + batch_data = batch[d_num] + X, y = torch.squeeze(batch_data["instance_image"], dim=1).to( + self.device + ), torch.squeeze(batch_data["confidence_maps"], dim=1) + + output = self.model(X)["CenteredInstanceConfmapsHead"] + + for h_num in batch.keys(): + if d_num != h_num: + with torch.no_grad(): + output[h_num] = output[h_num].detach() + + y_preds = output[d_num] + curr_loss = 1.0 * self.loss_func(y_preds, y) + loss += curr_loss + + self.manual_backward(curr_loss, retain_graph=True) + + self.log( + f"train_loss_on_head_{d_num}", + curr_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + self.log( + f"train_loss", + loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + opt.step() + + return loss + + def validation_step(self, batch, batch_idx): + """Perform validation step.""" + total_loss = 0 + for d_num in batch.keys(): + X, y = torch.squeeze(batch[d_num]["instance_image"], dim=1).to( + self.device + ), torch.squeeze(batch[d_num]["confidence_maps"], dim=1) + + y_preds = self.model(X)["CenteredInstanceConfmapsHead"][d_num] + curr_loss = 1.0 * nn.MSELoss()(y_preds, y) + total_loss += curr_loss + + self.log( + f"val_loss_on_head_{d_num}", + curr_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + self.log( + f"val_loss", + total_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + lr = self.optimizers().optimizer.param_groups[0]["lr"] + self.log( + "learning_rate", + lr, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) class SingleInstanceMultiHeadLightningModule(MultiHeadLightningModule): - pass + """Lightning Module for SingleInstanceMultiHeadLightningModule Model. + + This is a subclass of the `MultiHeadLightningModule` to configure the training/ validation steps + and forward pass specific to single-instance multi-head model. + + Args: + config: OmegaConf dictionary which has the following: + (i) dataset_mapper: mapping between dataset numbers and dataset name. + (ii) data_config: data loading pre-processing configs. + (iii) model_config: backbone and head configs to be passed to `Model` class. + (iv) trainer_config: trainer configs like accelerator, optimiser params. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. + backbone_type: Backbone model. One of `unet`, `convnext` and `swint`. + + """ + + def __init__( + self, + config: OmegaConf, + model_type: str, + backbone_type: str, + ): + """Initialise the configs and the model.""" + super().__init__( + config=config, + backbone_type=backbone_type, + model_type=model_type, + ) + self.single_instance_inf_layer = SingleInstanceInferenceModel( + torch_model=self.forward, + peak_threshold=0.2, + input_scale=1.0, + return_confmaps=True, + ) + + def forward(self, img): + """Forward pass of the model.""" + img = torch.squeeze(img, dim=1).to(self.device) + return self.model(img)["SingleInstanceConfmapsHead"] + + def on_train_epoch_start(self): + """Configure the train timer at the beginning of each epoch.""" + # add eval + if self.config.trainer_config.log_inf_epochs is not None: + if ( + self.current_epoch > 0 + and self.global_rank == 0 + and (self.current_epoch % self.config.trainer_config.log_inf_epochs) + == 0 + ): + img_array = [] + for d_num in self.config.dataset_mapper: + sample = next(iter(self.trainer.val_dataloaders[d_num])) + sample["eff_scale"] = torch.ones(sample["video_idx"].shape) + for k, v in sample.items(): + sample[k] = v.to(device=self.device) + self.single_instance_inf_layer.output_head_skeleton_num = d_num + self.single_instance_inf_layer.output_stride = ( + self.config.model_config.head_configs.single_instance.confmaps[ + d_num + ]["output_stride"] + ) + output = self.single_instance_inf_layer(sample) + batch_idx = 0 + + # plot predictions on sample image + if self.use_wandb or self.save_ckpt: + peaks = output["pred_instance_peaks"][batch_idx].cpu().numpy() + img = output["image"][batch_idx, 0].cpu().numpy() + gt_instances = sample["instances"][batch_idx, 0].cpu().numpy() + confmaps = output["pred_confmaps"][batch_idx].cpu().numpy() + fig = plot_pred_confmaps_peaks( + img=img, + confmaps=confmaps, + peaks=np.expand_dims(peaks, axis=0), + gt_instances=np.expand_dims(gt_instances, axis=0), + plot_title=f"{self.config.dataset_mapper[d_num]}", + ) + + if self.save_ckpt: + curr_results_path = ( + Path(self.config.trainer_config.save_ckpt_path) + / "visualizations" + / f"epoch_{self.current_epoch}" + ) + if not Path(curr_results_path).exists(): + Path(curr_results_path).mkdir(parents=True, exist_ok=True) + fig.savefig( + (Path(curr_results_path) / f"pred_on_{d_num}").as_posix(), + bbox_inches="tight", + ) + + if self.use_wandb: + fig.canvas.draw() + img = Image.frombytes( + "RGB", + fig.canvas.get_width_height(), + fig.canvas.tostring_rgb(), + ) + + img_array.append(wandb.Image(img)) + + plt.close(fig) + + if self.use_wandb and img_array: + # wandb logging metrics in table + + wandb_table = wandb.Table( + columns=[ + "epoch", + "Predictions on test set", + ], + data=[[self.current_epoch, img_array]], + ) + wandb.log({"Performance": wandb_table}) + + self.train_start_time = time.time() + + def training_step(self, batch, batch_idx): + """Training step.""" + loss = 0 + opt = self.optimizers() + opt.zero_grad() + for d_num in batch.keys(): + batch_data = batch[d_num] + X, y = torch.squeeze(batch_data["image"], dim=1).to( + self.device + ), torch.squeeze(batch_data["confidence_maps"], dim=1) + + output = self.model(X)["SingleInstanceConfmapsHead"] + + for h_num in batch.keys(): + if d_num != h_num: + with torch.no_grad(): + output[h_num] = output[h_num].detach() + + y_preds = output[d_num] + curr_loss = 1.0 * self.loss_func(y_preds, y) + loss += curr_loss + + self.manual_backward(curr_loss, retain_graph=True) + + self.log( + f"train_loss_on_head_{d_num}", + curr_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + self.log( + f"train_loss", + loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + opt.step() + + return loss + + def validation_step(self, batch, batch_idx): + """Perform validation step.""" + total_loss = 0 + for d_num in batch.keys(): + X, y = torch.squeeze(batch[d_num]["image"], dim=1).to( + self.device + ), torch.squeeze(batch[d_num]["confidence_maps"], dim=1) + + y_preds = self.model(X)["SingleInstanceConfmapsHead"][d_num] + curr_loss = 1.0 * nn.MSELoss()(y_preds, y) + total_loss += curr_loss + + self.log( + f"val_loss_on_head_{d_num}", + curr_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + self.log( + f"val_loss", + total_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + lr = self.optimizers().optimizer.param_groups[0]["lr"] + self.log( + "learning_rate", + lr, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) class CentroidMultiHeadLightningModule(MultiHeadLightningModule): - pass + """Lightning Module for CentroidMultiHeadLightningModule Model. + + This is a subclass of the `MultiHeadLightningModule` to configure the training/ validation steps + and forward pass specific to centroid multi-head model. + + Args: + config: OmegaConf dictionary which has the following: + (i) dataset_mapper: mapping between dataset numbers and dataset name. + (ii) data_config: data loading pre-processing configs. + (iii) model_config: backbone and head configs to be passed to `Model` class. + (iv) trainer_config: trainer configs like accelerator, optimiser params. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. + backbone_type: Backbone model. One of `unet`, `convnext` and `swint`. + + """ + + def __init__( + self, + config: OmegaConf, + model_type: str, + backbone_type: str, + ): + """Initialise the configs and the model.""" + super().__init__( + config=config, + backbone_type=backbone_type, + model_type=model_type, + ) + self.centroid_inf_layer = CentroidCrop( + torch_model=self.forward, + peak_threshold=0.2, + return_confmaps=True, + output_stride=self.config.model_config.head_configs.centroid.confmaps[0][ + "output_stride" + ], + input_scale=1.0, + ) + + def forward(self, img): + """Forward pass of the model.""" + img = torch.squeeze(img, dim=1).to(self.device) + return self.model(img)["CentroidConfmapsHead"] + + def on_train_epoch_start(self): + """Configure the train timer at the beginning of each epoch.""" + # add eval + if self.config.trainer_config.log_inf_epochs is not None: + if ( + self.current_epoch > 0 + and self.global_rank == 0 + and (self.current_epoch % self.config.trainer_config.log_inf_epochs) + == 0 + ): + img_array = [] + for d_num in self.config.dataset_mapper: + sample = next(iter(self.trainer.val_dataloaders[d_num])) + gt_centroids = sample["centroids"] + sample["eff_scale"] = torch.ones(sample["video_idx"].shape) + for k, v in sample.items(): + sample[k] = v.to(device=self.device) + output = self.centroid_inf_layer(sample) + batch_idx = 1 + + # plot predictions on sample image + if self.use_wandb or self.save_ckpt: + centroids = output["centroids"][batch_idx, 0].cpu().numpy() + img = output["image"][batch_idx, 0].cpu().numpy() + confmaps = ( + output["pred_centroid_confmaps"][batch_idx].cpu().numpy() + ) + gt_centroids = gt_centroids[batch_idx, 0].cpu().numpy() + fig = plot_pred_confmaps_peaks( + img=img, + confmaps=confmaps, + peaks=np.expand_dims(centroids, axis=0), + gt_instances=np.expand_dims(gt_centroids, axis=0), + plot_title=f"{self.config.dataset_mapper[d_num]}", + ) + if self.save_ckpt: + curr_results_path = ( + Path(self.config.trainer_config.save_ckpt_path) + / "visualizations" + / f"epoch_{self.current_epoch}" + ) + if not Path(curr_results_path).exists(): + Path(curr_results_path).mkdir(parents=True, exist_ok=True) + fig.savefig( + (Path(curr_results_path) / f"pred_on_{d_num}").as_posix(), + bbox_inches="tight", + ) + + if self.use_wandb: + fig.canvas.draw() + img = Image.frombytes( + "RGB", + fig.canvas.get_width_height(), + fig.canvas.tostring_rgb(), + ) + + img_array.append(wandb.Image(img)) + + plt.close(fig) + + if self.use_wandb and img_array: + # wandb logging metrics in table + + wandb_table = wandb.Table( + columns=[ + "epoch", + "Predictions on test set", + ], + data=[[self.current_epoch, img_array]], + ) + wandb.log({"Performance": wandb_table}) + + self.train_start_time = time.time() + + def training_step(self, batch, batch_idx): + """Training step.""" + loss = 0 + opt = self.optimizers() + opt.zero_grad() + for d_num in batch.keys(): + batch_data = batch[d_num] + X, y = torch.squeeze(batch_data["image"], dim=1).to( + self.device + ), torch.squeeze(batch_data["centroids_confidence_maps"], dim=1).to( + self.device + ) + + output = self.model(X)["CentroidConfmapsHead"] + + y_preds = output[0] + curr_loss = 1.0 * self.loss_func(y_preds, y) + loss += curr_loss + + self.manual_backward(curr_loss, retain_graph=True) + + self.log( + f"train_loss", + loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + opt.step() + + return loss + + def validation_step(self, batch, batch_idx): + """Perform validation step.""" + total_loss = 0 + for d_num in batch.keys(): + X, y = torch.squeeze(batch[d_num]["image"], dim=1).to( + self.device + ), torch.squeeze(batch[d_num]["centroids_confidence_maps"], dim=1).to( + self.device + ) + + y_preds = self.model(X)["CentroidConfmapsHead"][0] + curr_loss = 1.0 * nn.MSELoss()(y_preds, y) + total_loss += curr_loss + + self.log( + f"val_loss_on_head_{d_num}", + curr_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + self.log( + f"val_loss", + total_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + lr = self.optimizers().optimizer.param_groups[0]["lr"] + self.log( + "learning_rate", + lr, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) class BottomUpMultiHeadLightningModule(MultiHeadLightningModule): - pass + """Lightning Module for BottomUpMultiHeadLightningModule Model. + + This is a subclass of the `MultiHeadLightningModule` to configure the training/ validation steps + and forward pass specific to bottom up multi-head model. + + Args: + config: OmegaConf dictionary which has the following: + (i) dataset_mapper: mapping between dataset numbers and dataset name. + (ii) data_config: data loading pre-processing configs. + (iii) model_config: backbone and head configs to be passed to `Model` class. + (iv) trainer_config: trainer configs like accelerator, optimiser params. + skeletons_dict: Dict of `sio.Skeleton` objects from the input `.slp` file for all the datasets. + model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`. + backbone_type: Backbone model. One of `unet`, `convnext` and `swint`. + + """ + + def __init__( + self, + config: OmegaConf, + skeletons_dict: List[sio.Skeleton], + model_type: str, + backbone_type: str, + ): + """Initialise the configs and the model.""" + super().__init__( + config=config, + skeletons_dict=skeletons_dict, + backbone_type=backbone_type, + model_type=model_type, + ) + paf_scorer = PAFScorer( + part_names=self.config.model_config.head_configs.bottomup.confmaps[0][ + "part_names" + ], + edges=self.config.model_config.head_configs.bottomup.pafs[0]["edges"], + pafs_stride=self.config.model_config.head_configs.bottomup.pafs[0][ + "output_stride" + ], + ) + self.inf_layer = BottomUpInferenceModel( + torch_model=self.forward, + paf_scorer=paf_scorer, + peak_threshold=0.2, + input_scale=1.0, + return_confmaps=True, + cms_output_stride=self.config.model_config.head_configs.bottomup.confmaps[ + 0 + ]["output_stride"], + pafs_output_stride=self.config.model_config.head_configs.bottomup.pafs[0][ + "output_stride" + ], + ) + + def on_train_epoch_start(self): + """Configure the train timer at the beginning of each epoch.""" + # add eval + if self.config.trainer_config.log_inf_epochs is not None: + if ( + self.current_epoch > 0 + and self.global_rank == 0 + and (self.current_epoch % self.config.trainer_config.log_inf_epochs) + == 0 + ): + img_array = [] + for d_num in self.config.dataset_mapper: + sample = next(iter(self.trainer.val_dataloaders[d_num])) + sample["eff_scale"] = torch.ones(sample["video_idx"].shape) + for k, v in sample.items(): + sample[k] = v.to(device=self.device) + + paf_scorer = PAFScorer( + part_names=self.config.model_config.head_configs.bottomup.confmaps[ + d_num + ][ + "part_names" + ], + edges=self.config.model_config.head_configs.bottomup.pafs[ + d_num + ]["edges"], + pafs_stride=self.config.model_config.head_configs.bottomup.pafs[ + d_num + ]["output_stride"], + ) + self.inf_layer.paf_scorer = paf_scorer + self.inf_layer.cms_output_stride = ( + self.config.model_config.head_configs.bottomup.confmaps[d_num][ + "output_stride" + ] + ) + self.inf_layer.pafs_output_stride = ( + self.config.model_config.head_configs.bottomup.pafs[d_num][ + "output_stride" + ] + ) + + output = self.inf_layer(sample, output_head_skeleton_num=d_num) + batch_idx = 0 + + # plot predictions on sample image + if self.use_wandb or self.save_ckpt: + peaks = output["pred_instance_peaks"][batch_idx].cpu().numpy() + img = output["image"][batch_idx, 0].cpu().numpy() + confmaps = output["pred_confmaps"][batch_idx].cpu().numpy() + gt_instances = sample["instances"][batch_idx, 0].cpu().numpy() + fig = plot_pred_confmaps_peaks( + img=img, + confmaps=confmaps, + peaks=peaks, + gt_instances=gt_instances, + plot_title=f"{self.config.dataset_mapper[d_num]}", + ) + plt.imshow( + output["image"][batch_idx, 0] + .cpu() + .numpy() + .transpose(1, 2, 0) + ) + plt.plot( + peaks[:, 0], + peaks[:, 1], + "rx", + label="Predicted", + ) + plt.legend() + plt.title(f"{self.config.dataset_mapper[d_num]}") + plt.axis("off") + + if self.save_ckpt: + curr_results_path = ( + Path(self.config.trainer_config.save_ckpt_path) + / "visualizations" + / f"epoch_{self.current_epoch}" + ) + if not Path(curr_results_path).exists(): + Path(curr_results_path).mkdir(parents=True, exist_ok=True) + plt.savefig( + (Path(curr_results_path) / f"pred_on_{d_num}").as_posix() + ) + + if self.use_wandb: + fig = plt.gcf() + fig.canvas.draw() + img = Image.frombytes( + "RGB", + fig.canvas.get_width_height(), + fig.canvas.tostring_rgb(), + ) + + img_array.append(wandb.Image(img)) + + plt.close(fig) + + if self.use_wandb and img_array: + # wandb logging metrics in table + + wandb_table = wandb.Table( + columns=[ + "epoch", + "Predictions on test set", + ], + data=[[self.current_epoch, img_array]], + ) + wandb.log({"Performance": wandb_table}) + + self.train_start_time = time.time() + + def forward(self, img): + """Forward pass of the model.""" + img = torch.squeeze(img, dim=1).to(self.device) + output = self.model(img) + return { + "MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"], + "PartAffinityFieldsHead": output["PartAffinityFieldsHead"], + } + + def training_step(self, batch, batch_idx): + """Training step.""" + loss = 0 + opt = self.optimizers() + opt.zero_grad() + for d_num in batch.keys(): + batch_data = batch[d_num] + X = torch.squeeze(batch_data["image"], dim=1) + y_confmap = torch.squeeze(batch_data["confidence_maps"], dim=1).to( + self.device + ) + y_paf = batch_data["part_affinity_fields"] + + output = self.model(X) + output_confmaps = output["MultiInstanceConfmapsHead"] + output_pafs = output["PartAffinityFieldsHead"] + + for h_num in batch.keys(): + if d_num != h_num: + with torch.no_grad(): + output_confmaps[h_num] = output_confmaps[h_num].detach() + output_pafs[h_num] = output_pafs[h_num].detach() + + losses = { + "MultiInstanceConfmapsHead": nn.MSELoss()( + output_confmaps[d_num], y_confmap + ), + "PartAffinityFieldsHead": nn.MSELoss()(output_pafs[d_num], y_paf), + } + curr_loss = 1.0 * sum( + [s * losses[t] for s, t in zip(self.loss_weights, losses)] + ) + + loss += curr_loss + + self.manual_backward(curr_loss, retain_graph=True) + + self.log( + f"train_loss_on_head_{d_num}", + curr_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + self.log( + f"train_loss", + loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + opt.step() + + return loss + + def validation_step(self, batch, batch_idx): + """Perform validation step.""" + total_loss = 0 + for d_num in batch.keys(): + batch_data = batch[d_num] + X = torch.squeeze(batch_data["image"], dim=1) + y_confmap = torch.squeeze(batch_data["confidence_maps"], dim=1).to( + self.device + ) + y_paf = batch_data["part_affinity_fields"] + + output = self.model(X) + output_confmaps = output["MultiInstanceConfmapsHead"] + output_pafs = output["PartAffinityFieldsHead"] + + losses = { + "MultiInstanceConfmapsHead": nn.MSELoss()( + output_confmaps[d_num], y_confmap + ), + "PartAffinityFieldsHead": nn.MSELoss()(output_pafs[d_num], y_paf), + } + + curr_loss = 1.0 * sum( + [s * losses[t] for s, t in zip(self.loss_weights, losses)] + ) + total_loss += curr_loss + + self.log( + f"val_loss_on_head_{d_num}", + curr_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + self.log( + f"val_loss", + total_loss, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) + + lr = self.optimizers().optimizer.param_groups[0]["lr"] + self.log( + "learning_rate", + lr, + prog_bar=True, + on_step=False, + on_epoch=True, + logger=True, + ) diff --git a/sleap_nn/training/utils.py b/sleap_nn/training/utils.py index b301f630..6e666321 100644 --- a/sleap_nn/training/utils.py +++ b/sleap_nn/training/utils.py @@ -1,6 +1,7 @@ """Miscellaneous utility functions for training.""" import numpy as np +import matplotlib.pyplot as plt from loguru import logger from torch import nn import torch.distributed as dist @@ -27,6 +28,67 @@ def xavier_init_weights(x): nn.init.constant_(x.bias, 0) +def plot_pred_confmaps_peaks( + img: np.ndarray, + confmaps: np.ndarray, + peaks: Optional[np.ndarray] = None, + gt_instances: Optional[np.ndarray] = None, + plot_title: Optional[str] = None, +): + """Plot the predicted peaks on input image overlayed with confmaps. + + Args: + img: Input image with shape (channel, height, width). + confmaps: Output confmaps with shape (num_nodes, confmap_height, confmap_width). + peaks: Predicted keypoints with shape (num_instances, num_nodes, 2). + gt_instances: Ground-truth keypoints with shape (num_instances, num_nodes, 2). + plot_title: Title for the plot. + """ + img_h, img_w = img.shape[-2:] + img = img.transpose(1, 2, 0) # (C, H, W) -> (H, W, C) + + confmaps = confmaps.transpose(1, 2, 0) # (C, H, W) -> (H, W, C) + confmaps = np.max(np.abs(confmaps), axis=-1) + + fig, ax = plt.subplots() + ax.axis("off") + + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + + ax.imshow(img) + + ax.imshow(confmaps, alpha=0.5, extent=[0, img_w, img_h, 0]) + + if gt_instances is not None: + for instance in gt_instances: + ax.plot( + instance[:, 0], + instance[:, 1], + "go", + markersize=8, + markeredgewidth=2, + label="GT keypoints", + ) + + if peaks is not None: + for peak in peaks: + ax.plot( + peak[:, 0], + peak[:, 1], + "rx", + markersize=8, + markeredgewidth=2, + label="Predicted peaks", + ) + + if plot_title is not None: + ax.set_title(f"{plot_title}") + + ax.legend() + + return fig + + def check_memory( labels: sio.Labels, max_hw: Tuple[int, int], From c18b2f16ba8212c195d768576cb6cbb4103e4d22 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 23 Apr 2025 18:16:31 -0700 Subject: [PATCH 4/8] Add trainer module --- docs/config_slumbr.yaml | 2 +- sleap_nn/training/model_trainer.py | 709 ++++++++++++++++++++++++++++- 2 files changed, 709 insertions(+), 2 deletions(-) diff --git a/docs/config_slumbr.yaml b/docs/config_slumbr.yaml index 82582b14..74307488 100644 --- a/docs/config_slumbr.yaml +++ b/docs/config_slumbr.yaml @@ -16,7 +16,7 @@ data_config: 0: minimal_instance.pkg.slp 1: minimal_instance.pkg.slp 2: minimal_instance.pkg.slp - data_pipeline_fw: torch_dataset_cache_img_disk + data_pipeline_fw: torch_dataset_cache_img_disk # one of `torch_dataset`, `torch_dataset_cache_img_memory`, `torch_dataset_cache_img_disk` cache_img_path: ./img_dir/ use_existing_imgs: 0: diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 5850d93c..946d7297 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -6,7 +6,7 @@ import shutil import subprocess import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler import torch.distributed as dist import sleap_io as sio from omegaconf import DictConfig, OmegaConf @@ -36,6 +36,7 @@ ConvNeXt_Small_Weights, ConvNeXt_Large_Weights, ) +from lightning.pytorch.utilities import CombinedLoader import sleap_io as sio from sleap_nn.data.custom_datasets import ( BottomUpDataset, @@ -78,6 +79,10 @@ CentroidModel, TopDownCenteredInstanceModel, SingleInstanceModel, + BottomUpMultiHeadLightningModule, + CentroidMultiHeadLightningModule, + TopDownCenteredInstanceMultiHeadLightningModule, + SingleInstanceMultiHeadLightningModule, ) from sleap_nn.config.training_job_config import verify_training_cfg @@ -908,3 +913,705 @@ def train(self): (Path(self.val_litdata_chunks_path)).as_posix(), ignore_errors=True, ) + + +class MultiHeadModelTrainer: + """Train sleap-nn multi-head model using PyTorch Lightning. + + This class is used to train a multi-head model and save the model checkpoints/ logs with options to logging + with wandb and csvlogger. + + Args: + config: OmegaConf dictionary which has the following: + (i) dataset_mapper: mapping between dataset numbers and dataset name. + (ii) data_config: data loading pre-processing configs. + (iii) model_config: backbone and head configs to be passed to `Model` class. + (iv) trainer_config: trainer configs like accelerator, optimiser params. + + """ + + def __init__( + self, + config: DictConfig, + ): + """Initialise the class with configs and set the seed and device as class attributes.""" + self.config = config + self.data_pipeline_fw = self.config.data_config.data_pipeline_fw + self.use_existing_imgs = self.config.data_config.use_existing_imgs + self.user_instances_only = OmegaConf.select( + self.config, "data_config.user_instances_only", default=True + ) + + # Get ckpt dir path + self.dir_path = self.config.trainer_config.save_ckpt_path + if self.dir_path is None: + self.dir_path = "." + + if not Path(self.dir_path).exists(): + try: + Path(self.dir_path).mkdir(parents=True, exist_ok=True) + except OSError as e: + message = f"Cannot create a new folder in {self.dir_path}. Check the permissions to the given Checkpoint directory. \n {e}" + logger.error(message) + raise OSError(message) + + if ( + self.data_pipeline_fw == "torch_dataset" + or self.data_pipeline_fw == "torch_dataset_cache_img_memory" + or self.data_pipeline_fw == "torch_dataset_cache_img_disk" + ): + + self.cache_img = ( + self.data_pipeline_fw.split("_")[-1] + if "cache_img" in self.data_pipeline_fw + else None + ) + + # get cache imgs path + if self.cache_img == "disk": + self.cache_img_path = ( + Path(self.config.data_config.cache_img_path) + if self.config.data_config.cache_img_path is not None + else Path(self.dir_path) + ) + self.train_cache_img_paths = {} + self.val_cache_img_paths = {} + + for d_num, d_name in self.config.dataset_mapper.items(): + self.train_cache_img_paths[d_num] = ( + Path(self.cache_img_path) / f"{d_name}" / "train_imgs" + ) + self.val_cache_img_paths[d_num] = ( + Path(self.cache_img_path) / f"{d_name}" / "val_imgs" + ) + + if self.use_existing_imgs[d_num]: + if not ( + self.train_cache_img_paths[d_num].exists() + and self.train_cache_img_paths[d_num].is_dir() + and any(self.train_cache_img_paths[d_num].glob("*.jpg")) + ): + message = f"There are no images in the path: {self.train_cache_img_paths[d_num]}" + logger.error(message) + raise Exception(message) + + if not ( + self.val_cache_img_paths[d_num].exists() + and self.val_cache_img_paths[d_num].is_dir() + and any(self.val_cache_img_paths[d_num].glob("*.npz")) + ): + message = f"There are no images in the path: {self.val_cache_img_paths[d_num]}" + logger.error(message) + raise Exception(message) + + else: + raise Exception(f"{self.data_pipeline_fw} is not supported!") + + self.seed = self.config.trainer_config.seed + self.steps_per_epochs = self.config.trainer_config.steps_per_epoch + if self.steps_per_epochs is None: + self.steps_per_epochs = {} + + # initialize attributes + self.model = None + self.train_labels, self.val_labels = {}, {} + self.train_datasets, self.val_datasets = {}, {} + self.train_data_loaders, self.val_data_loaders = {}, {} + self.trainer = None + self.edge_inds = {} + self.max_heights = {} + self.max_widths = {} + self.crop_hws = {} + self.skeletons_dict = {} + OmegaConf.update(self.config.data_config, f"skeletons", {}) + + # check which backbone architecture + for k, v in self.config.model_config.backbone_config.items(): + if v is not None: + self.backbone_type = k + break + + # check which head type to choose the model + for k, v in self.config.model_config.head_configs.items(): + if v is not None: + self.model_type = k + break + + OmegaConf.save(config=self.config, f=f"{self.dir_path}/initial_config.yaml") + + # set seed + torch.manual_seed(self.seed) + + self.max_stride = self.config.model_config.backbone_config[ + f"{self.backbone_type}" + ]["max_stride"] + + # set skeletons, compute preprocessing params from labels file + for d_num, d_name in self.config.dataset_mapper.items(): + self.train_labels[d_num] = sio.load_slp( + self.config.data_config.train_labels_path[d_num] + ) + self.val_labels[d_num] = sio.load_slp( + self.config.data_config.val_labels_path[d_num] + ) + + if self.config.data_config.preprocessing.scale[d_num] is None: + self.config.data_config.preprocessing.scale[d_num] = 1.0 + + self.skeletons_dict[d_num] = self.train_labels[d_num].skeletons + + # save the skeleton in the config + for skl in self.skeletons_dict[d_num]: + if skl.symmetries: + symm = [list(s.nodes) for s in skl.symmetries] + else: + symm = None + skl_name = skl.name if skl.name is not None else "skeleton-0" + self.config["data_config"]["skeletons"][d_num] = { + skl_name: { + "nodes": skl.nodes, + "edges": skl.edges, + "symmetries": symm, + } + } + + # if edges and part names aren't set in config, get it from `sio.Labels` object. + head_configs = self.config.model_config.head_configs[self.model_type] + for key in head_configs: + if "part_names" in head_configs[key][d_num].keys(): + if head_configs[key][d_num]["part_names"] is None: + part_names = [ + x.name for x in self.skeletons_dict[d_num][0].nodes + ] + self.config.model_config.head_configs[self.model_type][key][ + d_num + ]["part_names"] = part_names + + if "edges" in head_configs[key][d_num].keys(): + if head_configs[key][d_num]["edges"] is None: + edges = [ + (x.source.name, x.destination.name) + for x in self.skeletons_dict[d_num][0].edges + ] + self.config.model_config.head_configs[self.model_type][key][ + d_num + ]["edges"] = edges + + self.edge_inds[d_num] = self.train_labels[d_num].skeletons[0].edge_inds + self.max_heights[d_num], self.max_widths[d_num] = get_max_height_width( + self.train_labels[d_num] + ) + if ( + self.config.data_config.preprocessing.max_height[d_num] is None + and self.config.data_config.preprocessing.max_width[d_num] is None + ): + self.config.data_config.preprocessing.max_height[d_num] = ( + self.max_heights[d_num] + ) + self.config.data_config.preprocessing.max_width[d_num] = ( + self.max_widths[d_num] + ) + else: + self.max_heights[d_num] = ( + self.config.data_config.preprocessing.max_height[d_num] + ) + self.max_widths[d_num] = ( + self.config.data_config.preprocessing.max_width[d_num] + ) + + if self.model_type == "centered_instance": + # compute crop size + self.crop_hws[d_num] = self.config.data_config.preprocessing.crop_hw[ + d_num + ] + if self.crop_hws[d_num] is None: + + min_crop_size = ( + self.config.data_config.preprocessing.min_crop_size + if "min_crop_size" in self.config.data_config.preprocessing + else None + ) + crop_size = find_instance_crop_size( + self.train_labels[d_num], + maximum_stride=self.max_stride, + min_crop_size=min_crop_size, + input_scaling=self.config.data_config.preprocessing.scale[ + d_num + ], + ) + self.crop_hws[d_num] = crop_size + self.config.data_config.preprocessing.crop_hw[d_num] = ( + self.crop_hws[d_num], + self.crop_hws[d_num], + ) + else: + self.crop_hws[d_num] = self.crop_hws[d_num][0] + + OmegaConf.save(config=self.config, f=f"{self.dir_path}/training_config.yaml") + + def _create_data_loaders_torch_dataset(self, d_num): + """Create a torch DataLoader for train, validation and test sets using the data_config.""" + if self.data_pipeline_fw == "torch_dataset_cache_img_memory": + train_cache_memory = check_memory( + self.train_labels[d_num], + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + model_type=self.model_type, + input_scaling=self.config.data_config.preprocessing.scale[d_num], + crop_size=self.crop_hws[d_num] if self.crop_hw != -1 else None, + ) + val_cache_memory = check_memory( + self.val_labels[d_num], + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + model_type=self.model_type, + input_scaling=self.config.data_config.preprocessing.scale[d_num], + crop_size=self.crop_hws[d_num] if self.crop_hw != -1 else None, + ) + total_cache_memory = train_cache_memory + val_cache_memory + total_cache_memory += 0.1 * total_cache_memory # memory required in bytes + available_memory = ( + psutil.virtual_memory().available + ) # available memory in bytes + + if total_cache_memory > available_memory: + raise Exception( + f"Insufficient memory for in-memory caching. Use disk caching instead." + ) + + if self.model_type == "bottomup": + self.train_datasets[d_num] = BottomUpDataset( + labels=self.train_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.bottomup.confmaps[ + d_num + ], + pafs_head_config=self.config.model_config.head_configs.bottomup.pafs[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=self.config.data_config.augmentation_config, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=self.config.data_config.use_augmentations_train, + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.train_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + self.val_datasets[d_num] = BottomUpDataset( + labels=self.val_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.bottomup.confmaps[ + d_num + ], + pafs_head_config=self.config.model_config.head_configs.bottomup.pafs[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=None, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=False, + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.val_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + + elif self.model_type == "centered_instance": + self.train_datasets[d_num] = CenteredInstanceDataset( + labels=self.train_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.centered_instance.confmaps[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=self.config.data_config.augmentation_config, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=self.config.data_config.use_augmentations_train, + crop_hw=(self.crop_hws[d_num], self.crop_hws[d_num]), + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.train_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + self.val_datasets[d_num] = CenteredInstanceDataset( + labels=self.val_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.centered_instance.confmaps[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=None, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=False, + crop_hw=(self.crop_hws[d_num], self.crop_hws[d_num]), + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.val_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + + elif self.model_type == "centroid": + self.train_datasets[d_num] = CentroidDataset( + labels=self.train_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.centroid.confmaps[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=self.config.data_config.augmentation_config, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=self.config.data_config.use_augmentations_train, + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.train_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + self.val_datasets[d_num] = CentroidDataset( + labels=self.val_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.centroid.confmaps[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=None, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=False, + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.val_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + + elif self.model_type == "single_instance": + self.train_datasets[d_num] = SingleInstanceDataset( + labels=self.train_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.single_instance.confmaps[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=self.config.data_config.augmentation_config, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=self.config.data_config.use_augmentations_train, + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.train_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + self.val_datasets[d_num] = SingleInstanceDataset( + labels=self.val_labels[d_num], + confmap_head_config=self.config.model_config.head_configs.single_instance.confmaps[ + d_num + ], + max_stride=self.max_stride, + user_instances_only=self.config.data_config.user_instances_only, + is_rgb=self.config.data_config.preprocessing.is_rgb, + augmentation_config=None, + scale=self.config.data_config.preprocessing.scale[d_num], + apply_aug=False, + max_hw=(self.max_heights[d_num], self.max_widths[d_num]), + cache_img=self.cache_img, + cache_img_path=self.val_cache_img_paths[d_num], + use_existing_imgs=self.use_existing_imgs[d_num], + ) + + else: + message = f"Model type: {self.model_type}. Ensure the heads config has one of the keys: [`bottomup`, `centroid`, `centered_instance`, `single_instance`]." + logger.error(message) + raise ValueError(message) + + if self.steps_per_epochs.get(d_num, None) is None: + self.steps_per_epochs[d_num] = ( + len(self.train_datasets[d_num]) + // self.config.trainer_config.train_data_loader.batch_size + ) + if self.steps_per_epochs[d_num] == 0: + self.steps_per_epochs[d_num] = 1 + + val_steps_per_epoch = ( + len(self.val_datasets[d_num]) + // self.config.trainer_config.val_data_loader.batch_size + ) + + pin_memory = ( + self.config.trainer_config.train_data_loader.pin_memory + if "pin_memory" in self.config.trainer_config.train_data_loader + and self.config.trainer_config.train_data_loader.pin_memory is not None + else True + ) + + if self.config.trainer_config.trainer_devices > 1: + # add sampler if using more than 1 gpu + train_sampler = DistributedSampler( + self.train_datasets[d_num], + num_replicas=self.config.trainer_config.trainer_devices, + rank=self.trainer.global_rank if self.trainer is not None else 0, + shuffle=self.config.trainer_config.train_data_loader.shuffle, + ) + + self.train_data_loaders[d_num] = DataLoader( + dataset=self.train_datasets[d_num], + batch_size=self.config.trainer_config.train_data_loader.batch_size, + num_workers=self.config.trainer_config.train_data_loader.num_workers, + pin_memory=pin_memory, + persistent_workers=( + True + if self.config.trainer_config.train_data_loader.num_workers > 0 + else None + ), + prefetch_factor=( + self.config.trainer_config.train_data_loader.batch_size + if self.config.trainer_config.train_data_loader.num_workers > 0 + else None + ), + sampler=train_sampler, + multiprocessing_context="forkserver", + ) + + val_sampler = DistributedSampler( + self.val_datasets[d_num], + num_replicas=self.config.trainer_config.trainer_devices, + rank=self.trainer.global_rank if self.trainer is not None else 0, + shuffle=False, + ) + self.val_data_loaders[d_num] = DataLoader( + dataset=self.val_datasets[d_num], + batch_size=self.config.trainer_config.val_data_loader.batch_size, + num_workers=self.config.trainer_config.val_data_loader.num_workers, + pin_memory=pin_memory, + persistent_workers=( + True + if self.config.trainer_config.val_data_loader.num_workers > 0 + else None + ), + prefetch_factor=( + self.config.trainer_config.val_data_loader.batch_size + if self.config.trainer_config.val_data_loader.num_workers > 0 + else None + ), + sampler=val_sampler, + multiprocessing_context="forkserver", + ) + + else: + self.train_data_loaders[d_num] = DataLoader( + dataset=self.train_datasets[d_num], + batch_size=self.config.trainer_config.train_data_loader.batch_size, + num_workers=self.config.trainer_config.train_data_loader.num_workers, + pin_memory=pin_memory, + persistent_workers=( + True + if self.config.trainer_config.train_data_loader.num_workers > 0 + else None + ), + prefetch_factor=( + self.config.trainer_config.train_data_loader.batch_size + if self.config.trainer_config.train_data_loader.num_workers > 0 + else None + ), + shuffle=self.config.trainer_config.train_data_loader.shuffle, + ) + + self.val_data_loaders[d_num] = DataLoader( + dataset=self.val_datasets[d_num], + batch_size=self.config.trainer_config.val_data_loader.batch_size, + num_workers=self.config.trainer_config.val_data_loader.num_workers, + pin_memory=pin_memory, + persistent_workers=( + True + if self.config.trainer_config.val_data_loader.num_workers > 0 + else None + ), + prefetch_factor=( + self.config.trainer_config.val_data_loader.batch_size + if self.config.trainer_config.val_data_loader.num_workers > 0 + else None + ), + shuffle=False, + ) + + def _set_wandb(self): + wandb.login(key=self.config.trainer_config.wandb.api_key) + + def _initialize_model( + self, + ): + models = { + "single_instance": SingleInstanceMultiHeadLightningModule, + "centered_instance": TopDownCenteredInstanceMultiHeadLightningModule, + "centroid": CentroidMultiHeadLightningModule, + "bottomup": BottomUpMultiHeadLightningModule, + } + self.model = models[self.model_type]( + config=self.config, + model_type=self.model_type, + backbone_type=self.backbone_type, + ) + + def _get_param_count(self): + return sum(p.numel() for p in self.model.parameters()) + + def train(self): + """Initiate the training by calling the fit method of Trainer.""" + self._initialize_model() + total_params = self._get_param_count() + self.config.model_config.total_params = total_params + + training_loggers = [] + + if self.config.trainer_config.save_ckpt: + + # create checkpoint callback + checkpoint_callback = ModelCheckpoint( + save_top_k=self.config.trainer_config.model_ckpt.save_top_k, + save_last=self.config.trainer_config.model_ckpt.save_last, + dirpath=self.dir_path, + filename="best", + monitor="val_loss", + mode="min", + ) + callbacks = [checkpoint_callback] + # logger to create csv with metrics values over the epochs + csv_logger = CSVLogger(self.dir_path) + training_loggers.append(csv_logger) + + else: + callbacks = [] + + if self.config.trainer_config.early_stopping.stop_training_on_plateau: + callbacks.append( + EarlyStopping( + monitor="val_loss", + mode="min", + verbose=False, + min_delta=self.config.trainer_config.early_stopping.min_delta, + patience=self.config.trainer_config.early_stopping.patience, + ) + ) + + if self.config.trainer_config.use_wandb: + wandb_config = self.config.trainer_config.wandb + if wandb_config.wandb_mode == "offline": + os.environ["WANDB_MODE"] = "offline" + else: + self._set_wandb() + wandb_logger = WandbLogger( + entity=wandb_config.entity, + project=wandb_config.project, + name=wandb_config.name, + save_dir=self.dir_path, + id=self.config.trainer_config.wandb.prv_runid, + group=self.config.trainer_config.wandb.group, + ) + training_loggers.append(wandb_logger) + + # save the configs as yaml in the checkpoint dir + self.config.trainer_config.wandb.api_key = "" + + profilers = { + "advanced": AdvancedProfiler(), + "passthrough": PassThroughProfiler(), + "pytorch": PyTorchProfiler(), + "simple": SimpleProfiler(), + } + cfg_profiler = OmegaConf.select( + self.config, "trainer_config.profiler", default=None + ) + profiler = None + if cfg_profiler is not None: + if cfg_profiler in profilers: + profiler = profilers[cfg_profiler] + else: + message = f"{cfg_profiler} is not a valid option. Please choose one of {list(profilers.keys())}" + logger.error(message) + raise ValueError(message) + + strategy = OmegaConf.select( + self.config, "trainer_config.trainer_strategy", default="auto" + ) + + self.trainer = L.Trainer( + callbacks=callbacks, + logger=training_loggers, + enable_checkpointing=self.config.trainer_config.save_ckpt, + devices=self.config.trainer_config.trainer_devices, + max_epochs=self.config.trainer_config.max_epochs, + accelerator=self.config.trainer_config.trainer_accelerator, + enable_progress_bar=self.config.trainer_config.enable_progress_bar, + strategy=strategy, + profiler=profiler, + ) + + # save the configs as yaml in the checkpoint dir + if ( + self.trainer.global_rank == 0 + ): # save config if there are no distributed process or the rank = 0 + OmegaConf.save( + config=self.config, f=f"{self.dir_path}/training_config.yaml" + ) + + for d_num, _ in self.config.dataset_mapper.items(): + self._create_data_loaders_torch_dataset(d_num=d_num) + self.combined_train_dataloader = CombinedLoader( + self.train_data_loaders, + mode=self.config.trainer_config.combined_loader_mode, + ) + self.combined_val_dataloader = CombinedLoader( + self.val_data_loaders, + mode=self.config.trainer_config.combined_loader_mode, + ) + + if self.config.trainer_config.use_wandb: + if ( + self.trainer.global_rank == 0 + ): # save config if there are no distributed process or the rank = 0 + wandb_logger.experiment.config.update({"run_name": wandb_config.name}) + wandb_logger.experiment.config.update( + {"run_config": OmegaConf.to_container(self.config, resolve=True)} + ) + wandb_logger.experiment.config.update({"model_params": total_params}) + + try: + + self.trainer.fit( + self.model, + self.combined_train_dataloader, + self.combined_val_dataloader, + ckpt_path=self.config.trainer_config.resume_ckpt_path, + ) + + except KeyboardInterrupt: + logger.info("Stopping training...") + + finally: + if self.config.trainer_config.use_wandb: + self.config.trainer_config.wandb.run_id = wandb.run.id + wandb.finish() + + # save the config with wandb runid + OmegaConf.save( + config=self.config, f=f"{self.dir_path}/training_config.yaml" + ) + + if ( + self.data_pipeline_fw == "torch_dataset_cache_img_disk" + and self.config.data_config.delete_cache_imgs_after_training + ): + for d_num, d_name in self.config.dataset_mapper.items(): + if (self.train_cache_img_path[d_num]).exists(): + shutil.rmtree( + (self.train_cache_img_path[d_num]).as_posix(), + ignore_errors=True, + ) + + if (self.val_cache_img_path[d_num]).exists(): + shutil.rmtree( + (self.val_cache_img_path[d_num]).as_posix(), + ignore_errors=True, + ) From 884c7097526d2afee8a2c45931a82b50d1ec79c3 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 24 Apr 2025 10:11:21 -0700 Subject: [PATCH 5/8] Fix output stride access --- sleap_nn/architectures/model.py | 4 +++- sleap_nn/inference/predictors.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sleap_nn/architectures/model.py b/sleap_nn/architectures/model.py index b20a7f2e..474d8e2c 100644 --- a/sleap_nn/architectures/model.py +++ b/sleap_nn/architectures/model.py @@ -241,7 +241,9 @@ def __init__( output_strides = [] for head_type in head_configs: head_config = head_configs[head_type] - output_strides.extend([cfg.output_stride for cfg in head_config]) + output_strides.extend( + [head_config[cfg].output_stride for cfg in head_config] + ) min_output_stride = min(output_strides) min_output_stride = min(min_output_stride, self.backbone_config.output_stride) diff --git a/sleap_nn/inference/predictors.py b/sleap_nn/inference/predictors.py index 95d42c5e..c97172b9 100644 --- a/sleap_nn/inference/predictors.py +++ b/sleap_nn/inference/predictors.py @@ -755,11 +755,13 @@ def from_trained_models( skeletons = get_skeleton_from_config( confmap_config.data_config.skeletons ) - confmap_model = TopDownCenteredInstanceLightningModule.load_from_checkpoint( - checkpoint_path=ckpt_path, - config=confmap_config, + confmap_model = ( + TopDownCenteredInstanceLightningModule.load_from_checkpoint( + checkpoint_path=ckpt_path, + config=confmap_config, model_type="centered_instance", - backbone_type=centered_instance_backbone_type, + backbone_type=centered_instance_backbone_type, + ) ) if backbone_ckpt_path is not None and head_ckpt_path is not None: @@ -1198,7 +1200,7 @@ def from_trained_models( confmap_model = SingleInstanceLightningModule.load_from_checkpoint( checkpoint_path=ckpt_path, config=confmap_config, - model_type="single_instance", + model_type="single_instance", backbone_type=backbone_type, ) if backbone_ckpt_path is not None and head_ckpt_path is not None: @@ -1672,7 +1674,7 @@ def from_trained_models( bottomup_model = BottomUpLightningModule.load_from_checkpoint( checkpoint_path=ckpt_path, config=bottomup_config, - backbone_type=backbone_type, + backbone_type=backbone_type, model_type="bottomup", ) From 839b25e33da625669c6bc133a172ac85bb6e19c1 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Thu, 24 Apr 2025 11:48:45 -0700 Subject: [PATCH 6/8] Fix raise exception for val imgs --- sleap_nn/training/model_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index c76aea95..7830e74b 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -993,7 +993,7 @@ def __init__( if not ( self.val_cache_img_paths[d_num].exists() and self.val_cache_img_paths[d_num].is_dir() - and any(self.val_cache_img_paths[d_num].glob("*.npz")) + and any(self.val_cache_img_paths[d_num].glob("*.jpg")) ): message = f"There are no images in the path: {self.val_cache_img_paths[d_num]}" logger.error(message) From 17e60db658163cff5f37f925a8113e414a95f0b3 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Sat, 26 Apr 2025 07:42:19 -0700 Subject: [PATCH 7/8] Add rank argument --- sleap_nn/training/model_trainer.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index cf172591..8605ea6a 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -1199,6 +1199,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.train_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) self.val_datasets[d_num] = BottomUpDataset( labels=self.val_labels[d_num], @@ -1218,6 +1219,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.val_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) elif self.model_type == "centered_instance": @@ -1237,6 +1239,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.train_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) self.val_datasets[d_num] = CenteredInstanceDataset( labels=self.val_labels[d_num], @@ -1254,6 +1257,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.val_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) elif self.model_type == "centroid": @@ -1272,6 +1276,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.train_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) self.val_datasets[d_num] = CentroidDataset( labels=self.val_labels[d_num], @@ -1288,6 +1293,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.val_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) elif self.model_type == "single_instance": @@ -1306,6 +1312,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.train_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) self.val_datasets[d_num] = SingleInstanceDataset( labels=self.val_labels[d_num], @@ -1322,6 +1329,7 @@ def _create_data_loaders_torch_dataset(self, d_num): cache_img=self.cache_img, cache_img_path=self.val_cache_img_paths[d_num], use_existing_imgs=self.use_existing_imgs[d_num], + rank=self.trainer.global_rank if self.trainer is not None else None, ) else: From c8aac9d27b6436a1d0afd686cdc0767bc42de637 Mon Sep 17 00:00:00 2001 From: gitttt-1234 Date: Wed, 30 Apr 2025 11:14:42 -0700 Subject: [PATCH 8/8] Explicitly close video handles --- sleap_nn/training/model_trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 1647ec20..05204c40 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -1367,6 +1367,16 @@ def _create_data_loaders_torch_dataset(self, d_num): else True ) + # If using caching, close the videos to prevent `h5py objects can't be pickled error` when num_workers > 0. + if "cache_img" in self.data_pipeline_fw: + for video in self.train_labels[d_num].videos: + if video.is_open: + video.close() + + for video in self.val_labels[d_num].videos: + if video.is_open: + video.close() + if self.config.trainer_config.trainer_devices > 1: # add sampler if using more than 1 gpu train_sampler = DistributedSampler(