grater_jax.disk_model.SLD_utils

SLD_utils.py

Utility classes for disk modeling.

This module defines base JAX classes and implementations of density distributions, scattering phase functions, and point spread functions (PSFs) used in scattered-light disk forward modeling. It includes:

  • Jax_class : Base class for packing/unpacking parameter dictionaries.

  • DustEllipticalDistribution2PowerLaws : Two-power-law dust density model.

  • HenyeyGreenstein_SPF, DoubleHenyeyGreenstein_SPF : Scattering phase functions.

  • InterpolatedUnivariateSpline_SPF : Spline-based scattering phase function.

  • GAUSSIAN_PSF, EMP_PSF, Winnie_PSF : PSF models.

  • LinearStellarPSF, PositionalStellarPSF : Stellar PSF models using reference images.

This can be added to in order to introduce new distribution functions, scattering phase functions, and point spread functions to the framework.

Functions

recommended_num_knots(inclination_deg, ...)

Return a recommended knot count scaled to the scattering angles probed at a given inclination.

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

Classes

DoubleHenyeyGreenstein_SPF()

Implementation of a scattering phase function with a double Henyey Greenstein function.

DustEllipticalDistribution2PowerLaws()

Two-power-law elliptical dust density distribution model.

EMP_PSF()

Empirical point spread function (PSF) model.

GAUSSIAN_PSF()

Gaussian PSF model.

HenyeyGreenstein_SPF()

Implementation of a scattering phase function with a single Henyey Greenstein function.

InterpolatedUnivariateSpline(x, y[, k, ...])

InterpolatedUnivariateSpline_SPF()

Implementation of a spline scattering phase function.

Jax_class()

Base class for custom JAX-compatible objects that can be compressed into and uncompressed from JAX arrays.

LinearStellarPSF()

Stellar PSF model as a linear combination of reference images.

PositionalStellarPSF()

Stellar PSF model with position-dependent reference image weights.

StellarPSFReference()

Reference images that the Stellar PSF classes will use.

WinniePSF(psfs, psf_inds_rolls, ...)

A class for spatially-varying PSF convolution with image cubes, including rotation and masking per roll angle.

Winnie_PSF()

Creates a JWST PSF model, using the package Winnie.

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

class grater_jax.disk_model.SLD_utils.Jax_class

Bases: object

Base class for custom JAX-compatible objects that can be compressed into and uncompressed from JAX arrays.

params = {}
classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

classmethod pack_pars(p_dict)

Pack a parameter dictionary into a JAX array.

The order of parameters in the array is determined by the key order in cls.params.

Parameters:

p_dict (dict) – Dictionary mapping parameter names (str) to values.

Returns:

1D array of parameter values in the order specified by cls.params.

Return type:

jax.numpy.ndarray

class grater_jax.disk_model.SLD_utils.DustEllipticalDistribution2PowerLaws

Bases: Jax_class

Two-power-law elliptical dust density distribution model.

params = {'accuracy': 0.005, 'alpha_in': 5.0, 'alpha_out': -5.0, 'beta': 1.0, 'dens_at_r0': 1.0, 'e': 0.0, 'gamma': 2.0, 'itiltthreshold': 0.0, 'ksi0': 1.0, 'p': 0.0, 'pmin': 0.0, 'rmax': 0.0, 'rmin': 0.0, 'rpeak': 0.0, 'rpeak_surface_density': 0.0, 'sma': 60.0, 'zmax': 0.0}
classmethod init(accuracy=0.005, alpha_in=5.0, alpha_out=-5.0, sma=60.0, e=0.0, ksi0=1.0, gamma=2.0, beta=1.0, rmin=0.0, dens_at_r0=1.0)

Constructor for the Dust_distribution class.

We assume the dust density is 0 radially after it drops below 0.5% (the accuracy variable) of the peak density in the midplane, and vertically whenever it drops below 0.5% of the peak density in the midplane.

Based off of code from VIP disk forward modeling by Julien Milli

Parameters:
  • accuracy (float) – Density limit as described above. Default is 5.e-3.

  • alpha_in (float) – slope of the power-low distribution in the inner disk. It must be positive (default 5)

  • alpha_out (float) – slope of the power-low distribution in the outer disk. It must be negative (default -5)

  • sma (float) – reference radius in au (default 60)

  • e (float) – eccentricity (default 0)

  • ksi0 (float) – scale height in au at the reference radius (default 1 a.u.)

  • gamma (float) – exponent (2=gaussian,1=exponential profile, default 2)

  • beta (float) – flaring index (0=no flaring, 1=linear flaring, default 1)

  • rmin (float) – minimum semi-major axis: the dust density is 0 below this value (default 0)

  • parameters. (Other kwargs / params are generated from the above)

