grater_jax.optimization.optimize_framework

optimizer.py

High-level optimization interface for GRaTeR-JAX.

This module defines the Optimizer class, a top-level wrapper for the GRaTeR-JAX framework that integrates disk morphology, scattering phase functions, PSFs, and stellar PSFs into a single object. It provides unified methods for generating disk models, computing likelihoods and gradients, and performing optimization with both deterministic (SciPy) and stochastic (MCMC) approaches.

Main features

  • Optimizer.get_model : Build a model image from current parameters.

  • Optimizer.get_objective_likelihood, get_objective_gradient : Evaluate log-likelihood and its gradient.

  • Optimizer.scipy_optimize, Optimizer.scipy_bounded_optimize : Run SciPy-based minimization with optional analytic gradients.

  • Optimizer.mcmc : Run emcee-based MCMC sampling for posterior inference.

  • Optimizer.save_human_readable, save_machine_readable, load_machine_readable : Save/load model parameters for reproducibility.

  • OptimizeUtils : Helper utilities for empirical error maps, image preprocessing, and post-processing MCMC chains.

Functions

check_grad(func, grad, x0, *args[, epsilon, ...])

Check the correctness of a gradient function by comparing it against a (forward) finite-difference approximation of the gradient.

log_likelihood(image, target_image, err_map)

Compute the Gaussian log-likelihood between a model image and observed data.

minimize(fun, x0[, args, method, jac, hess, ...])

Minimization of scalar function of one or more variables.

objective_fit(params_fit, fit_keys, ...[, ...])

Same as the objective_ll function but accepts replacement values for the given parameters in fit_keys.

objective_fit_grad(params_fit, fit_keys, ...)

Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

objective_grad(disk_params, spf_params, ...)

Get the gradient of each parameter with respect to the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

objective_ll(disk_params, spf_params, ...[, ...])

Get the log likelihood for the generated disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters along with the target image and error map.

objective_model(disk_params, spf_params, ...)

Generate a disk model image given disk, scattering function, point spread function, stellar psf point spread function, and misceallaneous parameters.

recommended_num_knots(inclination_deg, ...)

Return a recommended knot count scaled to the scattering angles probed at a given inclination.

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

Classes

DoubleHenyeyGreenstein_SPF()

Implementation of a scattering phase function with a double Henyey Greenstein function.

DustEllipticalDistribution2PowerLaws()

Two-power-law elliptical dust density distribution model.

EMP_PSF()

Empirical point spread function (PSF) model.

GAUSSIAN_PSF()

Gaussian PSF model.

HenyeyGreenstein_SPF()

Implementation of a scattering phase function with a single Henyey Greenstein function.

InterpolatedUnivariateSpline(x, y[, k, ...])

InterpolatedUnivariateSpline_SPF()

Implementation of a spline scattering phase function.

Jax_class()

Base class for custom JAX-compatible objects that can be compressed into and uncompressed from JAX arrays.

LinearStellarPSF()

Stellar PSF model as a linear combination of reference images.

MCMC_model(fun, theta_bounds, name)

Markov Chain Monte Carlo (MCMC) wrapper for emcee sampler with tools for inference, diagnostics, and visualization.

OptimizeUtils()

Optimizer(DiskModel, DistrModel, FuncModel, ...)

The Optimizer class is the high level wrapper for the GRaTeR-JAX framework and all of its functionality.

PositionalStellarPSF()

Stellar PSF model with position-dependent reference image weights.

StellarPSFReference()

Reference images that the Stellar PSF classes will use.

WinniePSF(psfs, psf_inds_rolls, ...)

A class for spatially-varying PSF convolution with image cubes, including rotation and masking per roll angle.

Winnie_PSF()

Creates a JWST PSF model, using the package Winnie.

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

class grater_jax.optimization.optimize_framework.Optimizer(DiskModel, DistrModel, FuncModel, PSFModel, disk_params, spf_params, psf_params, misc_params, StellarPSFModel=None, stellar_psf_params=None, empirical_psf_image=None, throughput=None, **kwargs)

