Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions docs/config_slumbr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
dataset_mapper:
0: dataset_1
1: dataset_2
2: dataset_3
data_config:
provider: LabelsReader
train_labels_path:
0: minimal_instance.pkg.slp
1: minimal_instance.pkg.slp
2: minimal_instance.pkg.slp
val_labels_path:
0: minimal_instance.pkg.slp
1: minimal_instance.pkg.slp
2: minimal_instance.pkg.slp
test_file_path:
0: minimal_instance.pkg.slp
1: minimal_instance.pkg.slp
2: minimal_instance.pkg.slp
data_pipeline_fw: torch_dataset_cache_img_disk # one of `torch_dataset`, `torch_dataset_cache_img_memory`, `torch_dataset_cache_img_disk`
cache_img_path: ./img_dir/
use_existing_imgs:
0:
1:
2:
user_instances_only: True
litdata_chunks_path:
chunk_size:
delete_cache_imgs_after_training:
preprocessing:
max_width:
0:
1:
2:
max_height:
0:
1:
2:
scale:
0:
1:
2:
is_rgb: True
crop_hw:
0:
1:
2:
min_crop_size:
use_augmentations_train: true
augmentation_config:
geometric:
rotation: 180.0
scale: null
translate_width: 0
translate_height: 0
affine_p: 0.5

model_config:
init_weights: xavier
pre_trained_weights:
pretrained_backbone_weights:
pretrained_head_weights:
backbone_type: unet
backbone_config:
in_channels: 1
kernel_size: 3
filters: 16
filters_rate: 2
max_stride: 16
convs_per_block: 2
stacks: 1
stem_stride:
middle_block: True
up_interpolate: True
head_configs:
single_instance:
centroid:
bottomup:
centered_instance:
confmaps:
0:
part_names:
anchor_part: 0
sigma: 1.5
output_stride: 2
1:
part_names:
anchor_part: 2
sigma: 1.5
output_stride: 2
2:
part_names:
anchor_part: 4
sigma: 1.5
output_stride: 2
trainer_config:
train_data_loader:
batch_size: 4
shuffle: true
num_workers: 2
val_data_loader:
batch_size: 4
num_workers: 2
combined_loader_mode: "max_size_cycle" # or "min_size", "max_size"
model_ckpt:
save_top_k: 1
save_last: true
trainer_devices: 1
trainer_accelerator: gpu
profiler: simple
trainer_strategy: "ddp_find_unused_parameters_false" # "auto", "ddp", "ddp_find_unused_parameters_true"
enable_progress_bar: false
log_inf_epochs: 5
steps_per_epoch:
0:
1:
2:
max_epochs: 10
seed: 1000
use_wandb: true
wandb:
entity:
project: 'test_centroid_centered'
name: 'fly_unet_centered'
wandb_mode: ''
api_key: ''
prv_runid:
group:
save_ckpt: true
save_ckpt_path: 'multi_head'
resume_ckpt_path:
optimizer_name: Adam
optimizer:
lr: 0.0001
amsgrad: false
lr_scheduler:
scheduler: ReduceLROnPlateau
reduce_lr_on_plateau:
threshold: 1.0e-07
threshold_mode: abs
cooldown: 3
patience: 5
factor: 0.5
min_lr: 1.0e-08
early_stopping:
stop_training_on_plateau: True
min_delta: 1.0e-08
patience: 20
133 changes: 132 additions & 1 deletion sleap_nn/architectures/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from typing import List

from collections import defaultdict
import torch
from omegaconf.dictconfig import DictConfig
from torch import nn
Expand Down Expand Up @@ -178,3 +178,134 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
outputs[head.name] = head_layer(backbone_outputs["outputs"][idx])

return outputs


class MultiHeadModel(nn.Module):
"""Model creates a model consisting of a backbone and head.

Attributes:
backbone_type: Backbone type. One of `unet`, `convnext` and `swint`.
backbone_config: An `DictConfig` configuration dictionary for the model backbone.
model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
head_configs: An `DictConfig` configuration dictionary for the model heads
(this should have multiple head configs for each dataset).
"""

def __init__(
self,
backbone_type: str,
backbone_config: DictConfig,
model_type: str,
head_configs: DictConfig,
) -> None:
"""Initialize the backbone and head based on the backbone_config."""
super().__init__()
self.backbone_type = backbone_type
self.backbone_config = backbone_config
self.model_type = model_type
self.head_configs = head_configs

self.heads = []
if self.model_type == "single_instance":
for d_num, _ in self.head_configs.confmaps.items():
self.heads.append(
SingleInstanceConfmapsHead(**self.head_configs.confmaps[d_num])
)

