Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/249.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adds support for multiprocessing in PSF generation and patch cleanup changes.
128 changes: 53 additions & 75 deletions regularizepsf/builder.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Functions for building PSF models from images."""

import pathlib
import multiprocessing
from collections.abc import Generator

import numpy as np
import sep
from astropy.io import fits
from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation, binary_erosion, label
from skimage.transform import downscale_local_mean

from regularizepsf.exceptions import IncorrectShapeError, PSFBuilderError
from regularizepsf.image_processing import calculate_background, process_single_image
from regularizepsf.psf import ArrayPSF
from regularizepsf.util import IndexedCube, calculate_covering

Expand All @@ -34,71 +34,14 @@ def generator() -> np.ndarray:
elif isinstance(images, list) and (isinstance(images[0], str) or isinstance(images[0], pathlib.Path)):
def generator() -> np.ndarray:
for image_path in images:
with fits.open(image_path) as hdul:
yield hdul[hdu_choice].data.astype(float)
yield image_path
data_iterator = generator()
else:
msg = "Unsupported type for `images`"
raise TypeError(msg)

return data_iterator

def _scale_image(image, interpolation_scale):
interpolator = RectBivariateSpline(np.arange(image.shape[0]),
np.arange(image.shape[1]),
image)
image = interpolator(np.linspace(0,
image.shape[0] - 1,
1 + (image.shape[0] - 1) * interpolation_scale),
np.linspace(0,
image.shape[1] - 1,
1 + (image.shape[1] - 1) * interpolation_scale))
return image

def _find_patches(image, star_threshold, star_mask, interpolation_scale, psf_size, i,
saturation_threshold: float = np.inf, image_mask: np.ndarray | None = None):
background = sep.Background(image)
image_background_removed = image - background
image_star_coords = sep.extract(image_background_removed,
star_threshold,
err=background.globalrms,
mask=star_mask)

coordinates = [(i,
int(round(x - psf_size * interpolation_scale / 2)),
int(round(y - psf_size * interpolation_scale / 2)))
for x, y in zip(image_star_coords["y"], image_star_coords["x"], strict=True)]

# pad in case someone selects a region on the edge of the image
padding_shape = ((psf_size * interpolation_scale, psf_size * interpolation_scale),
(psf_size * interpolation_scale, psf_size * interpolation_scale))
padded_image = np.pad(image_background_removed,
padding_shape,
mode="reflect")

# the mask indicates which pixel should be ignored in the calculation
if image_mask is not None:
padded_mask = np.pad(image_mask, padding_shape, mode='reflect')
else: # if no mask is provided, we create an empty mask
padded_mask = np.zeros_like(padded_image, dtype=bool)

patches = {}
for coordinate in coordinates:
patch = padded_image[coordinate[1] + interpolation_scale * psf_size:
coordinate[1] + 2 * interpolation_scale * psf_size,
coordinate[2] + interpolation_scale * psf_size:
coordinate[2] + 2 * interpolation_scale * psf_size]
mask_patch = padded_mask[coordinate[1] + interpolation_scale * psf_size:
coordinate[1] + 2 * interpolation_scale * psf_size,
coordinate[2] + interpolation_scale * psf_size:
coordinate[2] + 2 * interpolation_scale * psf_size]

# we do not add patches that have saturated pixels
if np.all(patch < saturation_threshold):
patch[mask_patch] = np.nan
patches[coordinate] = patch
return patches

def _find_matches(coordinate, x_bounds, y_bounds, psf_size):
center_x = coordinate[1] + psf_size // 2
center_y = coordinate[2] + psf_size // 2
Expand Down Expand Up @@ -140,6 +83,8 @@ def _average_patches_by_percentile(patches, corners, x_bounds, y_bounds, psf_siz
counts = {tuple(corner): 0 for corner in corners}

for coordinate, patch in patches.items():
if not isinstance(patch, np.ndarray):
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this scenario occur? Can you add a comment for clarity?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put this in when testing the multiprocessing (and it not originally returned patches). No longer needed, removed.

patch = patch / patch[psf_size // 2, psf_size // 2] # normalize so the star brightness is always 1
match_indices = _find_matches(coordinate, x_bounds, y_bounds, psf_size)

Expand Down Expand Up @@ -178,6 +123,7 @@ def _average_patches(patches, corners, method='mean', percentile: float=None):

return averages, counts


class ArrayPSFBuilder:
"""A builder that will take a series of images and construct an ArrayPSF to represent their implicit PSF."""

Expand All @@ -193,19 +139,25 @@ def build(self,
images: list[str] | list[pathlib.Path] | np.ndarray | Generator,
sep_mask: list[str] | list[pathlib.Path] | np.ndarray | Generator | None = None,
hdu_choice: int | None = 0,
num_workers: int | None = None,
interpolation_scale: int = 1,
star_threshold: int = 3,
average_method: str = 'median',
percentile: float = 50,
saturation_threshold: float = np.inf,
image_mask: np.ndarray | None = None,
star_minimum: float = 0,
star_maximum: float = np.inf,
sqrt_compressed: bool = False,
return_patches: bool = False) -> (ArrayPSF, dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know the other parameters aren't documented, but could you include the new ones you added in the docstring? We really should include all of them. You'll get an extra prize if you do haha.

"""Build the PSF model.

