grater_jax.disk_model.jax_gradient_wrappers
jax_gradient_wrappers.py
Differentiable log-likelihood functions for disk modeling.
This module defines JAX-compiled objective functions and their gradients for disk image generation and likelihood evaluation. It provides standard, Winnie PSF, and spline-based variants, along with convenience wrappers for computing per-pixel residuals and Gaussian log-likelihoods. These functions are meant to be differentiable versions of the methods in jax_model_wrappers.py, however, there is some overhead needed to match the parameters of the two, this is handled by objective_grad in objective_functions.py.
Functions
|
Gradient of jax_model_ll with respect to positional argument(s) (5, 6, 7, 8, 12). |
|
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Gradient of jax_model_spline_ll with respect to positional argument(s) (5, 6, 7, 8, 12). |
|
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Gradient of jax_model_spline_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). |
|
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Gradient of jax_model_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). |
|
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Compute the Gaussian log-likelihood between a model image and observed data. |
|
Compute per-pixel squared normalized residuals plus log-variance term. |
Classes
|
partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords. |
- grater_jax.disk_model.jax_gradient_wrappers.log_likelihood(image, target_image, err_map, ll_method='sum')
Compute the Gaussian log-likelihood between a model image and observed data.
- Parameters:
image (jnp.ndarray) – Model image of shape (H, W).
target_image (jnp.ndarray) – Observed image of the same shape as image.
err_map (jnp.ndarray) – Per-pixel standard deviation (noise map); same shape as image.
ll_method ({'sum', 'mean'}, optional) – How to reduce the per-pixel Gaussian log-density terms across the image.
'sum'(default) returns the total log-likelihood — the statistically correct choice for MCMC, since log-posterior ratios depend on absolute values.'mean'divides byjnp.size(image)and returns the per-pixel average — sometimes preferred for scipy gradient-based optimisation because the smaller magnitude keeps the L-BFGS-Bftolconvergence criterion from tripping prematurely on large images.
- Returns:
The reduced log-likelihood assuming independent Gaussian errors.
- Return type:
float
- Raises:
ValueError – If
ll_methodis not'sum'or'mean'.
- grater_jax.disk_model.jax_gradient_wrappers.residuals(image, target_image, err_map)
Compute per-pixel squared normalized residuals plus log-variance term.
- Parameters:
image (jnp.ndarray) – Model image of shape (H, W).
target_image (jnp.ndarray) – Observed image of the same shape.
err_map (jnp.ndarray) – Per-pixel standard deviation (noise map); same shape as image.
- Returns:
Residual map of the same shape as image.
- Return type:
jnp.ndarray
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_ll(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel != Winnie_PSF and FuncModel != InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- PSFModelclass (Can be found in disk_model/SLD_utils.py)
The point spread function model type
- disk_paramsjnp.array
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsjnp.array
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- psf_paramsjnp.array
The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- stellar_psf_paramsjnp.array, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image.
- distancefloat
Distance to the star in pc (default 70.)
- pxInArcsecfloat
Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)
- nxint
number of pixels along the x axis of the image (default 200)
- nyint
number of pixels along the y axis of the image (default 200)
- halfNbSlicesinteger
half number of distances along the line of sight
- flux_scalingfloat
Scaling factor for disk model.
- Returns:
Log likelihood for the generated disk image, target image, and error map.
- Return type:
float
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_winnie_ll(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel == Winnie_PSF and FuncModel != InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- winnie_psfclass (Can be found in disk_model/winnie_class.py)
JWST off axis PSF, modeled after Winnie framework
- disk_paramsjnp.array
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsjnp.array
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- stellar_psf_paramsjnp.array, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image.
- distancefloat
Distance to the star in pc (default 70.)
- pxInArcsecfloat
Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)
- nxint
number of pixels along the x axis of the image (default 200)
- nyint
number of pixels along the y axis of the image (default 200)
- halfNbSlicesinteger
half number of distances along the line of sight
- flux_scalingfloat
Scaling factor for disk model.
- Returns:
Log likelihood for the generated disk image, target image, and error map.
- Return type:
float
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_ll(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel != Winnie_PSF and FuncModel == InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- PSFModelclass (Can be found in disk_model/SLD_utils.py)
The point spread function model type
- disk_paramsjnp.array
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsjnp.array
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- psf_paramsjnp.array
The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- stellar_psf_paramsjnp.array, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image.
- distancefloat
Distance to the star in pc (default 70.)
- pxInArcsecfloat
Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)
- nxint
number of pixels along the x axis of the image (default 200)
- nyint
number of pixels along the y axis of the image (default 200)
- halfNbSlicesinteger
half number of distances along the line of sight
- flux_scalingfloat
Scaling factor for disk model.
- Returns:
Log likelihood for the generated disk image, target image, and error map.
- Return type:
float
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_winnie_ll(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel == Winnie_PSF and FuncModel == InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- winnie_psfclass (Can be found in disk_model/winnie_class.py)
JWST off axis PSF, modeled after Winnie framework
- disk_paramsjnp.array
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsjnp.array
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- stellar_psf_paramsjnp.array, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image.
- distancefloat
Distance to the star in pc (default 70.)
- pxInArcsecfloat
Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)
- nxint
number of pixels along the x axis of the image (default 200)
- nyint
number of pixels along the y axis of the image (default 200)
- halfNbSlicesinteger
half number of distances along the line of sight
- flux_scalingfloat
Scaling factor for disk model.
- Returns:
Log likelihood for the generated disk image, target image, and error map.
- Return type:
float
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_grad(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')
Gradient of jax_model_ll with respect to positional argument(s) (5, 6, 7, 8, 12). Takes the same arguments as jax_model_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 8, 12).
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_winnie_grad(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')
Gradient of jax_model_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). Takes the same arguments as jax_model_winnie_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 11).
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_grad(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')
Gradient of jax_model_spline_ll with respect to positional argument(s) (5, 6, 7, 8, 12). Takes the same arguments as jax_model_spline_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 8, 12).
- grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_winnie_grad(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')
Gradient of jax_model_spline_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). Takes the same arguments as jax_model_spline_winnie_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 11).