classmethod density_cylindrical(distr_params, r, costheta, z)

Returns the particule volume density at r, theta, z

classmethod pack_pars(p_dict)

Pack a parameter dictionary into a JAX array.

The order of parameters in the array is determined by the key order in cls.params.

Parameters:

p_dict (dict) – Dictionary mapping parameter names (str) to values.

Returns:

1D array of parameter values in the order specified by cls.params.

Return type:

jax.numpy.ndarray

classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

class grater_jax.disk_model.SLD_utils.HenyeyGreenstein_SPF

Bases: Jax_class

Implementation of a scattering phase function with a single Henyey Greenstein function.

params = {'g': 0.3}
classmethod init(func_params)

Constructor of a Heyney Greenstein phase function.

Parameters:

spf_dico (dictionary containing the key "g" (float)) – g is the Heyney Greenstein coefficient and should be between -1 (backward scattering) and 1 (forward scattering).

classmethod compute_phase_function_from_cosphi(phase_func_params, cos_phi)

Compute the phase function at (a) specific scattering scattering angle(s) phi. The argument is not phi but cos(phi) for optimization reasons.

Parameters:

cos_phi (float or array) – cosine of the scattering angle(s) at which the scattering function must be calculated.

classmethod pack_pars(p_dict)

Pack a parameter dictionary into a JAX array.

The order of parameters in the array is determined by the key order in cls.params.

Parameters:

p_dict (dict) – Dictionary mapping parameter names (str) to values.

Returns:

1D array of parameter values in the order specified by cls.params.

Return type:

jax.numpy.ndarray

classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

class grater_jax.disk_model.SLD_utils.DoubleHenyeyGreenstein_SPF

Bases: Jax_class

Implementation of a scattering phase function with a double Henyey Greenstein function.

Parameters:
  • g1 (float) – the first Heyney Greenstein coefficient and should be between -1 (backward scattering) and 1 (forward scattering)

  • g2 (float) – the second Heyney Greenstein coefficient and should be between -1 (backward scattering) and 1 (forward scattering)

  • weight (float) – weighting of the first Henyey Greenstein component

params = {'g1': 0.5, 'g2': -0.3, 'weight': 0.7}
classmethod init(func_params)
classmethod compute_phase_function_from_cosphi(phase_func_params, cos_phi)

Compute the phase function at (a) specific scattering scattering angle(s) phi. The argument is not phi but cos(phi) for optimization reasons.

Parameters:

cos_phi (float or array) – cosine of the scattering angle(s) at which the scattering function must be calculated.

classmethod pack_pars(p_dict)

Pack a parameter dictionary into a JAX array.

The order of parameters in the array is determined by the key order in cls.params.

Parameters:

p_dict (dict) – Dictionary mapping parameter names (str) to values.

Returns:

1D array of parameter values in the order specified by cls.params.

Return type:

jax.numpy.ndarray

classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

grater_jax.disk_model.SLD_utils.recommended_num_knots(inclination_deg, num_knots_full_range, boundary_buffer=0.1)

Return a recommended knot count scaled to the scattering angles probed at a given inclination.

A disk at inclination i samples scattering angles roughly in the range [90 - i, 90 + i] degrees, which corresponds to cos(phi) in [-sin(i), +sin(i)]. Placing num_knots_full_range knots over the full 180-degree range is appropriate when the full range is probed (high inclinations), but over-parameterises the spline for low inclinations where only a narrow window is accessible. Over-parameterisation makes the optimiser ill-conditioned and can cause convergence failures.

This function scales the knot count linearly with the angular window that is actually probed, with a small outward buffer so knots are not placed right at the geometric boundary.

The minimum returned value is 4. While a cubic spline technically requires only 4 control points (num_knots = 3), empirical testing shows that num_knots = 3 produces a spline that is too rigid: the optimiser consistently converges to image-consistent but SPF-inconsistent local minima at low inclinations. num_knots = 4 (5 control points) provides enough flexibility to recover the SPF shape reliably.

Using this function is optional. Users can always pass num_knots directly to InterpolatedUnivariateSpline_SPF.params; this helper just provides a principled starting point when the inclination is known.

Parameters:
  • inclination_deg (float) – Disk inclination in degrees (0 = face-on, 90 = edge-on).

  • num_knots_full_range (int) – Desired knot count for a disk that probes the full 0–180 degree range. Typically 5–8; the default InterpolatedUnivariateSpline_SPF uses 6.

  • boundary_buffer (float, optional) – Buffer added beyond the geometric probed range in cos_phi units before computing the angular window. Gives the spline a little room at the edges so knots are not placed exactly on the boundary. Default 0.1.

Returns:

Recommended number of free knots, in the range [4, num_knots_full_range].