Parameters
----------
images : list[pathlib.Path] | np.ndarray | Generator
images to use
num_workers : int | None
Number of worker processes. None uses all available CPUs.

Returns
-------
Expand All @@ -223,24 +175,27 @@ def generator() -> None:
else:
mask_iterator = _convert_to_generator(sep_mask, hdu_choice=hdu_choice)

# We'll store the first image's shape, and then make sure the others match.
image_shape = None
args = [
(i, image, star_mask, interpolation_scale, self.psf_size,
star_threshold, saturation_threshold, image_mask, hdu_choice,
star_minimum, star_maximum, sqrt_compressed)
for i, (image, star_mask) in enumerate(zip(data_iterator, mask_iterator))
]

with multiprocessing.Pool(processes = num_workers) as pool:
results = pool.map(process_single_image, args)

patches = {}
for i, (image, star_mask) in enumerate(zip(data_iterator, mask_iterator, strict=False)):
image_shape = None
for patch, data_shape in results:
if image_shape is None:
image_shape = image.shape
elif image.shape != image_shape:
image_shape = data_shape
elif image_shape != data_shape:
msg = ("Images must all be the same shape."
f"Found both {image_shape} and {image.shape}.")
f"Found both {image_shape} and {data_shape}.")
raise PSFBuilderError(msg)

# if the image should be scaled then, do the scaling before anything else
if interpolation_scale != 1:
image = _scale_image(image, interpolation_scale=1)

# find stars using SEP
patches.update(_find_patches(image, star_threshold, star_mask, interpolation_scale, self.psf_size, i,
saturation_threshold, image_mask))
patches.update(patch)

