grater_jax.disk_model.objective_functions

objective_functions.py

Objective functions for disk modeling and optimization.

This module provides functions that connect high-level optimization routines to the underlying JAX-accelerated disk model wrappers. It defines parameter templates, packing utilities, and objective functions for model evaluation, log-likelihood computation, and gradient calculation.

Functions

jax_model(DiskModel, DistrModel, FuncModel, ...)

Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

jax_model_grad(DiskModel, DistrModel, ...[, ...])

Gradient of jax_model_ll with respect to positional argument(s) (5, 6, 7, 8, 12).

jax_model_spline(DiskModel, DistrModel, ...)

Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

jax_model_spline_grad(DiskModel, DistrModel, ...)

Gradient of jax_model_spline_ll with respect to positional argument(s) (5, 6, 7, 8, 12).

jax_model_spline_winnie(DiskModel, ...[, ...])

Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

jax_model_spline_winnie_grad(DiskModel, ...)

Gradient of jax_model_spline_winnie_ll with respect to positional argument(s) (5, 6, 7, 11).

jax_model_winnie(DiskModel, DistrModel, ...)

Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

jax_model_winnie_grad(DiskModel, DistrModel, ...)

Gradient of jax_model_winnie_ll with respect to positional argument(s) (5, 6, 7, 11).

log_likelihood(image, target_image, err_map)

Compute the Gaussian log-likelihood between a model image and observed data.

objective_fit(params_fit, fit_keys, ...[, ...])

Same as the objective_ll function but accepts replacement values for the given parameters in fit_keys.

objective_fit_grad(params_fit, fit_keys, ...)

Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

objective_grad(disk_params, spf_params, ...)

Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

objective_ll(disk_params, spf_params, ...[, ...])

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

objective_model(disk_params, spf_params, ...)

Generate a disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters.

pack_pars(p_dict, orig_dict)

Pack parameter values from a dictionary into a JAX array.

Classes

InterpolatedUnivariateSpline_SPF()

Implementation of a spline scattering phase function.

Parameter_Index()

Default parameter sets for disk modeling components.

Winnie_PSF()

Creates a JWST PSF model, using the package Winnie.

class grater_jax.disk_model.objective_functions.Parameter_Index

Bases: object

Default parameter sets for disk modeling components.

This class defines canonical parameter dictionaries used for optimization and simulation of scattered light disks. These parameters act as templates for packing and fitting routines and represent all modifiable physical and observational quantities.

disk_params
Dictionary of physical parameters describing the disk geometry and density:
  • accuracyfloat

    Numerical accuracy for integrators.

  • alpha_in, alpha_outfloat

    Inner and outer radial density power-law slopes.

  • smafloat

    Semi-major axis (in pixels or AU, depending on model).

  • efloat

    Eccentricity of the disk.

  • ksi0float

    Azimuthal anisotropy parameter.

  • gamma, betafloat

    Vertical structure parameters.

  • rminfloat

    Inner radius cutoff.

  • dens_at_r0float

    Density normalization at reference radius.

  • inclinationfloat

    Inclination angle (degrees).

  • position_anglefloat

    Disk orientation on sky (degrees).

  • x_center, y_centerfloat

    Image center coordinates.

  • halfNbSlicesint

    Number of angular slices for 3D integration (half of total).

  • omegafloat

    Argument of pericenter (in radians or degrees).

Type:

dict

misc_params
Dictionary of observational and image grid parameters:
  • distancefloat

    Distance to the system (in parsecs).

  • pxInArcsecfloat

    Pixel scale (arcseconds per pixel).

  • nx, nyint

    Image dimensions (width and height in pixels).

  • halfNbSlicesint

    Number of angular slices for rendering (should match disk_params).

  • flux_scalingfloat

    Normalization factor for model brightness.

Type:

dict

Notes

Parameter dictionaries for scattering phase functions (SPFs), PSFs, and stellar PSFs are defined in SLD_utils.py. Spline SPFs and Winnie PSFs use instantiated model classes (e.g., InterpolatedUnivariateSpline_SPF, Winnie_PSF).

disk_params = {'accuracy': 0.005, 'alpha_in': 5, 'alpha_out': -5, 'beta': 1.0, 'dens_at_r0': 1.0, 'e': 0.0, 'gamma': 2.0, 'inclination': 0, 'ksi0': 3.0, 'omega': 0.0, 'position_angle': 0, 'rmin': 0.0, 'sma': 50, 'x_center': 70.0, 'y_center': 70.0}
misc_params = {'distance': 50.0, 'flux_scaling': 1000000.0, 'halfNbSlices': 25, 'nx': 140, 'ny': 140, 'pxInArcsec': 0.01414}
grater_jax.disk_model.objective_functions.pack_pars(p_dict, orig_dict)

Pack parameter values from a dictionary into a JAX array.

The output array follows the key order defined in orig_dict, which is typically a template dictionary that defines the parameter structure. This is how jax classes are wrapped in the JAX code.

Parameters:
  • p_dict (dict) – Dictionary of parameter values to be packed. Keys must match those in orig_dict.

  • orig_dict (dict) – Reference dictionary that defines the desired key ordering.

Returns:

JAX array of parameter values in the order defined by orig_dict.

Return type:

jnp.ndarray

grater_jax.disk_model.objective_functions.objective_model(disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, stellar_psf_params=None, StellarPSFModel=None, throughput=None, **kwargs)