Return type:

int

Examples

>>> recommended_num_knots(30.0, 6)
4
>>> recommended_num_knots(50.0, 6)
4
>>> recommended_num_knots(80.0, 6)
6
class grater_jax.disk_model.SLD_utils.InterpolatedUnivariateSpline_SPF

Bases: Jax_class

Implementation of a spline scattering phase function. Uses 6 knots by default, takes knot y values as parameters. Locations are fixed to the given knots, pack_pars and init both return the spline model itself

Parameters:
  • backscatt_bound (float) – cosine of bound on back scattering (closer to 180 deg) scattering angle used for the spline

  • forwardscatt_bound (float) – cosine of bound on forward scattering (closer to 0 deg) scattering angle used for the spline

  • num_knots (int) – number of knots

  • knot_values (array) – y values of the knots

params = {'backscatt_bound': -1, 'forwardscatt_bound': 1, 'knot_values': (1.0, 1.0, 1.0, 1.0, 1.0, 1.0), 'num_knots': 6}
classmethod init(p_arr, knots)
classmethod get_knots(p_dict)

Return knot x-positions (in cos_phi space) for the spline SPF.

Returns num_knots + 1 positions. A fixed knot is always placed at cos_phi = 0 (90 degrees), which is the normalization point. The remaining num_knots positions are split evenly on either side.

classmethod pack_pars(p_arr, knots)

Build an InterpolatedUnivariateSpline from the free knot values.

