Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions micro_sam/sam_annotator/image_series_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def image_series_annotator(
output_folder: str,
model_type: str = util._DEFAULT_MODEL,
embedding_path: Optional[str] = None,
segmentation_files: Optional[Union[List[Union[os.PathLike, str]], List[np.ndarray]]] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
viewer: Optional["napari.viewer.Viewer"] = None,
Expand All @@ -94,6 +95,7 @@ def image_series_annotator(
model_type: The Segment Anything model to use. For details on the available models check out
https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
embedding_path: Filepath where to save the embeddings.
segmentation_files: TODO
tile_shape: Shape of tiles for tiled embedding prediction.
If `None` then the whole image is passed to Segment Anything.
halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
Expand All @@ -111,6 +113,11 @@ def image_series_annotator(
Returns:
The napari viewer, only returned if `return_viewer=True`.
"""
# TODO
# If segmentation_files were passed,
# then check that we have the same number of initial segmentations as images.
if segmentation_files is not None:
pass

os.makedirs(output_folder, exist_ok=True)
next_image_id = 0
Expand Down Expand Up @@ -246,11 +253,12 @@ def _next_image(viewer):


def image_folder_annotator(
input_folder: str,
input_folder: Union[os.PathLike, str],
output_folder: str,
pattern: str = "*",
viewer: Optional["napari.viewer.Viewer"] = None,
return_viewer: bool = False,
segmentation_folder: Optional[Union[os.PathLike, str]] = None,
**kwargs
) -> Optional["napari.viewer.Viewer"]:
"""Run the 2d annotation tool for a series of images in a folder.
Expand All @@ -263,14 +271,22 @@ def image_folder_annotator(
viewer: The viewer to which the SegmentAnything functionality should be added.
This enables using a pre-initialized viewer.
return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
segmentation_folder: Folder with initial segmentation results that should be loaded in the
'commited_objects' layer. This enables correcting initial segmentations obtained with another tool.
The segmentation files should have the same name as the corresponding image files.
kwargs: The keyword arguments for `micro_sam.sam_annotator.image_series_annotator`.

Returns:
The napari viewer, only returned if `return_viewer=True`.
"""
image_files = sorted(glob(os.path.join(input_folder, pattern)))
if segmentation_folder is None:
segmentation_files is None
else:
segmentation_files = sorted(glob(os.path.join(segmentation_folder, pattern)))
return image_series_annotator(
image_files, output_folder, viewer=viewer, return_viewer=return_viewer, **kwargs
image_files, output_folder, viewer=viewer, return_viewer=return_viewer,
segmentation_files=segmentation_files, **kwargs
)


Expand Down Expand Up @@ -424,6 +440,11 @@ def main():
"NOTE: It is recommended to pass this argument and store the embeddings, "
"otherwise they will be recomputed every time (which can take a long time)."
)
parser.add_argument(
"-s", "--segmentation_folder",
help="Optional filepath to a folder with initial segmentation results."
"The files in this folder need to have the same name as the corresponding image files."
)
parser.add_argument(
"-m", "--model_type", default=util._DEFAULT_MODEL,
help=f"The segment anything model that will be used, one of {available_models}."
Expand Down Expand Up @@ -454,8 +475,9 @@ def main():

image_folder_annotator(
args.input_folder, args.output_folder, args.pattern,
embedding_path=args.embedding_path, model_type=args.model_type,
tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state,
embedding_path=args.embedding_path, segmentation_folder=args.segmentation_folder,
model_type=args.model_type, tile_shape=args.tile_shape,
halo=args.halo, precompute_amg_state=args.precompute_amg_state,
checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric,
prefer_decoder=args.prefer_decoder,
)