Generate a disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters.

disk_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

psf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

misc_paramsdict of (str, float) pairs

The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

PSFModelclass (Can be found in disk_model/SLD_utils.py)

The point spread function model type

stellar_psf_paramsdict of (str, float) pairs, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

throughputnp.ndarray, optional

The throughput image to apply to the generated disk image. Defaults to None, indicating all 1s.

kwargsdict, optional

Additional keyword arguments that are passed into the objective model function.

Returns:

Generated disk model image

Return type:

jnp.ndarray

grater_jax.disk_model.objective_functions.objective_ll(disk_params, spf_params, psf_params, stellar_psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, target_image, err_map, throughput=None, ll_method='sum', **kwargs)

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

disk_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

psf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

misc_paramsdict of (str, float) pairs

The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

PSFModelclass (Can be found in disk_model/SLD_utils.py)

The point spread function model type

stellar_psf_paramsdict of (str, float) pairs, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image. Defaults to None, indicating all 1s.

kwargsdict, optional

Additional keyword arguments that are passed into the objective model function.

Returns:

Log likelihood for the generated disk image, target image, and error map

Return type:

float

grater_jax.disk_model.objective_functions.objective_fit(params_fit, fit_keys, disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, target_image, err_map, throughput=None, stellar_psf_params=None, StellarPSFModel=None, ll_method='sum', **kwargs)

Same as the objective_ll function but accepts replacement values for the given parameters in fit_keys. This is ideal for fitting as it provides a clean objective function for a given set of parameters.

ex: fit_keys = [‘alpha_in’, ‘alpha_out’], params_fit = [-5., 5.], the function will compute the log-likelihood for the given parameters while replacing alpha_in with -5 and alpha_out with 5.

params_fitlist of float

The replacement values for the parameters indicated by fit_keys.

fit_keyslist of str

The names of the parameters to have their values be replaced by the corresponding values in params_fit.

disk_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

psf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

misc_paramsdict of (str, float) pairs

The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

PSFModelclass (Can be found in disk_model/SLD_utils.py)

The point spread function model type

stellar_psf_paramsdict of (str, float) pairs, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image. Defaults to None, indicating all 1s.

kwargsdict, optional

Additional keyword arguments that are passed into the objective model function.

Returns:

Log likelihood for the generated disk image, target image, and error map

Return type:

float

grater_jax.disk_model.objective_functions.objective_grad(disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, target_image, err_map, throughput=None, stellar_psf_params=None, StellarPSFModel=None, ll_method='sum', **kwargs)

Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

disk_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

psf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

misc_paramsdict of (str, float) pairs

The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

PSFModelclass (Can be found in disk_model/SLD_utils.py)

The point spread function model type

stellar_psf_paramsdict of (str, float) pairs, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image. Defaults to None, indicating all

kwargsdict, optional

Additional keyword arguments that are passed into the objective model function.

Returns:

Gradients of all the parameters with the format (gradients of disk params, gradients of spf params, gradients of psf_params, gradients of stellar psf params)

Note: if the psf model is a WinniePSF, the gradients of psf_params will not be included in the output. Note: not all parameters are supported for gradient evaluation due to limitations of the JAX model. Note: the raw gradient output is transformed in the Optimizer class which wraps this method nicely.

Return type:

list of jnp array

grater_jax.disk_model.objective_functions.objective_fit_grad(params_fit, fit_keys, disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, target_image, err_map, throughput=None, stellar_psf_params=None, StellarPSFModel=None, ll_method='sum', **kwargs)

Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

ex: fit_keys = [‘alpha_in’, ‘alpha_out’], params_fit = [-5., 5.], the function will compute the gradients for all the parameters with respect to the log-likelihood for the given parameters while replacing alpha_in with -5 and alpha_out with 5.

params_fitlist of float

The replacement values for the parameters indicated by fit_keys.

fit_keyslist of str

The names of the parameters to have their values be replaced by the corresponding values in params_fit.

disk_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

spf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

psf_paramsdict of (str, float) pairs

The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

misc_paramsdict of (str, float) pairs

The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.

DiskModelclass (ScatteredLighDisk is the only supported disk model)

The disk model type

DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)

The dust distribution model type

FuncModelclass (Can be found in disk_model/SLD_utils.py)

The scattering phase function model type

PSFModelclass (Can be found in disk_model/SLD_utils.py)

The point spread function model type

stellar_psf_paramsdict of (str, float) pairs, optional

The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

StellarPSFModelclass, optional

The scattering phase function model type, set to None be default indicating no stellar psf model.

target_imagenp.ndarray

The target image that the log likelihood is being computed for.

err_mapnp.ndarray

The error map for the target image.

throughputnp.ndarray

The throughput image to apply to the generated disk image. Defaults to None, indicating all

kwargsdict, optional

Additional keyword arguments that are passed into the objective model function.

Returns:

Gradients of all the parameters with the format (gradients of disk params, gradients of spf params, gradients of psf_params, gradients of stellar psf params)

Note: if the psf model is a WinniePSF, the gradients of psf_params will not be included in the output. Note: not all parameters are supported for gradient evaluation due to limitations of the JAX model. Note: the raw gradient output is transformed in the Optimizer class which wraps this method nicely.

Return type:

list of jnp array