Bases: object

The Optimizer class is the high level wrapper for the GRaTeR-JAX framework and all of its functionality. It handles all of the supported parameters for the disk morphology, scattering phase function, off axis point spread function, miscellaneous, and on axis stelllar psf parameters. The optimizer class is meant to update in place as the user uses various optimization and modification methods. To save a current disk model before using an optimization method, you can save a copy of the current Optimizer object.

Parameters:
  • DiskModel (class (ScatteredLighDisk is the only supported disk model)) – The disk model type

  • DistrModel (class (DustEllipticalDistribution2PowerLaws is the only supported dust distribution model)) – The dust distribution model type

  • FuncModel (class (Can be found in disk_model/SLD_utils.py)) – The scattering phase function model type

  • PSFModel (class (Can be found in disk_model/SLD_utils.py)) – The point spread function model type

  • disk_params (dict of (str, float) pairs) – The corresponding parameter dictionary for the disk model, dictionary is made of (parameter name, parameter value) pairs.

  • spf_params (dict of (str, float) pairs) – The corresponding parameter dictionary for the scattering phase function, dictionary is made of (parameter name, parameter value) pairs.

  • psf_params (dict of (str, float) pairs) – The corresponding parameter dictionary for the point spread function, dictionary is made of (parameter name, parameter value) pairs.

  • misc_params (dict of (str, float) pairs) – The parameter dictionary for misceallanious values, such as image size and flux scaling, dictionary is made of (parameter name, parameter value) pairs.

  • StellarPSFModel (class, optional) – The scattering phase function model type, set to None be default indicating no stellar psf model.

  • stellar_psf_params (dict of (str, float) pairs, optional) – The corresponding parameter dictionary for the on axis stellar psf model, dictionary is made of (parameter name, parameter value) pairs.

  • empirical_psf_image (np.ndarray) – Image that the model uses for the empirical psf.

  • kwargs (dict, optional) – Additional keyword arguments that are passed into the objective model function.

jit_compile_model()

Just-in-time compiles the disk model. Need to call again if any of the component classes are changed.

jit_compile_gradient(target_image, err_map)

Just-in-time compiles the gradient of the disk model for a given target image and err_map.

set_throughput(throughput)

Sets the throughput image to the given image.

Parameters:

throughput (numpy.ndarray) – 2D array representing the throughput image.

Raises:

ValueError – If the shape of the throughput image does not match the specified image dimensions in misc_params.

get_model()

Returns the disk model as per the current parameters.

Returns:

2d image of the generated disk model

Return type:

numpy.ndarray

get_objective_likelihood(params_fit, fit_keys, target_image, err_map)

Returns the log likelihood of the disk model as per the given parameters, target image, and error map.

Parameters:
  • params_fit (list of float) – List of parameter values that are set to replace the original values of the parameters specified by fit_keys.

  • fit_keys (list of str) – List of names of parameters that are set to be replaced by values in params_fit.

  • target_image (numpy.ndarray) – Target image that the log likelihood is being computed with.

  • err_map (numpy.ndarray) – Error map of same shape as target image that the log likelihood is being computed with.

  • Returns

  • --------

  • float – Log likelihood value of the current model.

get_gradient(keys, target_image, err_map)

Returns the gradient of the disk model for the given target image and error map for the parameters specified in keys. Note: Log-scaling is not handled in this method. Some parameters may not work.

Parameters:
  • keys (list of str) – Names of parameters that the gradient is being computed for.

  • target_image (numpy.ndarray) – Target image that the gradient is being computed with.

  • err_map (numpy.ndarray) – Error map of same shape as target image that the gradient is being computed with.

  • Returns

  • --------

  • list – Gradient of the specified values in keys. Gradients returned in the same order as keys.

get_objective_gradient(params_fit, fit_keys, target_image, err_map)

Objective function for returning the gradient of the disk model for the given target image and error map for the log likelihood. Takes in a list of parameters and their new values and outputs their gradients. Note: Log-scaling is not handled in this method. Some parameters may not work

