grater_jax.disk_model.jax_gradient_wrappers

jax_gradient_wrappers.py

Differentiable log-likelihood functions for disk modeling.

This module defines JAX-compiled objective functions and their gradients for disk image generation and likelihood evaluation. It provides standard, Winnie PSF, and spline-based variants, along with convenience wrappers for computing per-pixel residuals and Gaussian log-likelihoods. These functions are meant to be differentiable versions of the methods in jax_model_wrappers.py, however, there is some overhead needed to match the parameters of the two, this is handled by objective_grad in objective_functions.py.

Functions

jax_model_grad(DiskModel, DistrModel, ...[, ...])

Gradient of jax_model_ll with respect to positional argument(s) (5, 6, 7, 8, 12).

jax_model_ll(DiskModel, DistrModel, ...[, ...])

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

jax_model_spline_grad(DiskModel, DistrModel, ...)

Gradient of jax_model_spline_ll with respect to positional argument(s) (5, 6, 7, 8, 12).

jax_model_spline_ll(DiskModel, DistrModel, ...)

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

jax_model_spline_winnie_grad(DiskModel, ...)

Gradient of jax_model_spline_winnie_ll with respect to positional argument(s) (5, 6, 7, 11).

jax_model_spline_winnie_ll(DiskModel, ...[, ...])

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

jax_model_winnie_grad(DiskModel, DistrModel, ...)

Gradient of jax_model_winnie_ll with respect to positional argument(s) (5, 6, 7, 11).

jax_model_winnie_ll(DiskModel, DistrModel, ...)

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

log_likelihood(image, target_image, err_map)

Compute the Gaussian log-likelihood between a model image and observed data.

residuals(image, target_image, err_map)

Compute per-pixel squared normalized residuals plus log-variance term.

Classes

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

grater_jax.disk_model.jax_gradient_wrappers.log_likelihood(image, target_image, err_map, ll_method='sum')

Compute the Gaussian log-likelihood between a model image and observed data.

Parameters:
  • image (jnp.ndarray) – Model image of shape (H, W).

  • target_image (jnp.ndarray) – Observed image of the same shape as image.

  • err_map (jnp.ndarray) – Per-pixel standard deviation (noise map); same shape as image.

  • ll_method ({'sum', 'mean'}, optional) – How to reduce the per-pixel Gaussian log-density terms across the image. 'sum' (default) returns the total log-likelihood — the statistically correct choice for MCMC, since log-posterior ratios depend on absolute values. 'mean' divides by jnp.size(image) and returns the per-pixel average — sometimes preferred for scipy gradient-based optimisation because the smaller magnitude keeps the L-BFGS-B ftol convergence criterion from tripping prematurely on large images.

Returns:

The reduced log-likelihood assuming independent Gaussian errors.

Return type:

float

Raises:

ValueError – If ll_method is not 'sum' or 'mean'.

grater_jax.disk_model.jax_gradient_wrappers.residuals(image, target_image, err_map)

Compute per-pixel squared normalized residuals plus log-variance term.

Parameters:
  • image (jnp.ndarray) – Model image of shape (H, W).

  • target_image (jnp.ndarray) – Observed image of the same shape.

  • err_map (jnp.ndarray) – Per-pixel standard deviation (noise map); same shape as image.

Returns:

Residual map of the same shape as image.

Return type:

jnp.ndarray

grater_jax.disk_model.jax_gradient_wrappers.jax_model_ll(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel != Winnie_PSF and FuncModel != InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

PSFModelclass (Can be found in disk_model/SLD_utils.py)

The point spread function model type

disk_paramsjnp.array

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsjnp.array

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

psf_paramsjnp.array

The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

stellar_psf_paramsjnp.array, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image.

distancefloat

Distance to the star in pc (default 70.)

pxInArcsecfloat

Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)

nxint

number of pixels along the x axis of the image (default 200)

nyint

number of pixels along the y axis of the image (default 200)