elif self.model_type == "centered_instance":
for d_num, _ in self.head_configs.confmaps.items():
self.heads.append(
CenteredInstanceConfmapsHead(**self.head_configs.confmaps[d_num])
)

elif self.model_type == "centroid":
centroid_confmaps = self.head_configs.confmaps[0].copy()
centroid_confmaps.anchor_part = None
self.heads.append(CentroidConfmapsHead(**centroid_confmaps))

elif self.model_type == "bottomup":
for d_num, _ in self.head_configs.confmaps.items():
self.heads.append(
MultiInstanceConfmapsHead(**self.head_configs.confmaps[d_num])
)
for d_num, _ in self.head_configs.pafs.items():
self.heads.append(
PartAffinityFieldsHead(**self.head_configs.pafs[d_num])
)

else:
message = f"{self.model_type} is not a defined model type. Please choose one of `single_instance`, `centered_instance`, `centroid`, `bottomup`."
logger.error(message)
raise Exception(message)

output_strides = []
for head_type in head_configs:
head_config = head_configs[head_type]
output_strides.extend(
[head_config[cfg].output_stride for cfg in head_config]
)

min_output_stride = min(output_strides)
min_output_stride = min(min_output_stride, self.backbone_config.output_stride)

self.backbone = get_backbone(
self.backbone_type,
backbone_config,
)

strides = self.backbone.dec.current_strides
self.head_layers = nn.ModuleList([])
for head in self.heads:
in_channels = int(
round(
self.backbone.max_channels
/ (
self.backbone_config.filters_rate
** len(self.backbone.dec.decoder_stack)
)
)
)
if head.output_stride != min_output_stride:
factor = strides.index(min_output_stride) - strides.index(
head.output_stride
)
in_channels = in_channels * (self.backbone_config.filters_rate**factor)
self.head_layers.append(head.make_head(x_in=int(in_channels)))

@classmethod
def from_config(
cls,
backbone_type: str,
backbone_config: DictConfig,
model_type: str,
head_configs: DictConfig,
) -> "MultiHeadModel":
"""Create the model from a config dictionary."""
return cls(
backbone_type=backbone_type,
backbone_config=backbone_config,
model_type=model_type,
head_configs=head_configs,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the model.

Args:
x: Input image.

Returns:
A dictionary with key as the head name and values as list of confmaps
for each of the skeleton formats (in the order of datasets in the
config).
"""
backbone_outputs = self.backbone(x)

outputs = defaultdict(list)
for head, head_layer in zip(self.heads, self.head_layers):
idx = backbone_outputs["strides"].index(head.output_stride)
outputs[head.name].append(
head_layer(backbone_outputs["outputs"][idx])
) # eg: outputs = {"SingleInstanceConfmapsHead" : [output_head_0, output_head_1, output_head_2, ...]}

return outputs
13 changes: 12 additions & 1 deletion sleap_nn/inference/bottomup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class BottomUpInferenceModel(L.LightningModule):
representation used during instance grouping.
input_scale: Float indicating if the images should be resized before being
passed to the model.
output_head_skeleton_num: Dataset number (as given in the config) indicating
which skeleton format to output. This parameter is only required for
multi-head model inference.
"""

def __init__(
Expand All @@ -62,6 +65,7 @@ def __init__(
return_pafs: Optional[bool] = False,
return_paf_graph: Optional[bool] = False,
input_scale: float = 1.0,
output_head_skeleton_num: int = 0,
):
"""Initialise the model attributes."""
super().__init__()
Expand All @@ -76,6 +80,7 @@ def __init__(
self.return_pafs = return_pafs
self.return_paf_graph = return_paf_graph
self.input_scale = input_scale
self.output_head_skeleton_num = output_head_skeleton_num

def _generate_cms_peaks(self, cms):
peaks, peak_vals, sample_inds, peak_channel_inds = find_local_peaks(
Expand Down Expand Up @@ -120,7 +125,13 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
self.batch_size = inputs["image"].shape[0]
output = self.torch_model(inputs["image"])
cms = output["MultiInstanceConfmapsHead"]
pafs = output["PartAffinityFieldsHead"].permute(0, 2, 3, 1)
# for multi-head-model, output is a list of confmaps for the different heads
if isinstance(cms, list):
cms = cms[self.output_head_skeleton_num]
pafs = output["PartAffinityFieldsHead"]
if isinstance(pafs, list):
pafs = pafs[self.output_head_skeleton_num]
pafs = pafs.permute(0, 2, 3, 1)
cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)

(
Expand Down
Loading
Loading