Parameters:
  • params_fit (list of float) – List of parameter values that are set to replace the original values of the parameters specified by fit_keys.

  • fit_keys (list of str) – Names of parameters that the gradient is being computed for.

  • target_image (numpy.ndarray) – Target image that the gradient is being computed with.

  • err_map (numpy.ndarray) – Error map of same shape as target image that the gradient is being computed with.

  • Returns

  • --------

  • numpy.ndarray – Gradient of the specified values in fit_keys. Gradients returned in the same order as fit_keys.

get_disk()

Returns the disk model as per the current parameters without the off-axis psf nor stellar psf. Note: will take longer to run the first time as it needs to be JIT compiled due to the changing of the component classes for the objective function call.

Returns:

2d image of the generated disk model.

Return type:

numpy.ndarray

get_values(keys)

Retrieves the current values of the specified model parameters.

Parameters:

keys (list of str) – List of parameter names to retrieve from the disk, SPF, PSF, stellar PSF, or misc parameter dictionaries.

Returns:

List of parameter values in the same order as the keys. If a key is not found, None is returned for that key.

Return type:

list

log_likelihood_pos(target_image, err_map)

Returns the negative log-likelihood of the current model compared to the target image.

Parameters:
  • target_image (numpy.ndarray) – 2D array representing the target image.

  • err_map (numpy.ndarray) – 2D array of the same shape as target_image, representing the error map.

Returns:

Negative log-likelihood of the model with respect to the given image and error map.

Return type:

float

log_likelihood(target_image, err_map)

Returns the log-likelihood of the current model compared to the target image.

Parameters:
  • target_image (numpy.ndarray) – 2D array representing the target image.

  • err_map (numpy.ndarray) – 2D array of the same shape as target_image, representing the error map.

Returns:

Log-likelihood of the model with respect to the given image and error map.

Return type:

float

define_reference_images(reference_images)

Sets the reference images to be used for constructing the stellar PSF.

Parameters:

reference_images (list of numpy.ndarray) – List of 2D reference images to use for constructing the stellar PSF.

scipy_optimize(fit_keys, logscaled_params, array_params, target_image, err_map, disp_soln=False, iters=500, method=None, use_grad=False, scale=1.0, ftol=1e-40, gtol=1e-40, eps=1e-40, **kwargs)

Runs unconstrained optimization using SciPy’s minimize on the specified parameters to maximize the log-likelihood. Uses current parameter dictionary values.

Parameters:
  • fit_keys (list of str) – Names of parameters to fit.

  • logscaled_params (list of str) – Subset of fit_keys that are log-scaled.

  • array_params (list of str) – Subset of fit_keys that are array-valued parameters.

  • target_image (numpy.ndarray) – Target image to fit the model to.

  • err_map (numpy.ndarray) – Error map associated with the target image.

  • disp_soln (bool, optional) – If True, prints the optimization result after completion. Default is False.

  • iters (int, optional) – Maximum number of iterations. Default is 500.

  • method (str or None, optional) – Optimization method to use (e.g., ‘BFGS’, ‘L-BFGS-B’). If None, defaults to method chosen by minimize.

  • use_grad (bool, optional) – Whether to use the analytical gradient. Default is False.

  • scale (float, optional) – Scaling factor applied to the objective value and gradient. Default is 1.

  • ftol (float, optional) – Tolerance for change in function value for convergence. Default is 1e-40.

  • gtol (float, optional) – Tolerance for the norm of the projected gradient. Default is 1e-40.

  • eps (float, optional) – Step size used for numerical approximation of the Jacobian (if use_grad is False). Default is 1e-40.

  • **kwargs (dict) – Additional keyword arguments (currently unused).

Returns:

soln – The result of the optimization. Includes optimized parameters, success status, and diagnostic information.

Return type:

scipy.optimize.OptimizeResult

scipy_bounded_optimize(fit_keys, fit_bounds, logscaled_params, array_params, target_image, err_map, disp_soln=False, iters=500, ftol=1e-12, gtol=1e-12, eps=1e-08, scale=1.0, use_grad=False, maxls=20, **kwargs)

