Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions docs/examples/models/phase_thick_3d_sector_illumination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
phase thick 3d with sector illumination
========================================

# 3D phase reconstruction with oblique sector illumination
# This example demonstrates multi-channel phase reconstruction where each channel
# corresponds to a different illumination sector angle.
"""

import napari
import numpy as np
import torch

from waveorder.models import phase_thick_3d

# Parameters
# all lengths must use consistent units e.g. um
simulation_arguments = {
"zyx_shape": (100, 256, 256),
"yx_pixel_size": 6.5 / 63,
"z_pixel_size": 0.25,
"index_of_refraction_media": 1.3,
}
phantom_arguments = {"index_of_refraction_sample": 1.50, "sphere_radius": 5}
transfer_function_arguments = {
"z_padding": 0,
"wavelength_illumination": 0.532,
"numerical_aperture_illumination": 0.9,
"numerical_aperture_detection": 1.2,
}

# Define 9 sector illumination angles
# 8 sectors at 45-degree intervals + 1 full aperture
sector_angle = 45
sector_angle_offset = -22.5
illumination_sector_angles = [
(
(i * sector_angle + sector_angle_offset) % 360,
((i + 1) * sector_angle + sector_angle_offset) % 360,
)
for i in range(8)
] + [(0, 360)]

print(f"Using {illumination_sector_angles} illumination sectors")

# Create a phantom
zyx_phase = phase_thick_3d.generate_test_phantom(
**simulation_arguments, **phantom_arguments
)

# Calculate multi-channel transfer function (one for each sector)
(
real_potential_transfer_function,
imag_potential_transfer_function,
) = phase_thick_3d.calculate_transfer_function(
**simulation_arguments,
**transfer_function_arguments,
illumination_sector_angles=illumination_sector_angles,
)

print(
f"Transfer function shape: {real_potential_transfer_function.shape}"
) # Should be (C, Z, Y, X)

# Display complete multi-channel transfer function
viewer = napari.Viewer()
zyx_scale = np.array(
[
simulation_arguments["z_pixel_size"],
simulation_arguments["yx_pixel_size"],
simulation_arguments["yx_pixel_size"],
]
)

# Add full CZYX transfer function (imaginary part) as single 4D layer
# Match the visualization style from add_transfer_function_to_viewer
czyx_shape = imag_potential_transfer_function.shape
voxel_scale = np.array(
[
czyx_shape[1] * zyx_scale[0], # Z extent
czyx_shape[2] * zyx_scale[1], # Y extent
czyx_shape[3] * zyx_scale[2], # X extent
]
)
lim = 0.5 * torch.max(torch.abs(imag_potential_transfer_function)).item()

viewer.add_image(
torch.fft.ifftshift(
torch.imag(imag_potential_transfer_function), dim=(-3, -2, -1)
)
.cpu()
.numpy(),
name="Imag pot. TF (CZYX)",
colormap="bwr",
contrast_limits=(-lim, lim),
scale=(1,) + tuple(1 / voxel_scale), # No scaling on C dimension
)

# Set up XZ view with C and Y as sliders
viewer.dims.order = [0, 2, 1, 3] # (C, Y, Z, X) for XZ display
viewer.dims.current_step = (
0,
czyx_shape[1] // 2,
czyx_shape[2] // 2,
czyx_shape[3] // 2,
)

input(
"Showing CZYX OTF in XZ view (use C and Y sliders). Press <enter> to continue..."
)
viewer.layers.select_all()
viewer.layers.remove_selected()

# Simulate multi-channel data (one channel per sector)
# In practice, these would come from your microscope as separate acquisitions
zyx_data_multi_channel = []
for c in range(len(illumination_sector_angles)):
zyx_data_channel = phase_thick_3d.apply_transfer_function(
zyx_phase,
real_potential_transfer_function[c],
transfer_function_arguments["z_padding"],
brightness=1e3,
)
zyx_data_multi_channel.append(zyx_data_channel)

# Stack into (C, Z, Y, X) tensor
zyx_data_multi_channel = torch.stack(zyx_data_multi_channel, dim=0)
print(f"Multi-channel data shape: {zyx_data_multi_channel.shape}")

# Reconstruct phase from all channels combined
zyx_recon = phase_thick_3d.apply_inverse_transfer_function(
zyx_data_multi_channel,
real_potential_transfer_function,
imag_potential_transfer_function,
transfer_function_arguments["z_padding"],
)

# Display
viewer.add_image(zyx_phase.numpy(), name="Phantom", scale=zyx_scale)
viewer.add_image(
zyx_data_multi_channel.numpy(),
name="Data (CZYX)",
scale=zyx_scale,
)
viewer.add_image(zyx_recon.numpy(), name="Reconstruction", scale=zyx_scale)

# Show comparison with single channel (full aperture) for reference
print("\nComparing with single-channel (full aperture) reconstruction...")
(
real_tf_single,
imag_tf_single,
) = phase_thick_3d.calculate_transfer_function(
**simulation_arguments,
**transfer_function_arguments,
illumination_sector_angles=None, # Full aperture
)
zyx_data_single = phase_thick_3d.apply_transfer_function(
zyx_phase,
real_tf_single[0], # Single channel
transfer_function_arguments["z_padding"],
brightness=1e3,
)
zyx_recon_single = phase_thick_3d.apply_inverse_transfer_function(
zyx_data_single[None, ...], # Add channel dimension
real_tf_single,
imag_tf_single,
transfer_function_arguments["z_padding"],
)
viewer.add_image(
zyx_recon_single.numpy(),
name="Reconstruction (single channel)",
scale=zyx_scale,
)

print(
f"\nReconstruction error (multi-channel): {torch.mean(torch.abs(zyx_recon - zyx_phase)).item():.6f}"
)
print(
f"Reconstruction error (single channel): {torch.mean(torch.abs(zyx_recon_single - zyx_phase)).item():.6f}"
)

input("\nShowing phantom, data, and reconstructions. Press <enter> to quit...")
10 changes: 5 additions & 5 deletions waveorder/cli/apply_inverse_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,19 @@ def phase(

# [phase only, 3]
elif recon_dim == 3:
# Load transfer functions
# Load transfer functions (keep channel dimension)
real_potential_transfer_function = torch.tensor(
transfer_function_dataset["real_potential_transfer_function"][0, 0]
transfer_function_dataset["real_potential_transfer_function"][0]
)
imaginary_potential_transfer_function = torch.tensor(
transfer_function_dataset["imaginary_potential_transfer_function"][
0, 0
0
]
)

# Apply
# Apply (pass full CZYX data)
output = phase_thick_3d.apply_inverse_transfer_function(
czyx_data[0],
czyx_data,
real_potential_transfer_function,
imaginary_potential_transfer_function,
z_padding=settings_phase.transfer_function.z_padding,
Expand Down
11 changes: 5 additions & 6 deletions waveorder/cli/compute_transfer_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,12 @@ def generate_and_save_phase_transfer_function(
# Save
dataset.create_image(
"real_potential_transfer_function",
real_potential_transfer_function.cpu().numpy()[None, None, ...],
real_potential_transfer_function.cpu().numpy()[None, ...],
chunks=(1, 1, 1, zyx_shape[1], zyx_shape[2]),
)
dataset.create_image(
"imaginary_potential_transfer_function",
imaginary_potential_transfer_function.cpu().numpy()[
None, None, ...
],
imaginary_potential_transfer_function.cpu().numpy()[None, ...],
chunks=(1, 1, 1, zyx_shape[1], zyx_shape[2]),
)

Expand Down Expand Up @@ -367,14 +365,15 @@ def compute_transfer_function_cli(
print("Found z_focus_offset:", z_focus_offset)

# Prepare output dataset
num_channels = (
num_input_channel = len(settings.input_channel_names)
num_output_channels = (
2 if settings.reconstruction_dimension == 2 else 1
) # space for SVD
output_dataset = open_ome_zarr(
output_dirpath,
layout="fov",
mode="w",
channel_names=num_channels * ["None"],
channel_names=num_input_channel * num_output_channels * ["None"],
)

# Pass settings to appropriate calculate_transfer_function and save
Expand Down
68 changes: 55 additions & 13 deletions waveorder/cli/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from pathlib import Path
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional, Tuple, Union

from pydantic.v1 import (
BaseModel,
Expand Down Expand Up @@ -100,6 +100,7 @@ class PhaseTransferFunctionSettings(
):
numerical_aperture_illumination: NonNegativeFloat = 0.5
invert_phase_contrast: bool = False
illumination_sector_angles: Optional[List[Tuple[float, float]]] = None

@validator("numerical_aperture_illumination")
def na_ill(cls, v, values):
Expand All @@ -110,6 +111,27 @@ def na_ill(cls, v, values):
)
return v

@validator("illumination_sector_angles")
def validate_sector_angles(cls, v):
if v is None:
return v
if len(v) == 0:
raise ValueError(
"illumination_sector_angles must contain at least one sector"
)
normalized = []
for start, end in v:
if start >= end:
raise ValueError(
f"Sector start angle {start} must be less than end angle {end}"
)
# Normalize angles to [0, 360) using modulo 360
# Special case: preserve 360 for full aperture (don't reduce to 0)
normalized_start = start % 360
normalized_end = end % 360 if end % 360 != 0 else 360
normalized.append((normalized_start, normalized_end))
return normalized


class FluorescenceTransferFunctionSettings(FourierTransferFunctionSettings):
wavelength_emission: PositiveFloat = 0.507
Expand Down Expand Up @@ -171,17 +193,37 @@ def validate_reconstruction_types(cls, values):
'"fluorescence" cannot be present alongside "birefringence" or "phase". Please use one configuration file for a "fluorescence" reconstruction and another configuration file for a "birefringence" and/or "phase" reconstructions.'
)
num_channel_names = len(values.get("input_channel_names"))
if values.get("birefringence") is None:
if (
values.get("phase") is None
and values.get("fluorescence") is None
):
raise ValueError(
"Provide settings for either birefringence, phase, birefringence + phase, or fluorescence."
)
if num_channel_names != 1:
raise ValueError(
f"{num_channel_names} channels names provided. Please provide a single channel for fluorescence/phase reconstructions."
)

# Check for sector illumination in phase reconstruction
phase_settings = values.get("phase")
if phase_settings is not None:
sector_angles = (
phase_settings.transfer_function.illumination_sector_angles
)
if sector_angles is not None:
# Multi-channel reconstruction with sector illumination
if len(sector_angles) != num_channel_names:
raise ValueError(
f"Number of illumination_sector_angles ({len(sector_angles)}) must match number of input_channel_names ({num_channel_names})"
)
else:
# Single channel phase reconstruction without sector illumination
if (
values.get("birefringence") is None
and num_channel_names != 1
):
raise ValueError(
f"{num_channel_names} channels names provided. Please provide a single channel for phase reconstructions without sector illumination."
)
else:
if values.get("birefringence") is None:
if values.get("fluorescence") is None:
raise ValueError(
"Provide settings for either birefringence, phase, birefringence + phase, or fluorescence."
)
if num_channel_names != 1:
raise ValueError(
f"{num_channel_names} channels names provided. Please provide a single channel for fluorescence reconstructions."
)

return values
Loading
Loading