halfNbSlicesinteger

half number of distances along the line of sight

flux_scalingfloat

Scaling factor for disk model.

Returns:

Log likelihood for the generated disk image, target image, and error map.

Return type:

float

grater_jax.disk_model.jax_gradient_wrappers.jax_model_winnie_ll(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel == Winnie_PSF and FuncModel != InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

winnie_psfclass (Can be found in disk_model/winnie_class.py)

JWST off axis PSF, modeled after Winnie framework

disk_paramsjnp.array

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsjnp.array

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

stellar_psf_paramsjnp.array, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image.

distancefloat

Distance to the star in pc (default 70.)

pxInArcsecfloat

Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)

nxint

number of pixels along the x axis of the image (default 200)

nyint

number of pixels along the y axis of the image (default 200)

halfNbSlicesinteger

half number of distances along the line of sight

flux_scalingfloat

Scaling factor for disk model.

Returns:

Log likelihood for the generated disk image, target image, and error map.

Return type:

float

grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_ll(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel != Winnie_PSF and FuncModel == InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

PSFModelclass (Can be found in disk_model/SLD_utils.py)

The point spread function model type

disk_paramsjnp.array

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsjnp.array

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

psf_paramsjnp.array

The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

stellar_psf_paramsjnp.array, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image.

distancefloat

Distance to the star in pc (default 70.)

pxInArcsecfloat

Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)

nxint

number of pixels along the x axis of the image (default 200)

nyint

number of pixels along the y axis of the image (default 200)

halfNbSlicesinteger

half number of distances along the line of sight

flux_scalingfloat

Scaling factor for disk model.

Returns:

Log likelihood for the generated disk image, target image, and error map.

Return type:

float

grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_winnie_ll(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. This function only applies when PSFModel == Winnie_PSF and FuncModel == InterpolatedUnivariateSpline_SPF. This function is later differentiated by jax grad.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

winnie_psfclass (Can be found in disk_model/winnie_class.py)

JWST off axis PSF, modeled after Winnie framework

disk_paramsjnp.array

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsjnp.array

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

stellar_psf_paramsjnp.array, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image.

distancefloat

Distance to the star in pc (default 70.)

pxInArcsecfloat

Pixel field of view in arcsec/px (default the SPHERE pixel scale 0.01225 arcsec/px)

nxint

number of pixels along the x axis of the image (default 200)

nyint

number of pixels along the y axis of the image (default 200)

halfNbSlicesinteger

half number of distances along the line of sight

flux_scalingfloat

Scaling factor for disk model.

Returns:

Log likelihood for the generated disk image, target image, and error map.

Return type:

float

grater_jax.disk_model.jax_gradient_wrappers.jax_model_grad(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')

Gradient of jax_model_ll with respect to positional argument(s) (5, 6, 7, 8, 12). Takes the same arguments as jax_model_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 8, 12).

grater_jax.disk_model.jax_gradient_wrappers.jax_model_winnie_grad(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, ll_method='sum')

Gradient of jax_model_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). Takes the same arguments as jax_model_winnie_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 11).

grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_grad(DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, disk_params, spf_params, psf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')

Gradient of jax_model_spline_ll with respect to positional argument(s) (5, 6, 7, 8, 12). Takes the same arguments as jax_model_spline_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 8, 12).

grater_jax.disk_model.jax_gradient_wrappers.jax_model_spline_winnie_grad(DiskModel, DistrModel, FuncModel, winnie_psf, StellarPSFModel, disk_params, spf_params, stellar_psf_params, target_image, err_map, throughput, flux_scaling=1000000.0, distance=0.0, pxInArcsec=0.0, nx=140, ny=140, halfNbSlices=25, knots=array([1., 0.5, -0.5, -1.]), ll_method='sum')

Gradient of jax_model_spline_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). Takes the same arguments as jax_model_spline_winnie_ll but returns the gradient, which has the same shape as the arguments at positions (5, 6, 7, 11).