Runs bounded optimization using SciPy’s minimize with the L-BFGS-B algorithm to maximize the log-likelihood. Uses current parameter dictionary values.

Parameters:
  • fit_keys (list of str) – Names of parameters to fit.

  • fit_bounds (tuple of list of float) – Tuple (lower_bounds, upper_bounds) giving bounds for each parameter. Each bound can be scalar or array-like.

  • logscaled_params (list of str) – Subset of fit_keys that are log-scaled.

  • array_params (list of str) – Subset of fit_keys that are array-valued parameters.

  • target_image (numpy.ndarray) – Target image to fit the model to.

  • err_map (numpy.ndarray) – Error map associated with the target image.

  • disp_soln (bool, optional) – If True, prints the optimization result after completion. Default is False.

  • iters (int, optional) – Maximum number of iterations. Default is 500.

  • ftol (float, optional) – Tolerance for change in function value for convergence. Default is 1e-12.

  • gtol (float, optional) – Tolerance for the norm of the projected gradient. Default is 1e-12.

  • eps (float, optional) – Step size used for numerical approximation of the Jacobian (if use_grad is False). Default is 1e-8.

  • scale (float, optional) – Scaling factor applied to the objective value and gradient. Default is 1.

  • use_grad (bool, optional) – Whether to use the analytical gradient. Default is False.

  • maxls (int, optional) – Maximum number of line search steps per iteration for L-BFGS-B. Default is 20 (scipy default). Increase (e.g. to 50–100) when ABNORMAL_TERMINATION_IN_LNSRCH is reported at low iteration counts.

  • **kwargs (dict) – Additional keyword arguments (currently unused).

Returns:

soln – The result of the optimization. Includes optimized parameters, success status, and diagnostic information.

Return type:

scipy.optimize.OptimizeResult

mcmc(fit_keys, logscaled_params, array_params, target_image, err_map, BOUNDS, nwalkers=250, niter=250, burns=50, resume=False, scale_objective_function_for_shape=False, **kwargs)

Runs Markov Chain Monte Carlo (MCMC) sampling using an emcee-based sampler to estimate posterior distributions for model parameters based on the log-likelihood function. Uses current parameter dictionary values.

Parameters:
  • fit_keys (list of str) – Names of parameters to optimize using MCMC.

  • logscaled_params (list of str) – Subset of fit_keys whose values are log-scaled.

  • array_params (list of str) – Subset of fit_keys that correspond to array-valued parameters.

  • target_image (numpy.ndarray) – 2D array of the observed image data to compare against the model.

  • err_map (numpy.ndarray) – 2D array of the same shape as target_image representing per-pixel uncertainties.

  • BOUNDS (tuple of (list, list)) – Tuple of (lower_bounds, upper_bounds) for each parameter in fit_keys. Each bound can be a scalar or a list if the parameter is an array.

  • nwalkers (int, optional) – Number of MCMC walkers. Default is 250.

  • niter (int, optional) – Total number of MCMC steps per walker. Default is 250.

  • burns (int, optional) – Number of burn-in steps to discard from each chain. Default is 50.

  • resume (bool, optional) – If True, continue sampling from the existing HDF5 backend, extending the chain by niter steps without re-running burn-in. If False (default), start a fresh run and overwrite the existing backend. Pass confirm_overwrite=False via **kwargs to suppress the interactive overwrite prompt in scripted or notebook contexts.

  • scale_objective_function_for_shape (bool, optional) – If True, scales the log-likelihood by the number of pixels in target_image to normalize shape differences. Default is False.

  • **kwargs (dict) – Additional arguments passed to the underlying MCMC_model.run() method. Notable: confirm_overwrite=False suppresses the interactive prompt when overwriting an existing backend (required in non-interactive contexts).

Returns:

The fitted MCMC_model object containing the sampler chain, statistics, and results.

Return type:

MCMC_model

Raises:

Exception – If the initial guess is out of the specified parameter bounds.

inc_bound_knots(buffer=0)

