grater_jax.disk_model.winnie_class
winnie_class.py
JAX wrapper for spatially-varying JWST PSF convolution.
This module defines the WinniePSF class, a JAX-compatible implementation of the Winnie package (Lawson et al.) for PSF subtraction. It provides tools for constructing and applying spatially-varying PSFs across roll angles, including rotation, masking, and convolution. A helper function jax_rotate_image is also included for JAX-based image rotation.
Functions
|
Generate a grid of synthetic NIRCam PSFs and associated index/throughput maps. |
|
Rotate a 2D image using JAX-compatible coordinate transformations. |
|
Map the input array to new coordinates using interpolation. |
|
Extends the set of types that are considered internal nodes in pytrees. |
|
Image resize. |
|
Vectorizing map. |
Classes
|
A class for spatially-varying PSF convolution with image cubes, including rotation and masking per roll angle. |
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
- class grater_jax.disk_model.winnie_class.WinniePSF(psfs, psf_inds_rolls, im_mask_rolls, psf_offsets, psf_parangs, num_unique_psfs)
Bases:
objectA class for spatially-varying PSF convolution with image cubes, including rotation and masking per roll angle. Class is a JAX pytree implementation of Kellen Lawson’s Winnie package for PSF subtraction.
- Parameters:
psfs (array_like) – Array of PSF kernels, shape (N, H, W) where N is number of unique PSFs.
psf_inds_rolls (array_like) – Array of PSF index maps for each roll angle. Shape (n_rolls, H, W).
im_mask_rolls (array_like) – Array of image masks for each roll angle. Shape (n_rolls, H, W).
psf_offsets (array_like) – Array of PSF offsets per roll (currently unused).
psf_parangs (array_like) – Array of parallactic angles (in degrees) for each roll.
num_unique_psfs (int) – The number of unique PSF kernels used across the image.
- psfs
- Type:
jax.numpy.ndarray
- psf_inds_rolls
- Type:
jax.numpy.ndarray
- im_mask_rolls
- Type:
jax.numpy.ndarray
- psf_offsets
- Type:
jax.numpy.ndarray
- psf_parangs
- Type:
jax.numpy.ndarray
- num_unique_psfs
- Type:
int
- classmethod init_db(db, roll_index=0, normalize=True, log_rscale=True, nr=5, rmax=3.1622776601683795, ntheta=4, use_coeff=True)
Generate PSF grid inputs from a spaceKLIP Database and store in a form usable by a WinniePSF class.
- Parameters:
db (spaceKLIP.database.Database) – Loaded spaceKLIP database object.
roll_index (int, optional) – Index of roll angle to use for instrument setup (default is 0).
normalize (bool) – Whether to normalize each PSF slice.
log_rscale (bool) – Whether to sample PSFs logarithmically in radius.
nr (int) – Number of radial zones.
rmax (float) – Maximum radius (arcsec).
ntheta (int) – Number of angular zones.
use_coeff (bool) – Whether to use precomputed PSF coefficients.
- get_convolved_cube(image)
Apply spatially-varying convolution to an input image across roll angles.
- Parameters:
image (jax.numpy.ndarray) – 2D input image to convolve.
- Returns:
Convolved image cube of shape (n_rolls, H, W).
- Return type:
jax.numpy.ndarray
- convolve_cube(image, cent)
Rotate and convolve the input image for each roll angle.
- Parameters:
image (jax.numpy.ndarray) – Input 2D image.
cent (tuple of float) – Center of rotation (x, y) in pixels.
- Returns:
Image cube after convolution, shape (n_rolls, H, W).
- Return type:
jax.numpy.ndarray
- rotate_hypercube(hcube, angles, cent, cval0=0.0)
Rotate each image slice in a cube by a corresponding angle.
- Parameters:
hcube (jax.numpy.ndarray) – Input hypercube of shape (n_slices, H, W).
angles (jax.numpy.ndarray) – Array of angles (degrees) for each slice.
cent (tuple of float) – Center of rotation (x, y).
cval0 (float, optional) – Fill value for outside edges (default is 0).
- Returns:
Rotated hypercube with the same shape as input.
- Return type:
jax.numpy.ndarray
- convolve_with_spatial_psfs(image, psf_inds, im_mask)
Convolve an image using spatially-varying PSFs based on a mask and index map.
- Parameters:
image (jax.numpy.ndarray) – The image to convolve.
psf_inds (jax.numpy.ndarray) – Index map indicating which PSF to use at each pixel.
im_mask (jax.numpy.ndarray) – Binary mask to apply before convolution.
- Returns:
Convolved image.
- Return type:
jax.numpy.ndarray
- tree_flatten()
Flatten the object for use with JAX PyTrees.
- Returns:
A tuple (children, aux_data) where children are JAX-traceable arrays and aux_data is any other metadata.
- Return type:
tuple
- grater_jax.disk_model.winnie_class.jax_rotate_image(image, angle, reshape=False, cval=0.0, order=1)
Rotate a 2D image using JAX-compatible coordinate transformations.
- Parameters:
image (jax.numpy.ndarray) – Input 2D image.
angle (float) – Rotation angle in degrees.
reshape (bool, optional) – Whether to reshape output to fit the entire rotated image (not used).
cval (float, optional) – Constant value for points outside boundaries.
order (int, optional) – Spline interpolation order (0=nearest, 1=linear, etc).
- Returns:
Rotated image of the same shape.
- Return type:
jax.numpy.ndarray