grater_jax.disk_model.SLD_utils
SLD_utils.py
Utility classes for disk modeling.
This module defines base JAX classes and implementations of density distributions, scattering phase functions, and point spread functions (PSFs) used in scattered-light disk forward modeling. It includes:
Jax_class : Base class for packing/unpacking parameter dictionaries.
DustEllipticalDistribution2PowerLaws : Two-power-law dust density model.
HenyeyGreenstein_SPF, DoubleHenyeyGreenstein_SPF : Scattering phase functions.
InterpolatedUnivariateSpline_SPF : Spline-based scattering phase function.
GAUSSIAN_PSF, EMP_PSF, Winnie_PSF : PSF models.
LinearStellarPSF, PositionalStellarPSF : Stellar PSF models using reference images.
This can be added to in order to introduce new distribution functions, scattering phase functions, and point spread functions to the framework.
Functions
|
Return a recommended knot count scaled to the scattering angles probed at a given inclination. |
|
Vectorizing map. |
Classes
Implementation of a scattering phase function with a double Henyey Greenstein function. |
|
Two-power-law elliptical dust density distribution model. |
|
|
Empirical point spread function (PSF) model. |
Gaussian PSF model. |
|
Implementation of a scattering phase function with a single Henyey Greenstein function. |
|
|
|
Implementation of a spline scattering phase function. |
|
Base class for custom JAX-compatible objects that can be compressed into and uncompressed from JAX arrays. |
|
Stellar PSF model as a linear combination of reference images. |
|
Stellar PSF model with position-dependent reference image weights. |
|
Reference images that the Stellar PSF classes will use. |
|
|
A class for spatially-varying PSF convolution with image cubes, including rotation and masking per roll angle. |
Creates a JWST PSF model, using the package Winnie. |
|
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
- class grater_jax.disk_model.SLD_utils.Jax_class
Bases:
objectBase class for custom JAX-compatible objects that can be compressed into and uncompressed from JAX arrays.
- params = {}
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- classmethod pack_pars(p_dict)
Pack a parameter dictionary into a JAX array.
The order of parameters in the array is determined by the key order in
cls.params.- Parameters:
p_dict (dict) – Dictionary mapping parameter names (str) to values.
- Returns:
1D array of parameter values in the order specified by
cls.params.- Return type:
jax.numpy.ndarray
- class grater_jax.disk_model.SLD_utils.DustEllipticalDistribution2PowerLaws
Bases:
Jax_classTwo-power-law elliptical dust density distribution model.
- params = {'accuracy': 0.005, 'alpha_in': 5.0, 'alpha_out': -5.0, 'beta': 1.0, 'dens_at_r0': 1.0, 'e': 0.0, 'gamma': 2.0, 'itiltthreshold': 0.0, 'ksi0': 1.0, 'p': 0.0, 'pmin': 0.0, 'rmax': 0.0, 'rmin': 0.0, 'rpeak': 0.0, 'rpeak_surface_density': 0.0, 'sma': 60.0, 'zmax': 0.0}
- classmethod init(accuracy=0.005, alpha_in=5.0, alpha_out=-5.0, sma=60.0, e=0.0, ksi0=1.0, gamma=2.0, beta=1.0, rmin=0.0, dens_at_r0=1.0)
Constructor for the Dust_distribution class.
We assume the dust density is 0 radially after it drops below 0.5% (the accuracy variable) of the peak density in the midplane, and vertically whenever it drops below 0.5% of the peak density in the midplane.
Based off of code from VIP disk forward modeling by Julien Milli
- Parameters:
accuracy (float) – Density limit as described above. Default is 5.e-3.
alpha_in (float) – slope of the power-low distribution in the inner disk. It must be positive (default 5)
alpha_out (float) – slope of the power-low distribution in the outer disk. It must be negative (default -5)
sma (float) – reference radius in au (default 60)
e (float) – eccentricity (default 0)
ksi0 (float) – scale height in au at the reference radius (default 1 a.u.)
gamma (float) – exponent (2=gaussian,1=exponential profile, default 2)
beta (float) – flaring index (0=no flaring, 1=linear flaring, default 1)
rmin (float) – minimum semi-major axis: the dust density is 0 below this value (default 0)
parameters. (Other kwargs / params are generated from the above)
- classmethod density_cylindrical(distr_params, r, costheta, z)
Returns the particule volume density at r, theta, z
- classmethod pack_pars(p_dict)
Pack a parameter dictionary into a JAX array.
The order of parameters in the array is determined by the key order in
cls.params.- Parameters:
p_dict (dict) – Dictionary mapping parameter names (str) to values.
- Returns:
1D array of parameter values in the order specified by
cls.params.- Return type:
jax.numpy.ndarray
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- class grater_jax.disk_model.SLD_utils.HenyeyGreenstein_SPF
Bases:
Jax_classImplementation of a scattering phase function with a single Henyey Greenstein function.
- params = {'g': 0.3}
- classmethod init(func_params)
Constructor of a Heyney Greenstein phase function.
- Parameters:
spf_dico (dictionary containing the key "g" (float)) – g is the Heyney Greenstein coefficient and should be between -1 (backward scattering) and 1 (forward scattering).
- classmethod compute_phase_function_from_cosphi(phase_func_params, cos_phi)
Compute the phase function at (a) specific scattering scattering angle(s) phi. The argument is not phi but cos(phi) for optimization reasons.
- Parameters:
cos_phi (float or array) – cosine of the scattering angle(s) at which the scattering function must be calculated.
- classmethod pack_pars(p_dict)
Pack a parameter dictionary into a JAX array.
The order of parameters in the array is determined by the key order in
cls.params.- Parameters:
p_dict (dict) – Dictionary mapping parameter names (str) to values.
- Returns:
1D array of parameter values in the order specified by
cls.params.- Return type:
jax.numpy.ndarray
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- class grater_jax.disk_model.SLD_utils.DoubleHenyeyGreenstein_SPF
Bases:
Jax_classImplementation of a scattering phase function with a double Henyey Greenstein function.
- Parameters:
g1 (float) – the first Heyney Greenstein coefficient and should be between -1 (backward scattering) and 1 (forward scattering)
g2 (float) – the second Heyney Greenstein coefficient and should be between -1 (backward scattering) and 1 (forward scattering)
weight (float) – weighting of the first Henyey Greenstein component
- params = {'g1': 0.5, 'g2': -0.3, 'weight': 0.7}
- classmethod init(func_params)
- classmethod compute_phase_function_from_cosphi(phase_func_params, cos_phi)
Compute the phase function at (a) specific scattering scattering angle(s) phi. The argument is not phi but cos(phi) for optimization reasons.
- Parameters:
cos_phi (float or array) – cosine of the scattering angle(s) at which the scattering function must be calculated.
- classmethod pack_pars(p_dict)
Pack a parameter dictionary into a JAX array.
The order of parameters in the array is determined by the key order in
cls.params.- Parameters:
p_dict (dict) – Dictionary mapping parameter names (str) to values.
- Returns:
1D array of parameter values in the order specified by
cls.params.- Return type:
jax.numpy.ndarray
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- grater_jax.disk_model.SLD_utils.recommended_num_knots(inclination_deg, num_knots_full_range, boundary_buffer=0.1)
Return a recommended knot count scaled to the scattering angles probed at a given inclination.
A disk at inclination i samples scattering angles roughly in the range
[90 - i, 90 + i]degrees, which corresponds to cos(phi) in[-sin(i), +sin(i)]. Placingnum_knots_full_rangeknots over the full 180-degree range is appropriate when the full range is probed (high inclinations), but over-parameterises the spline for low inclinations where only a narrow window is accessible. Over-parameterisation makes the optimiser ill-conditioned and can cause convergence failures.This function scales the knot count linearly with the angular window that is actually probed, with a small outward buffer so knots are not placed right at the geometric boundary.
The minimum returned value is 4. While a cubic spline technically requires only 4 control points (
num_knots= 3), empirical testing shows thatnum_knots= 3 produces a spline that is too rigid: the optimiser consistently converges to image-consistent but SPF-inconsistent local minima at low inclinations.num_knots= 4 (5 control points) provides enough flexibility to recover the SPF shape reliably.Using this function is optional. Users can always pass
num_knotsdirectly toInterpolatedUnivariateSpline_SPF.params; this helper just provides a principled starting point when the inclination is known.- Parameters:
inclination_deg (float) – Disk inclination in degrees (0 = face-on, 90 = edge-on).
num_knots_full_range (int) – Desired knot count for a disk that probes the full 0–180 degree range. Typically 5–8; the default
InterpolatedUnivariateSpline_SPFuses 6.boundary_buffer (float, optional) – Buffer added beyond the geometric probed range in cos_phi units before computing the angular window. Gives the spline a little room at the edges so knots are not placed exactly on the boundary. Default 0.1.
- Returns:
Recommended number of free knots, in the range
[4, num_knots_full_range].- Return type:
int
Examples
>>> recommended_num_knots(30.0, 6) 4 >>> recommended_num_knots(50.0, 6) 4 >>> recommended_num_knots(80.0, 6) 6
- class grater_jax.disk_model.SLD_utils.InterpolatedUnivariateSpline_SPF
Bases:
Jax_classImplementation of a spline scattering phase function. Uses 6 knots by default, takes knot y values as parameters. Locations are fixed to the given knots, pack_pars and init both return the spline model itself
- Parameters:
backscatt_bound (float) – cosine of bound on back scattering (closer to 180 deg) scattering angle used for the spline
forwardscatt_bound (float) – cosine of bound on forward scattering (closer to 0 deg) scattering angle used for the spline
num_knots (int) – number of knots
knot_values (array) – y values of the knots
- params = {'backscatt_bound': -1, 'forwardscatt_bound': 1, 'knot_values': (1.0, 1.0, 1.0, 1.0, 1.0, 1.0), 'num_knots': 6}
- classmethod init(p_arr, knots)
- classmethod get_knots(p_dict)
Return knot x-positions (in cos_phi space) for the spline SPF.
Returns
num_knots + 1positions. A fixed knot is always placed atcos_phi = 0(90 degrees), which is the normalization point. The remainingnum_knotspositions are split evenly on either side.
- classmethod pack_pars(p_arr, knots)
Build an InterpolatedUnivariateSpline from the free knot values.
p_arrcontainsnum_knotsfree values — the knot y-values at all positions except the fixed normalization point atcos_phi = 0. A value of 1.0 is inserted at the center index (n_left = len(p_arr) // 2) before constructing the spline, sospline(0) = 1by construction. This eliminates the degeneracy between knot scale andflux_scaling.
- classmethod compute_phase_function_from_cosphi(spline_model, cos_phi)
Compute the phase function at (a) specific scattering scattering angle(s) phi. The argument is not phi but cos(phi) for optimization reasons.
The spline is normalized so that it equals 1 at 90 degrees (cos_phi=0), breaking the degeneracy between the SPF knot values and the absolute flux scaling parameter.
When the spline knots do not cover the full [-1, 1] cos_phi range (i.e., when inclination-dependent knot placement is used), values outside the knot range are extrapolated rather than evaluated with the boundary polynomial segment, which diverges for cubic splines:
Forward side (cos_phi > knot max): linear extrapolation using the spline’s first derivative at the forward boundary. Physically motivated — SPFs typically rise toward forward scattering.
Backward side (cos_phi < knot min): constant extrapolation at the backward boundary value. SPFs are typically flat at large angles.
When knots cover the full [-1, 1] range these branches are never triggered, so this is fully backward-compatible.
- Parameters:
spline_model (InterpolatedUnivariateSpline) – spline model to represent scattering light phase function
cos_phi (float or array) – cosine of the scattering angle(s) at which the scattering function must be calculated.
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- class grater_jax.disk_model.SLD_utils.GAUSSIAN_PSF
Bases:
Jax_classGaussian PSF model. The PSF is defined by the following parameters:
- Parameters:
FWHM (float) – Full width at half maximum of the Gaussian PSF.
xo (float) – X coordinate of the center of the PSF.
yo (float) – Y coordinate of the center of the PSF.
theta (float) – Rotation angle of the PSF in radians.
offset (float) – Offset value to be added to the PSF.
amplitude (float) – Amplitude of the PSF.
- params = {'FWHM': 3.0, 'amplitude': 1.0, 'offset': 0.0, 'theta': 0.0, 'xo': 0.0, 'yo': 0.0}
- classmethod generate(image, psf_params)
Apply a Gaussian PSF to an image via FFT-based convolution.
- Parameters:
image (jax.numpy.ndarray) – 2D input image to be smoothed.
psf_params (jax.numpy.ndarray) – Packed PSF parameter array as returned by
pack_pars.
- Returns:
Smoothed image with the Gaussian PSF applied and offset added.
- Return type:
jax.numpy.ndarray
- classmethod pack_pars(p_dict)
Pack a parameter dictionary into a JAX array.
The order of parameters in the array is determined by the key order in
cls.params.- Parameters:
p_dict (dict) – Dictionary mapping parameter names (str) to values.
- Returns:
1D array of parameter values in the order specified by
cls.params.- Return type:
jax.numpy.ndarray
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- class grater_jax.disk_model.SLD_utils.EMP_PSF
Bases:
Jax_classEmpirical point spread function (PSF) model.
- params = {'offset': 1.0, 'scale_factor': 1.0}
- img = None
- classmethod generate(image, psf_params)
Convolve the input image with the empirical PSF via FFT.
- Parameters:
image (jax.numpy.ndarray) – 2D input image to convolve.
psf_params (jax.numpy.ndarray) – Packed PSF parameter array (unused; convolution uses
cls.img).
- Returns:
Image convolved with the stored empirical PSF.
- Return type:
jax.numpy.ndarray
- classmethod pack_pars(p_dict)
Pack a parameter dictionary into a JAX array.
The order of parameters in the array is determined by the key order in
cls.params.- Parameters:
p_dict (dict) – Dictionary mapping parameter names (str) to values.
- Returns:
1D array of parameter values in the order specified by
cls.params.- Return type:
jax.numpy.ndarray
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- class grater_jax.disk_model.SLD_utils.Winnie_PSF
Bases:
Jax_classCreates a JWST PSF model, using the package Winnie. See Winnie for further JWST PSF documentation.
- classmethod init(psfs, psf_inds_rolls, im_mask_rolls, psf_offsets, psf_parangs, num_unique_psfs)
Initialise and return a WinniePSF object from raw PSF grid arrays.
- Parameters:
psfs (array-like) – PSF grid images.
psf_inds_rolls (array-like) – Roll indices for each PSF.
im_mask_rolls (array-like) – Image mask rolls.
psf_offsets (array-like) – Positional offsets for each PSF.
psf_parangs (array-like) – Position angles for each PSF.
num_unique_psfs (int) – Number of unique PSFs in the grid.
- Returns:
Initialised Winnie PSF model object.
- Return type:
- classmethod pack_pars(winnie_model)
Return the WinniePSF model unchanged (identity packing).
- classmethod generate(image, winnie_model)
Convolve an image with the Winnie PSF and return the mean over spacecraft rolls.
- Parameters:
image (jax.numpy.ndarray) – 2D input image to convolve.
winnie_model (WinniePSF) – Packed Winnie PSF model as returned by
pack_pars.
- Returns:
Mean of the roll-convolved image cube.
- Return type:
jax.numpy.ndarray
- params = {}
- classmethod unpack_pars(p_arr)
Unpack a parameter array into a dictionary keyed by parameter names.
- Parameters:
p_arr (jax.numpy.ndarray or array-like) – 1D array of parameter values. The order must match the keys in
cls.params.- Returns:
Dictionary mapping parameter names (str) to their corresponding values.
- Return type:
dict
- class grater_jax.disk_model.SLD_utils.StellarPSFReference
Bases:
objectReference images that the Stellar PSF classes will use.
- reference_images = array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
- class grater_jax.disk_model.SLD_utils.LinearStellarPSF
Bases:
Jax_classStellar PSF model as a linear combination of reference images.
- params = {'stellar_weights': None}
- classmethod pack_pars(p_dict)
Pack stellar PSF parameters into a flat array.
- Parameters:
p_dict (dict) – Parameter dictionary with key
'stellar_weights'.- Returns:
1D array of stellar weights.
- Return type:
jax.numpy.ndarray
- classmethod unpack_pars(stellar_psf_params)
Unpack a flat parameter array into the stellar PSF parameter dictionary.
- Parameters:
stellar_psf_params (jax.numpy.ndarray) – 1D array of stellar weights.
- Returns:
Dictionary with key
'stellar_weights'.- Return type:
dict
- classmethod compute_stellar_psf_image(stellar_weights, nx, ny)
Computes the on axis psf from the reference images and linear weights. Resizes the final image to (nx, ny).
- class grater_jax.disk_model.SLD_utils.PositionalStellarPSF
Bases:
Jax_classStellar PSF model with position-dependent reference image weights.
- params = {'stellar_weights': None, 'stellar_xs': None, 'stellar_ys': None}
- classmethod pack_pars(p_dict)
Pack positional stellar PSF parameters into a single flat array.
Concatenates
stellar_weights,stellar_xs, andstellar_ysin that order.- Parameters:
p_dict (dict) – Dictionary with keys
'stellar_weights','stellar_xs', and'stellar_ys'.- Returns:
1D concatenated parameter array.
- Return type:
jax.numpy.ndarray
- classmethod unpack_pars(stellar_psf_params)
Unpack a flat parameter array into the positional stellar PSF dictionary.
Splits the array into
stellar_weights,stellar_xs, andstellar_ysbased on the number of PSF reference images.- Parameters:
stellar_psf_params (jax.numpy.ndarray) – 1D concatenated array produced by
pack_pars.- Returns:
Dictionary with keys
'stellar_weights','stellar_xs','stellar_ys'.- Return type:
dict
- classmethod compute_stellar_psf_image(stellar_psf_params, nx, ny)
Efficiently computes the resulting stellar psf from the linear weights, x positions, and y positions. Resizes the final image to (nx, ny).