Applying inclination bounds to the knot values in order to improve fitting. Note: only use this if you have a good initial inclination guess.

Parameters:

bufferint

Degrees of inclination to be added to the knot bounds for padding.

Returns:

dict

Copy of spf parameters for easier access. Should usually ignore.

warm_start_knots(init_g=0.5)

Initialise spline SPF knot values from a normalised Henyey-Greenstein shape.

The spline SPF is normalised to 1 at 90 degrees (cos_phi = 0) by construction, so the center knot is fixed and only the num_knots free knots are set here. Warm-starting from an HG shape (rather than all-ones) avoids a flat-spline local minimum that appears at lower inclinations where the probed scattering-angle range is narrow. init_g is an assumed asymmetry parameter used only to generate the initial guess — it is not a constraint on the fit.

Call inc_bound_knots before this method if you want the knot positions to reflect the actual scattering angles probed by the disk inclination.

Parameters:

init_g (float, optional) – Henyey-Greenstein asymmetry parameter used for initialisation (default 0.5). Positive values produce forward-scattering shapes; negative values produce back-scattering shapes. The value is not constrained during fitting.

Returns:

The initialised free knot values (also stored in self.spf_params['knot_values']).

Return type:

jnp.ndarray

estimate_flux_scaling(target_image, inner_mask_radius=12)

Estimate a starting flux_scaling by matching the model peak to the target image.

The model (at flux_scaling = 1) is evaluated once and its peak is compared to the 99th-percentile brightness of unmasked pixels in the target image. The ratio gives a robust starting estimate that avoids bias from background pixels.

Note

This is one reasonable heuristic for initialising flux_scaling, not the only valid approach. Users are responsible for verifying that the estimate is sensible for their data.

Parameters:
  • target_image (numpy.ndarray) – The science image to match against.

  • inner_mask_radius (float, optional) – Pixels within this radius (in pixels) of the image centre are excluded from the 99th-percentile calculation to avoid contamination from a coronagraphic mask or stellar residuals (default 12).

Returns:

The estimated flux_scaling (also stored in self.misc_params['flux_scaling']).

Return type:

float

compute_stellar_psf_image()

Just returns an image of the stellar psf model.

Returns:

numpy.ndarray

2 dimensional numpy array of stellar psf image

set_empirical_psf(image)

Sets the empirical psf image to the given image.

Params:

numpy.ndarray

2 dimensional numpy array of the empirical psf image

fix_all_nonphysical_params()

Catch-all method that bounds all parameters with non-physical parameters to their max or min values.

print_params()

Prints parameters to console.

plot_spline(num_points=100)

Plot the current spline SPF as a function of cos(phi).

Parameters:

num_points (int, optional) – Number of evaluation points along the cos(phi) axis. Default is 100.

Returns:

Figure containing the SPF curve with knot positions marked.

Return type:

matplotlib.figure.Figure

get_flux_scale()

Returns flux scaling value.

Returns:

float

flux scaling value

initialize_stellar_psf_weights(target_image, rcond=1e-08)

Initialize the weights for the stellar PSF reference images by solving a linear least squares problem to best match the target image.

This method computes a weighted linear combination of the reference PSF images that best fits the target image in a least-squares sense. The resulting weights are stored in self.stellar_psf_params[‘stellar_weights’].

Parameters:
  • target_image (ndarray) – The 2D target image that the stellar PSF model should fit.

  • rcond (float, optional) – Cut-off ratio for small singular values in the least squares solver. Values smaller than rcond * largest_singular_value are treated as zero. Default is 1e-8.

Return type:

None

save_human_readable(dirname)

Save all current model parameters to a human-readable text file.

Parameters:

dirname (str) – Directory in which to write the file. The filename is derived from self.name and self.last_fit.

save_machine_readable(dirname)

Save all current model parameters to JSON files for later reloading.

Writes four separate JSON files (disk, SPF, PSF, misc params) into dirname. Array-valued parameters are serialised as lists.

Parameters:

dirname (str) – Directory in which to write the JSON files. Filenames are derived from self.name and self.last_fit.

