grater_jax.disk_model.winnie_class

winnie_class.py

JAX wrapper for spatially-varying JWST PSF convolution.

This module defines the WinniePSF class, a JAX-compatible implementation of the Winnie package (Lawson et al.) for PSF subtraction. It provides tools for constructing and applying spatially-varying PSFs across roll angles, including rotation, masking, and convolution. A helper function jax_rotate_image is also included for JAX-based image rotation.

Functions

generate_nircam_psf_grid(inst, coron_cents)

Generate a grid of synthetic NIRCam PSFs and associated index/throughput maps.

jax_rotate_image(image, angle[, reshape, ...])

Rotate a 2D image using JAX-compatible coordinate transformations.

map_coordinates(input, coordinates, order[, ...])

Map the input array to new coordinates using interpolation.

register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

resize(image, shape, method[, antialias, ...])

Image resize.

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

Classes

WinniePSF(psfs, psf_inds_rolls, ...)

A class for spatially-varying PSF convolution with image cubes, including rotation and masking per roll angle.

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

class grater_jax.disk_model.winnie_class.WinniePSF(psfs, psf_inds_rolls, im_mask_rolls, psf_offsets, psf_parangs, num_unique_psfs)

Bases: object

A class for spatially-varying PSF convolution with image cubes, including rotation and masking per roll angle. Class is a JAX pytree implementation of Kellen Lawson’s Winnie package for PSF subtraction.

Parameters:
  • psfs (array_like) – Array of PSF kernels, shape (N, H, W) where N is number of unique PSFs.

  • psf_inds_rolls (array_like) – Array of PSF index maps for each roll angle. Shape (n_rolls, H, W).

  • im_mask_rolls (array_like) – Array of image masks for each roll angle. Shape (n_rolls, H, W).

  • psf_offsets (array_like) – Array of PSF offsets per roll (currently unused).

  • psf_parangs (array_like) – Array of parallactic angles (in degrees) for each roll.

  • num_unique_psfs (int) – The number of unique PSF kernels used across the image.

psfs
Type:

jax.numpy.ndarray

psf_inds_rolls
Type:

jax.numpy.ndarray

im_mask_rolls
Type:

jax.numpy.ndarray

psf_offsets
Type:

jax.numpy.ndarray

psf_parangs
Type:

jax.numpy.ndarray

num_unique_psfs
Type:

int

classmethod init_db(db, roll_index=0, normalize=True, log_rscale=True, nr=5, rmax=3.1622776601683795, ntheta=4, use_coeff=True)

Generate PSF grid inputs from a spaceKLIP Database and store in a form usable by a WinniePSF class.

Parameters:
  • db (spaceKLIP.database.Database) – Loaded spaceKLIP database object.

  • roll_index (int, optional) – Index of roll angle to use for instrument setup (default is 0).

  • normalize (bool) – Whether to normalize each PSF slice.

  • log_rscale (bool) – Whether to sample PSFs logarithmically in radius.

  • nr (int) – Number of radial zones.

  • rmax (float) – Maximum radius (arcsec).

  • ntheta (int) – Number of angular zones.

  • use_coeff (bool) – Whether to use precomputed PSF coefficients.

get_convolved_cube(image)

Apply spatially-varying convolution to an input image across roll angles.

Parameters:

image (jax.numpy.ndarray) – 2D input image to convolve.

Returns:

Convolved image cube of shape (n_rolls, H, W).

Return type:

jax.numpy.ndarray

convolve_cube(image, cent)

Rotate and convolve the input image for each roll angle.

Parameters:
  • image (jax.numpy.ndarray) – Input 2D image.

  • cent (tuple of float) – Center of rotation (x, y) in pixels.

Returns:

Image cube after convolution, shape (n_rolls, H, W).

Return type:

jax.numpy.ndarray

rotate_hypercube(hcube, angles, cent, cval0=0.0)

Rotate each image slice in a cube by a corresponding angle.

Parameters:
  • hcube (jax.numpy.ndarray) – Input hypercube of shape (n_slices, H, W).

  • angles (jax.numpy.ndarray) – Array of angles (degrees) for each slice.

  • cent (tuple of float) – Center of rotation (x, y).

  • cval0 (float, optional) – Fill value for outside edges (default is 0).

Returns:

Rotated hypercube with the same shape as input.

Return type:

jax.numpy.ndarray

convolve_with_spatial_psfs(image, psf_inds, im_mask)

Convolve an image using spatially-varying PSFs based on a mask and index map.

Parameters:
  • image (jax.numpy.ndarray) – The image to convolve.

  • psf_inds (jax.numpy.ndarray) – Index map indicating which PSF to use at each pixel.

  • im_mask (jax.numpy.ndarray) – Binary mask to apply before convolution.

Returns:

Convolved image.

Return type:

jax.numpy.ndarray

tree_flatten()

Flatten the object for use with JAX PyTrees.

Returns:

A tuple (children, aux_data) where children are JAX-traceable arrays and aux_data is any other metadata.

Return type:

tuple

classmethod tree_unflatten(aux_data, children)

Unflatten the object from children and auxiliary data.

Parameters:
  • aux_data (tuple) – Metadata (e.g. number of PSFs).

  • children (tuple) – Flattened arrays representing object state.

Returns:

Reconstructed object.

Return type:

WinniePSF

grater_jax.disk_model.winnie_class.jax_rotate_image(image, angle, reshape=False, cval=0.0, order=1)

Rotate a 2D image using JAX-compatible coordinate transformations.

Parameters:
  • image (jax.numpy.ndarray) – Input 2D image.

  • angle (float) – Rotation angle in degrees.

  • reshape (bool, optional) – Whether to reshape output to fit the entire rotated image (not used).

  • cval (float, optional) – Constant value for points outside boundaries.

  • order (int, optional) – Spline interpolation order (0=nearest, 1=linear, etc).

Returns:

Rotated image of the same shape.

Return type:

jax.numpy.ndarray