p_arr contains num_knots free values — the knot y-values at all positions except the fixed normalization point at cos_phi = 0. A value of 1.0 is inserted at the center index (n_left = len(p_arr) // 2) before constructing the spline, so spline(0) = 1 by construction. This eliminates the degeneracy between knot scale and flux_scaling.

classmethod compute_phase_function_from_cosphi(spline_model, cos_phi)

Compute the phase function at (a) specific scattering scattering angle(s) phi. The argument is not phi but cos(phi) for optimization reasons.

The spline is normalized so that it equals 1 at 90 degrees (cos_phi=0), breaking the degeneracy between the SPF knot values and the absolute flux scaling parameter.

When the spline knots do not cover the full [-1, 1] cos_phi range (i.e., when inclination-dependent knot placement is used), values outside the knot range are extrapolated rather than evaluated with the boundary polynomial segment, which diverges for cubic splines:

  • Forward side (cos_phi > knot max): linear extrapolation using the spline’s first derivative at the forward boundary. Physically motivated — SPFs typically rise toward forward scattering.

  • Backward side (cos_phi < knot min): constant extrapolation at the backward boundary value. SPFs are typically flat at large angles.

When knots cover the full [-1, 1] range these branches are never triggered, so this is fully backward-compatible.

Parameters:
  • spline_model (InterpolatedUnivariateSpline) – spline model to represent scattering light phase function

  • cos_phi (float or array) – cosine of the scattering angle(s) at which the scattering function must be calculated.

classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

class grater_jax.disk_model.SLD_utils.GAUSSIAN_PSF

Bases: Jax_class

Gaussian PSF model. The PSF is defined by the following parameters:

Parameters:
  • FWHM (float) – Full width at half maximum of the Gaussian PSF.

  • xo (float) – X coordinate of the center of the PSF.

  • yo (float) – Y coordinate of the center of the PSF.

  • theta (float) – Rotation angle of the PSF in radians.

  • offset (float) – Offset value to be added to the PSF.

  • amplitude (float) – Amplitude of the PSF.

params = {'FWHM': 3.0, 'amplitude': 1.0, 'offset': 0.0, 'theta': 0.0, 'xo': 0.0, 'yo': 0.0}
classmethod generate(image, psf_params)

Apply a Gaussian PSF to an image via FFT-based convolution.

Parameters:
  • image (jax.numpy.ndarray) – 2D input image to be smoothed.

  • psf_params (jax.numpy.ndarray) – Packed PSF parameter array as returned by pack_pars.

Returns:

Smoothed image with the Gaussian PSF applied and offset added.

Return type:

jax.numpy.ndarray

classmethod pack_pars(p_dict)

Pack a parameter dictionary into a JAX array.

The order of parameters in the array is determined by the key order in cls.params.

Parameters:

p_dict (dict) – Dictionary mapping parameter names (str) to values.

Returns:

1D array of parameter values in the order specified by cls.params.

Return type:

jax.numpy.ndarray

classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

class grater_jax.disk_model.SLD_utils.EMP_PSF

Bases: Jax_class

Empirical point spread function (PSF) model.

params = {'offset': 1.0, 'scale_factor': 1.0}
img = None
classmethod generate(image, psf_params)

Convolve the input image with the empirical PSF via FFT.

Parameters:
  • image (jax.numpy.ndarray) – 2D input image to convolve.

  • psf_params (jax.numpy.ndarray) – Packed PSF parameter array (unused; convolution uses cls.img).

Returns:

Image convolved with the stored empirical PSF.

Return type:

jax.numpy.ndarray

classmethod pack_pars(p_dict)

Pack a parameter dictionary into a JAX array.

The order of parameters in the array is determined by the key order in cls.params.

Parameters:

p_dict (dict) – Dictionary mapping parameter names (str) to values.

Returns:

1D array of parameter values in the order specified by cls.params.

Return type:

jax.numpy.ndarray

classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

class grater_jax.disk_model.SLD_utils.Winnie_PSF

Bases: Jax_class

Creates a JWST PSF model, using the package Winnie. See Winnie for further JWST PSF documentation.

classmethod init(psfs, psf_inds_rolls, im_mask_rolls, psf_offsets, psf_parangs, num_unique_psfs)

Initialise and return a WinniePSF object from raw PSF grid arrays.

Parameters:
  • psfs (array-like) – PSF grid images.

  • psf_inds_rolls (array-like) – Roll indices for each PSF.

  • im_mask_rolls (array-like) – Image mask rolls.

  • psf_offsets (array-like) – Positional offsets for each PSF.

  • psf_parangs (array-like) – Position angles for each PSF.

  • num_unique_psfs (int) – Number of unique PSFs in the grid.

Returns:

Initialised Winnie PSF model object.

Return type:

WinniePSF

classmethod pack_pars(winnie_model)

Return the WinniePSF model unchanged (identity packing).

Parameters:

winnie_model (WinniePSF) – Winnie PSF model object.

Returns:

The same model object, unchanged.

Return type:

WinniePSF

classmethod generate(image, winnie_model)

Convolve an image with the Winnie PSF and return the mean over spacecraft rolls.

Parameters:
  • image (jax.numpy.ndarray) – 2D input image to convolve.

  • winnie_model (WinniePSF) – Packed Winnie PSF model as returned by pack_pars.

Returns:

Mean of the roll-convolved image cube.

Return type:

jax.numpy.ndarray

params = {}
classmethod unpack_pars(p_arr)

Unpack a parameter array into a dictionary keyed by parameter names.

Parameters:

p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in cls.params.

Returns:

Dictionary mapping parameter names (str) to their corresponding values.

Return type:

dict

class grater_jax.disk_model.SLD_utils.StellarPSFReference

Bases: object

Reference images that the Stellar PSF classes will use.

reference_images = array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
class grater_jax.disk_model.SLD_utils.LinearStellarPSF

Bases: Jax_class

Stellar PSF model as a linear combination of reference images.

params = {'stellar_weights': None}
classmethod pack_pars(p_dict)

Pack stellar PSF parameters into a flat array.

Parameters:

p_dict (dict) – Parameter dictionary with key 'stellar_weights'.

Returns:

1D array of stellar weights.

Return type:

jax.numpy.ndarray

classmethod unpack_pars(stellar_psf_params)

Unpack a flat parameter array into the stellar PSF parameter dictionary.

Parameters:

stellar_psf_params (jax.numpy.ndarray) – 1D array of stellar weights.

Returns:

Dictionary with key 'stellar_weights'.

Return type:

dict

classmethod compute_stellar_psf_image(stellar_weights, nx, ny)

Computes the on axis psf from the reference images and linear weights. Resizes the final image to (nx, ny).

class grater_jax.disk_model.SLD_utils.PositionalStellarPSF

Bases: Jax_class

Stellar PSF model with position-dependent reference image weights.

params = {'stellar_weights': None, 'stellar_xs': None, 'stellar_ys': None}
classmethod pack_pars(p_dict)

Pack positional stellar PSF parameters into a single flat array.

Concatenates stellar_weights, stellar_xs, and stellar_ys in that order.

Parameters:

p_dict (dict) – Dictionary with keys 'stellar_weights', 'stellar_xs', and 'stellar_ys'.

Returns:

1D concatenated parameter array.

Return type:

jax.numpy.ndarray

classmethod unpack_pars(stellar_psf_params)

Unpack a flat parameter array into the positional stellar PSF dictionary.

Splits the array into stellar_weights, stellar_xs, and stellar_ys based on the number of PSF reference images.

Parameters:

stellar_psf_params (jax.numpy.ndarray) – 1D concatenated array produced by pack_pars.

Returns:

Dictionary with keys 'stellar_weights', 'stellar_xs', 'stellar_ys'.

Return type:

dict

classmethod compute_stellar_psf_image(stellar_psf_params, nx, ny)

Efficiently computes the resulting stellar psf from the linear weights, x positions, and y positions. Resizes the final image to (nx, ny).