corners = calculate_covering((image_shape[0] * interpolation_scale,
image_shape[1] * interpolation_scale),
Expand All @@ -254,7 +209,30 @@ def generator() -> None:
if interpolation_scale != 1:
this_patch = downscale_local_mean(this_patch,(interpolation_scale, interpolation_scale))
values_coords.append(coordinate)
values_array[i, :, :] = this_patch / np.nansum(this_patch)

this_background = calculate_background(this_patch)
this_patch -= this_background

this_patch[this_patch == 0] = np.nan

this_value = this_patch[this_patch.shape[0]//2, this_patch.shape[1]//2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename this variable to patch_central_value for clarity?

this_value_mask = this_patch < (0.005 * this_value)
this_value_mask = binary_erosion(this_value_mask, border_value = 1)

this_patch[this_value_mask] = np.nan

patch_zeroed = np.copy(this_patch)
patch_zeroed[~np.isfinite(patch_zeroed)] = 0

patch_lab = label(patch_zeroed)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarity, let's call this something like labeled_patch or patch_labels.

psf_core_mask = patch_lab == patch_lab[patch_lab.shape[0]//2,patch_lab.shape[1]//2]

psf_core_mask = binary_dilation(psf_core_mask)

fixed_patch = patch_zeroed * psf_core_mask
fixed_patch = fixed_patch / np.nansum(fixed_patch)

values_array[i,:,:] = fixed_patch

if return_patches:
return ArrayPSF(IndexedCube(values_coords, values_array)), counts, patches
Expand Down
7 changes: 3 additions & 4 deletions regularizepsf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,20 @@
class RegularizePSFError(Exception):
"""Base class for regularizepsf exceptions."""


class InvalidCoordinateError(RegularizePSFError):
"""The key for this coordinate does not exist in the model."""


class IncorrectShapeError(RegularizePSFError):
"""The shapes do not match for the model and the value."""


class InvalidFunctionError(RegularizePSFError):
"""Function for functional model has invalid parameters."""


class FunctionParameterMismatchError(RegularizePSFError):
"""Function evaluated with nonexistent kwargs."""

class PSFBuilderError(RegularizePSFError):
"""Something went wrong building the PSF model."""

class InvalidDataError(RegularizePSFError):
"""Invalid input data provided for PSF generation."""
141 changes: 141 additions & 0 deletions regularizepsf/image_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import pathlib

import numpy as np
import scipy
import sep
from astropy.io import fits
from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation, binary_erosion

from regularizepsf.exceptions import InvalidDataError


def calculate_background(patch: np.ndarray) -> np.ndarray:
patch_y, patch_x = np.indices(patch.shape)

mask = patch != 0
mask = binary_erosion(mask)
mask[0, :] = False
mask[-1, :] = False
mask[:, 0] = False
mask[:, -1] = False
mask = binary_dilation(mask) & ~mask

value_center = patch[patch.shape[1]//2, patch.shape[0]//2]
mask = mask & (patch < value_center)

A = np.c_[patch_x[mask], patch_y[mask], np.ones_like(patch_x[mask])]
coefficients, _, _, _ = scipy.linalg.lstsq(A, patch[mask])
background = coefficients[0] * patch_x + coefficients[1] * patch_y + coefficients[2]
background[patch == 0] = np.nan

return background


def _scale_image(image, interpolation_scale, hdu_choice):
interpolator = RectBivariateSpline(np.arange(image.shape[0]),
np.arange(image.shape[1]),
image)
image = interpolator(np.linspace(0,
image.shape[0] - 1,
1 + (image.shape[0] - 1) * interpolation_scale),
np.linspace(0,
image.shape[1] - 1,
1 + (image.shape[1] - 1) * interpolation_scale))
return image


def _find_patches(image, star_threshold, star_mask, interpolation_scale, psf_size, i,
saturation_threshold: float = np.inf, image_mask: np.ndarray | None = None,
star_minimum: float = 0, star_maximum: float = np.inf):
background = sep.Background(image)
image_background_removed = image - background

try:
image_star_coords = sep.extract(image_background_removed,
star_threshold,
err=background.globalrms,
mask=star_mask)
except Exception:
return {"x":[], "y":[]}

coordinates = [(i,
int(round(x - psf_size * interpolation_scale / 2)),
int(round(y - psf_size * interpolation_scale / 2)))
for x, y in zip(image_star_coords["y"], image_star_coords["x"], strict=True)]

# pad in case someone selects a region on the edge of the image
padding_shape = ((psf_size * interpolation_scale, psf_size * interpolation_scale),
(psf_size * interpolation_scale, psf_size * interpolation_scale))
padded_image = np.pad(image,
padding_shape,
mode="reflect")

# the mask indicates which pixel should be ignored in the calculation
if image_mask is not None:
padded_mask = np.pad(image_mask, padding_shape, mode='reflect')
else: # if no mask is provided, we create an empty mask
padded_mask = np.zeros_like(padded_image, dtype=bool)

patches = {}
for coordinate in coordinates:
patch = padded_image[coordinate[1] + interpolation_scale * psf_size:
coordinate[1] + 2 * interpolation_scale * psf_size,
coordinate[2] + interpolation_scale * psf_size:
coordinate[2] + 2 * interpolation_scale * psf_size]
mask_patch = padded_mask[coordinate[1] + interpolation_scale * psf_size:
coordinate[1] + 2 * interpolation_scale * psf_size,
coordinate[2] + interpolation_scale * psf_size:
coordinate[2] + 2 * interpolation_scale * psf_size]

# Separately background subtract each patch
background_patch = calculate_background(patch)
patch_background_subtracted = patch - background_patch
patch_background_subtracted[patch == 0] = np.nan

# we do not add patches that have saturated pixels
if np.all(patch_background_subtracted < saturation_threshold):
patch_background_subtracted[mask_patch] = np.nan
patches[coordinate] = patch_background_subtracted

# # we do not add patches that have central stars outside of our defined limits
center = (patch_background_subtracted.shape[1] // 2, patch_background_subtracted.shape[0] // 2)
if (patch_background_subtracted[center] < star_minimum) | (patch_background_subtracted[center] > star_maximum):
patch_background_subtracted[mask_patch] = np.nan
patches[coordinate] = patch_background_subtracted
return patches


def process_single_image(args):
"""Process a single image to extract patches.
Parameters
----------
args : tuple
Tuple containing (i, image, star_mask, interpolation_scale, psf_size,
star_threshold, saturation_threshold, image_mask)
Returns
-------
dict
Dictionary of patches found in the image
"""
i, image, star_mask, interpolation_scale, psf_size, star_threshold, saturation_threshold, image_mask, hdu_choice, star_minimum, star_maximum, sqrt_compressed = args

if isinstance(image, (str, pathlib.Path)):
with fits.open(image) as hdul:
header = hdul[hdu_choice].header
if sqrt_compressed:
data = ((hdul[hdu_choice].data.astype(float))**2)/header['SCALE']
else:
data = hdul[hdu_choice].data.astype(float)
elif isinstance(image, np.ndarray):
data = image
else:
raise InvalidDataError

if interpolation_scale != 1:
data = _scale_image(data, interpolation_scale=interpolation_scale, hdu_choice=hdu_choice)

return _find_patches(data, star_threshold, star_mask, interpolation_scale, psf_size, i,
saturation_threshold, image_mask, star_minimum, star_maximum), data.shape
Loading