grater_jax.disk_model.objective_functions
objective_functions.py
Objective functions for disk modeling and optimization.
This module provides functions that connect high-level optimization routines to the underlying JAX-accelerated disk model wrappers. It defines parameter templates, packing utilities, and objective functions for model evaluation, log-likelihood computation, and gradient calculation.
Functions
|
Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Gradient of jax_model_ll with respect to positional argument(s) (5, 6, 7, 8, 12). |
|
Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Gradient of jax_model_spline_ll with respect to positional argument(s) (5, 6, 7, 8, 12). |
|
Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Gradient of jax_model_spline_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). |
|
Get the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Gradient of jax_model_winnie_ll with respect to positional argument(s) (5, 6, 7, 11). |
|
Compute the Gaussian log-likelihood between a model image and observed data. |
|
Same as the objective_ll function but accepts replacement values for the given parameters in fit_keys. |
|
Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map. |
|
Generate a disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters. |
|
Pack parameter values from a dictionary into a JAX array. |
Classes
|
Implementation of a spline scattering phase function. |
Default parameter sets for disk modeling components. |
|
|
Creates a JWST PSF model, using the package Winnie. |
- class grater_jax.disk_model.objective_functions.Parameter_Index
Bases:
objectDefault parameter sets for disk modeling components.
This class defines canonical parameter dictionaries used for optimization and simulation of scattered light disks. These parameters act as templates for packing and fitting routines and represent all modifiable physical and observational quantities.
- disk_params
- Dictionary of physical parameters describing the disk geometry and density:
- accuracyfloat
Numerical accuracy for integrators.
- alpha_in, alpha_outfloat
Inner and outer radial density power-law slopes.
- smafloat
Semi-major axis (in pixels or AU, depending on model).
- efloat
Eccentricity of the disk.
- ksi0float
Azimuthal anisotropy parameter.
- gamma, betafloat
Vertical structure parameters.
- rminfloat
Inner radius cutoff.
- dens_at_r0float
Density normalization at reference radius.
- inclinationfloat
Inclination angle (degrees).
- position_anglefloat
Disk orientation on sky (degrees).
- x_center, y_centerfloat
Image center coordinates.
- halfNbSlicesint
Number of angular slices for 3D integration (half of total).
- omegafloat
Argument of pericenter (in radians or degrees).
- Type:
dict
- misc_params
- Dictionary of observational and image grid parameters:
- distancefloat
Distance to the system (in parsecs).
- pxInArcsecfloat
Pixel scale (arcseconds per pixel).
- nx, nyint
Image dimensions (width and height in pixels).
- halfNbSlicesint
Number of angular slices for rendering (should match disk_params).
- flux_scalingfloat
Normalization factor for model brightness.
- Type:
dict
Notes
Parameter dictionaries for scattering phase functions (SPFs), PSFs, and stellar PSFs are defined in SLD_utils.py. Spline SPFs and Winnie PSFs use instantiated model classes (e.g., InterpolatedUnivariateSpline_SPF, Winnie_PSF).
- disk_params = {'accuracy': 0.005, 'alpha_in': 5, 'alpha_out': -5, 'beta': 1.0, 'dens_at_r0': 1.0, 'e': 0.0, 'gamma': 2.0, 'inclination': 0, 'ksi0': 3.0, 'omega': 0.0, 'position_angle': 0, 'rmin': 0.0, 'sma': 50, 'x_center': 70.0, 'y_center': 70.0}
- misc_params = {'distance': 50.0, 'flux_scaling': 1000000.0, 'halfNbSlices': 25, 'nx': 140, 'ny': 140, 'pxInArcsec': 0.01414}
- grater_jax.disk_model.objective_functions.pack_pars(p_dict, orig_dict)
Pack parameter values from a dictionary into a JAX array.
The output array follows the key order defined in orig_dict, which is typically a template dictionary that defines the parameter structure. This is how jax classes are wrapped in the JAX code.
- Parameters:
p_dict (dict) – Dictionary of parameter values to be packed. Keys must match those in orig_dict.
orig_dict (dict) – Reference dictionary that defines the desired key ordering.
- Returns:
JAX array of parameter values in the order defined by orig_dict.
- Return type:
jnp.ndarray
- grater_jax.disk_model.objective_functions.objective_model(disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, stellar_psf_params=None, StellarPSFModel=None, throughput=None, **kwargs)
Generate a disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters.
- disk_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- psf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.
- misc_paramsdict of (str, float) pairs
The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- PSFModelclass (Can be found in disk_model/SLD_utils.py)
The point spread function model type
- stellar_psf_paramsdict of (str, float) pairs, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- throughputnp.ndarray, optional
The throughput image to apply to the generated disk image. Defaults to None, indicating all 1s.
- kwargsdict, optional
Additional keyword arguments that are passed into the objective model function.
- Returns:
Generated disk model image
- Return type:
jnp.ndarray
- grater_jax.disk_model.objective_functions.objective_ll(disk_params, spf_params, psf_params, stellar_psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, StellarPSFModel, target_image, err_map, throughput=None, **kwargs)
Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.
- disk_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- psf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.
- misc_paramsdict of (str, float) pairs
The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- PSFModelclass (Can be found in disk_model/SLD_utils.py)
The point spread function model type
- stellar_psf_paramsdict of (str, float) pairs, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image. Defaults to None, indicating all 1s.
- kwargsdict, optional
Additional keyword arguments that are passed into the objective model function.
- Returns:
Log likelihood for the generated disk image, target image, and error map
- Return type:
float
- grater_jax.disk_model.objective_functions.objective_fit(params_fit, fit_keys, disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, target_image, err_map, throughput=None, stellar_psf_params=None, StellarPSFModel=None, **kwargs)
Same as the objective_ll function but accepts replacement values for the given parameters in fit_keys. This is ideal for fitting as it provides a clean objective function for a given set of parameters.
ex: fit_keys = [‘alpha_in’, ‘alpha_out’], params_fit = [-5., 5.], the function will compute the log-likelihood for the given parameters while replacing alpha_in with -5 and alpha_out with 5.
- params_fitlist of float
The replacement values for the parameters indicated by fit_keys.
- fit_keyslist of str
The names of the parameters to have their values be replaced by the corresponding values in params_fit.
- disk_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- psf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.
- misc_paramsdict of (str, float) pairs
The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- PSFModelclass (Can be found in disk_model/SLD_utils.py)
The point spread function model type
- stellar_psf_paramsdict of (str, float) pairs, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image. Defaults to None, indicating all 1s.
- kwargsdict, optional
Additional keyword arguments that are passed into the objective model function.
- Returns:
Log likelihood for the generated disk image, target image, and error map
- Return type:
float
- grater_jax.disk_model.objective_functions.objective_grad(disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, target_image, err_map, throughput=None, stellar_psf_params=None, StellarPSFModel=None, **kwargs)
Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.
- disk_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- psf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.
- misc_paramsdict of (str, float) pairs
The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- PSFModelclass (Can be found in disk_model/SLD_utils.py)
The point spread function model type
- stellar_psf_paramsdict of (str, float) pairs, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image. Defaults to None, indicating all
- kwargsdict, optional
Additional keyword arguments that are passed into the objective model function.
- Returns:
Gradients of all the parameters with the format (gradients of disk params, gradients of spf params, gradients of psf_params, gradients of stellar psf params)
Note: if the psf model is a WinniePSF, the gradients of psf_params will not be included in the output. Note: not all parameters are supported for gradient evaluation due to limitations of the JAX model. Note: the raw gradient output is transformed in the Optimizer class which wraps this method nicely.
- Return type:
list of jnp array
- grater_jax.disk_model.objective_functions.objective_fit_grad(params_fit, fit_keys, disk_params, spf_params, psf_params, misc_params, DiskModel, DistrModel, FuncModel, PSFModel, target_image, err_map, throughput=None, stellar_psf_params=None, StellarPSFModel=None, **kwargs)
Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.
ex: fit_keys = [‘alpha_in’, ‘alpha_out’], params_fit = [-5., 5.], the function will compute the gradients for all the parameters with respect to the log-likelihood for the given parameters while replacing alpha_in with -5 and alpha_out with 5.
- params_fitlist of float
The replacement values for the parameters indicated by fit_keys.
- fit_keyslist of str
The names of the parameters to have their values be replaced by the corresponding values in params_fit.
- disk_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.
- spf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.
- psf_paramsdict of (str, float) pairs
The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.
- misc_paramsdict of (str, float) pairs
The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.
- DiskModelclass (ScatteredLighDisk is the only supported disk model)
The disk model type
- DistrModelclass (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)
The dust distribution model type
- FuncModelclass (Can be found in disk_model/SLD_utils.py)
The scattering phase function model type
- PSFModelclass (Can be found in disk_model/SLD_utils.py)
The point spread function model type
- stellar_psf_paramsdict of (str, float) pairs, optional
The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.
- StellarPSFModelclass, optional
The scattering phase function model type, set to None be default indicating no stellar psf model.
- target_imagenp.ndarray
The target image that the log likelihood is being computed for.
- err_mapnp.ndarray
The error map for the target image.
- throughputnp.ndarray
The throughput image to apply to the generated disk image. Defaults to None, indicating all
- kwargsdict, optional
Additional keyword arguments that are passed into the objective model function.
- Returns:
Gradients of all the parameters with the format (gradients of disk params, gradients of spf params, gradients of psf_params, gradients of stellar psf params)
Note: if the psf model is a WinniePSF, the gradients of psf_params will not be included in the output. Note: not all parameters are supported for gradient evaluation due to limitations of the JAX model. Note: the raw gradient output is transformed in the Optimizer class which wraps this method nicely.
- Return type:
list of jnp array