load_machine_readable(dirname, method=None)

Load model parameters from previously saved JSON files.

Parameters:
  • dirname (str) – Directory containing the JSON files to load.

  • method (str, optional) – Fitting method tag used in the filename ('scipyminimize', 'scipyboundminimize', or 'mcmc'). Defaults to self.last_fit.

class grater_jax.optimization.optimize_framework.OptimizeUtils

Bases: object

classmethod create_empirical_err_map(data, annulus_width=5, mask_rad=9, outlier_pixels=None)

Constructs an empirical error map based on radial annuli standard deviation.

Parameters:
  • data (numpy.ndarray) – 2D image data array.

  • annulus_width (int, optional) – Width (in pixels) of each radial annulus. Default is 5.

  • mask_rad (int, optional) – Radius (in pixels) within which the error is artificially set to a large value. Default is 9.

  • outlier_pixels (list of tuple, optional) – List of pixel coordinates (tuples) to mark as extreme outliers with inflated error values.

Returns:

noise_array – 2D array of the same shape as data, representing the empirical error map.

Return type:

numpy.ndarray

classmethod convert_dhg_params_to_spline_params(g1, g2, w, spf_params)

Converts Double Henyey-Greenstein parameters into spline phase function values.

Parameters:
  • g1 (float) – First asymmetry parameter.

  • g2 (float) – Second asymmetry parameter.

  • w (float) – Weighting factor between the two lobes.

  • spf_params (dict) – SPF parameters containing knot information for the spline.

Returns:

spline_vals – Evaluated spline phase function values.

Return type:

numpy.ndarray

classmethod process_image(image, scale_factor=1, bounds=(70, 210, 70, 210))

Crops, scales, and safely converts an image to float32.

Parameters:
  • image (numpy.ndarray) – Input image array.

  • scale_factor (int, optional) – Factor to downsample the image spatially. Default is 1 (no scaling).

  • bounds (tuple of int, optional) – A 4-element tuple defining the cropping window as (ymin, ymax, xmin, xmax). Default is (70, 210, 70, 210).

Returns:

fin_image – Processed and cropped image array in float32 format with NaNs converted to 0.

Return type:

numpy.ndarray

classmethod get_mask(data, annulus_width=5, mask_rad=9, outlier_pixels=None)

Generates a binary mask identifying the central masked region of an image.

Parameters:
  • data (numpy.ndarray) – 2D image array.

  • annulus_width (int, optional) – Width of annuli used for noise estimation (not returned). Default is 5.

  • mask_rad (int, optional) – Radius of the central region to mask out (set to True). Default is 9.

  • outlier_pixels (list of tuple, optional) – Currently unused in mask generation but included for consistency.

Returns:

mask – Boolean array of same shape as input, where True indicates masked (excluded) pixels.

Return type:

numpy.ndarray

classmethod unlogscale_mcmc_model(mc_model, fit_keys, logscaled_params, array_params, array_lengths)

Unlog-scales MCMC chain values in-place for parameters originally fit in log space.

Parameters:
  • mc_model (MCMC_model) – Instance of MCMC_model whose chain will be modified.

  • fit_keys (list of str) – Names of the parameters in the chain.

  • logscaled_params (list of str) – Subset of fit_keys that were log-scaled during sampling.

  • array_params (list of str) – Subset of fit_keys that are array-valued.

  • array_lengths (list of int) – Length of each parameter in flattened form (1 for scalars, >1 for arrays).

classmethod scale_spline_chains(mc_model, fit_keys, array_params, array_lengths, scale_factor)

Applies normalization scaling to spline knots and compensates flux scaling if present in MCMC chains.

Parameters:
  • mc_model (MCMC_model) – The MCMC model object containing the chain.

  • fit_keys (list of str) – Names of parameters in the MCMC chain.

  • array_params (list of str) – Parameters treated as arrays (e.g., spline knots).

  • array_lengths (list of int) – Flattened size of each parameter.

  • scale_factor (float) – Value to scale knot_values by. If flux_scaling is present, it will be inversely scaled.