Disk Modeling Basics

Introduction

This notebook provides an introduction to the disk modeling workflow in grater-jax.
It shows how the framework builds scattered-light debris disk images from modular components that represent different physical and instrumental effects.

The model is structured into the following pieces:

  • Dust Distribution – specifies the 3D density of dust grains (e.g., radial power-law slopes, eccentricity, vertical scale height).

  • Scattering Phase Function (SPF) – controls how starlight is scattered by dust as a function of angle (Henyey–Greenstein, double HG, or spline-based models).

  • Point Spread Function (PSF) – applies instrument response by convolving the model image with a chosen PSF kernel (Gaussian, empirical, etc.).

The notebook demonstrates how to configure these components, run them through the objective_model pipeline, and generate synthetic disk images that can be compared directly with observations. Additional information on the analytic model definition is also provided here.

Defining the Dust Distribution

GRaTeR-JAX follows the same dust distribution parameterization as the disk forward modeling module in VIP, which is also based off the framework introduced in Augereau+ 1999. The density distribution is in cylindrical coordinates, and contains three major terms: the surface density at the reference radius of the dust in the midplane (\(z\)=0) \(\rho_0\), the density distribution in the midplane, and the vertical profile, as shown below:

\[\rho(r,\theta,z) = \rho_0 \times \left( \frac{2}{\left(\frac{r}{R(\theta)}\right)^{-2\alpha_{\rm in}}\left(\frac{r}{R(\theta)}\right)^{2\alpha_{\rm out}}}\right)^{1/2} \times e^{-\left(\frac{z}{H(r)}\right)^\gamma}\]

where:

  • \(\alpha_{\rm in}\) and \(\alpha_{\rm out}\) are the inner and outer power laws describing the radial dust distribution

  • \(R(\theta)\) is the reference radius

  • \(a\) is the semimajor axis

  • \(e\) is the eccentricity

  • \(\gamma\) is the decay exponent (\(\gamma\)=2 corresponds to a Gaussian profile)

  • \(H(r)\) is the scale height defined as follows (where \(\xi_0\) is the scale height at the reference radius and \(\beta\) is the flaring coefficient (with \(\beta\)=1 corresponding to linear flaring)):

\[H(r) = \xi_0 \times \left(\frac{r}{R(\theta)}\right)^\beta\]

Note that the reference radius is dependent on \(\theta\) for an elliptical disk, as \(R(\theta) = \frac{a(1-e^2)}{1+e{\rm cos}(\theta)}\).

Scattering Phase Functions

The scattering phase function (SPF) describes the amount of light scattered by a particle (e.g. a dust grain) at different scattering angles \(\theta_{\rm sca}\), which ranges from 0 (forward scattering) to 180 (back scattering). Debris disks generally have more forward scattering phase functions.

The range of scattering angles probed in any given observation is limited by viewing geometry, and that limit is incorporated into GRaTeR-JAX by using an initial guess of inclination to bound the scattering angles used in the model’s SPF with the relationships \(\theta_{\rm sca,min}\approx90-i\) and \(\theta_{\rm sca,max}\approx90+i\).

GRaTeR-JAX provides three options for the SPF:

  1. Henyey-Greenstein: an analytical function determined by the scattering angle \(\theta_{\rm sca}\) and the asymmetry parameter \(g\), where isotropic scattering is indicated by \(g = 0\), forward scattering by \(0 \le g \le 1\), and back scattering by \(-1 \le g \le 0\) (Henyey and Greenstein 1941).

    \[I(\theta_{\rm sca}) = \frac{1}{4\pi} \frac{1-g^2}{\left(1+g^2-2g\cos\theta_{\rm sca}\right)^{3/2}}\]
  2. Double Henyey-Greenstein: a weighted combination of two HG functions, where the free parameters are the first asymmetry parameter \(g_1\), the second function’s asymmetry parameter, \(g_2\), and the relative weight of the two

  3. Spline: a more flexible framework not based on an analytic function; here, a specified number of knots are fit at scattering angles between bounds defined by the approximate inclination of the disk. The number of knots \(n_k\) is determined by the user. The absolute values of the spline knots are degenerate with the overall flux scaling of the disk; as a result, the SPF is normalized such that it is equal to 1 at 90 deg and the flux scaling factor is adjusted accordingly. See the spline injection and recovery tutorial notebook for more information on the assumptions in this formalism.

Note: \(n_k\) refers to the number of unfixed knots (the number of knots not including the fixed knot at (0, 1)). If an \(n_k\) is odd, the extra knot will go on the forward scattering side of the spline.

Setup

First, import the relevant components of GRaTer-JAX and outside packages. The two environment variables set are related to JAX’s handling of memory allocation, and reduce memory usage errors.

Note: This notebook was run on an NVIDIA GeForce RTX 4090

[1]:
import os

# Can limit JAX/XLA to using at most a fraction of total GPU memory to help
# avoid out-of-memory issues when sharing the GPU with other processes.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.15"

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
from astropy.io import fits

from grater_jax.disk_model.objective_functions import objective_model, log_likelihood, Parameter_Index
from grater_jax.disk_model.SLD_ojax import ScatteredLightDisk
from grater_jax.optimization.optimize_framework import Optimizer
from grater_jax.disk_model.SLD_utils import (
    DoubleHenyeyGreenstein_SPF,
    InterpolatedUnivariateSpline_SPF,
    EMP_PSF,
    DustEllipticalDistribution2PowerLaws,
)

jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True) <-- can use this to trace issues at the cost of runtime

Generating Disk Models with Simple PSFs

This example notebook uses an empirically-measured PSF from Gemini Planet Imager H-band data (Macintosh+ 2014, Crotts+ 2025). More complex PSFs, such as those from JWST, are explored in the JWST Tutorial (Experimental).

[2]:
# Generates a simple psf function image
scale_factor = 1
image = fits.open("tutorial_images/GPI_Hband_PSF.fits")[0].data[0,:,:]
scaled_image = (image[::scale_factor, ::scale_factor])[1::, 1::]
cropped_image = image[70:211, 70:211]
def safe_float32_conversion(value):
    try:
        return np.float32(value)
    except (ValueError, TypeError):
        print("This value is unjaxable: " + str(value))
emp_psf_image = np.nan_to_num(cropped_image)
emp_psf_image = np.vectorize(safe_float32_conversion)(emp_psf_image)

Here, we generate a simple disk model with specified parameters. There are four dictionaries: disk_params, spf_params, psf_params, and misc_params. These contain the following parameters relevant to a user:

Disk Parameters (``disk_params``):

  • inner radial density exponent \(\alpha_{\rm in}\) (alpha_in)

  • outer radial density exponent \(\alpha_{\rm out}\) (alpha_out)

  • semimajor axis of the disk \(a\) (sma)

  • disk eccentricity \(e\) (e)

  • scale height at the reference radius \(\xi_0\) (ksi0)

  • vertical decay exponent \(\gamma\) (gamma)

  • flaring exponent \(\beta\) (beta)

  • disk inclination \(i\) (inclination)

  • disk position angle \(\mathrm{PA}\) (position_angle)

  • x offset of the disk center \(x_c\) (x_center)

  • y offset of the disk center \(y_c\) (y_center)

  • argument of pericenter \(\omega\) (omega)

SPF Parameters (``spf_params``):

If Single HG (HenyeyGreenstein_SPF):

  • asymmetry parameter \(g\) (g)

If Double HG (DoubleHenyeyGreenstein_SPF):

  • first asymmetry parameter \(g_1\) (g1)

  • second asymmetry parameter \(g_2\) (g2)

  • relative weight (weight)

If Spline (InterpolatedUnivariateSpline_SPF):

  • number of knots \(n_k\) (num_knots)

  • array of knot \(y\) values (knot_values)

  • back scattering bound — cosine of angle in radians, defaults to -1 (equivalent to 180 deg), but can be set by inclination (backscatt_bound)

  • forward scattering bound — cosine of angle in radians, defaults to 1 (equivalent to 0 deg), but can be set by inclination (forwardscatt_bound)

Miscellaneous Parameters (``misc_params``):

  • distance to target (distance)

  • pixel scale in arcsec/pix (pxInArcsec)

  • number of pixels in x dimension (nx)

  • number of pixels in y dimension (ny)

  • number of slices calculated in the model (increase if you are encountering computational artifacts, but increasing also lengthens runtime) (halfNbSlices)

  • arbitrary flux scaling (flux_scaling)

[3]:
EMP_PSF.img = emp_psf_image

# Using parameters dictionaries from SLD_utils
spf_params = DoubleHenyeyGreenstein_SPF.params
psf_params = EMP_PSF.params

# Initializing the disk parameter dictionary (all the possible parameters that can be configured)
disk_params = Parameter_Index.disk_params

# Setting certain parameters to custom values
disk_params['sma'] = 40
disk_params['inclination'] = 60
disk_params['position_angle'] = 30

# Initializing the misc parameter dictionary
misc_params = Parameter_Index.misc_params

# Setting flux_scaling to custom value
misc_params['flux_scaling'] = 1e4

simple_disk_img = objective_model(disk_params, spf_params, psf_params, misc_params,
                        ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, DoubleHenyeyGreenstein_SPF, EMP_PSF)

Disk Model Image

[4]:
plt.imshow(simple_disk_img, origin='lower', cmap='inferno')
plt.title('Simple Scattered Light Disk Model')
[4]:
Text(0.5, 1.0, 'Simple Scattered Light Disk Model')
../_images/Tutorials_DiskTutorial_15_1.png

Runtime Analysis

[5]:
%timeit objective_model(disk_params, spf_params, psf_params, Parameter_Index.misc_params, ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, DoubleHenyeyGreenstein_SPF, EMP_PSF)
993 μs ± 892 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Forward Modeling

Generating a Target Disk

To illustrate GRaTer-JAX’s ability to model a disk accurately, we will use the image generated above as our ``data” by adding noise to it.

[6]:
noise_level = 1
seed = 24601
np.random.seed(seed)
noise = np.random.normal(0, noise_level, simple_disk_img.shape)
target_image = simple_disk_img + noise
err_map = jnp.ones(simple_disk_img.shape)*noise_level
plt.imshow(target_image, origin='lower', cmap='inferno')
plt.title('Target Disk')
[6]:
Text(0.5, 1.0, 'Target Disk')
../_images/Tutorials_DiskTutorial_20_1.png

Getting the Log-Likelihood

Now that we have our target image set up, we can set up the infrastructure for the fit. We initialize starting parameters that are deliberately offset from the true values (different sma, inclination, and position_angle) to demonstrate that the optimizer can recover them.

A note about the log-likelihood in GRaTer-JAX:

  • You can set a keyword argument ll_method to either sum or mean.

  • The mean is usually better for gradient descent workflows because it keeps the objective functions on a consistent scale, regardless of image size

  • The sum, on the other hand, usually works better for MCMC since the sums are more precise than the means

  • These are just guidelines not hard rules, so use whichever method works best for your workflow.

A few things to note about the spline SPF parameters:

  • num_knots sets the number of free knot values. The spline always has one additional fixed knot at cos_phi = 0 (90 degrees, value = 1.0), so the spline has num_knots + 1 total control points.

  • backscatt_bound and forwardscatt_bound (not frontscatt_bound) set the cosine-angle range covered by the spline. Values outside this range are extrapolated: linear extrapolation on the forward scattering side (to capture physically motivated forward-scattering behavior), and constant extrapolation at the backscattering boundary.

[7]:
EMP_PSF.img = emp_psf_image

start_spf_params = InterpolatedUnivariateSpline_SPF.params.copy()
start_psf_params = EMP_PSF.params.copy()

start_spf_params['num_knots'] = 4
start_spf_params['knot_values'] = jnp.ones(4)
start_spf_params['backscatt_bound'] = -0.9
start_spf_params['forwardscatt_bound'] = 0.9

start_disk_params = Parameter_Index.disk_params.copy()
# Perturb starting guess away from truth (sma=40, inc=60, PA=30) but within gradient reach
start_disk_params['sma'] = 35
start_disk_params['inclination'] = 50
start_disk_params['position_angle'] = 40

misc_params = Parameter_Index.misc_params.copy()
# Start flux_scaling 2× below the true value of 1e4
misc_params['flux_scaling'] = 5e3

start_image = objective_model(start_disk_params, start_spf_params, start_psf_params, misc_params,
                        ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, InterpolatedUnivariateSpline_SPF, EMP_PSF)

fig, axes = plt.subplots(1, 2)
axes[0].imshow(start_image, origin='lower', cmap='inferno', vmin=0, vmax=10)
axes[0].set_title('Initial Image')
axes[1].imshow(target_image, origin='lower', cmap='inferno', vmin=0, vmax=10)
axes[1].set_title('Target Image')
plt.show()

print(f'Log-Likelihood: {log_likelihood(start_image, target_image, err_map, ll_method="mean")}')
../_images/Tutorials_DiskTutorial_22_0.png
Log-Likelihood: -13.246594097464568

Disk Fitting using Scipy Optimize Minimize

Here we perform a simple fit using scipy.minimize — more complex parameter estimation is possible with MCMC, which is illustrated in the advanced disk modeling tutorial.

[8]:
opt = Optimizer(ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, InterpolatedUnivariateSpline_SPF, EMP_PSF,
                start_disk_params, start_spf_params, start_psf_params, misc_params)

fit_keys = ['sma', 'inclination', 'position_angle', 'knot_values', 'flux_scaling']
logscaled_params = ['knot_values', 'flux_scaling']
array_params = ['knot_values']

nk = start_spf_params['num_knots']
bounds = (
    [0,    0,    0,   np.zeros(nk), 1e0],
    [150,  90,  360,  1e3*np.ones(nk), 1e8],
)

soln = opt.scipy_bounded_optimize(fit_keys, bounds, logscaled_params, array_params,
                                   target_image, err_map, use_grad=True, disp_soln=True, iters=500)
RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =            8     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  1.32466D+01    |proj g|=  1.22467D+01

At iterate    1    f=  5.13428D+00    |proj g|=  8.31159D+00

At iterate    2    f=  3.09070D+00    |proj g|=  6.73641D+00

At iterate    3    f=  1.75483D+00    |proj g|=  3.36645D+00

At iterate    4    f=  1.16921D+00    |proj g|=  1.74458D+00

At iterate    5    f=  8.79353D-01    |proj g|=  8.43197D-01

At iterate    6    f=  7.46850D-01    |proj g|=  3.73329D-01

At iterate    7    f=  6.80809D-01    |proj g|=  1.24504D-01

At iterate    8    f=  6.36095D-01    |proj g|=  1.00747D-01

At iterate    9    f=  5.90680D-01    |proj g|=  5.93322D-02

At iterate   10    f=  5.67931D-01    |proj g|=  7.94534D-02

At iterate   11    f=  5.62380D-01    |proj g|=  4.84469D-02

At iterate   12    f=  5.55815D-01    |proj g|=  1.60445D-02

At iterate   13    f=  5.53923D-01    |proj g|=  1.02575D-02

At iterate   14    f=  5.52836D-01    |proj g|=  8.10256D-03

At iterate   15    f=  5.50551D-01    |proj g|=  1.73617D-02

At iterate   16    f=  5.44972D-01    |proj g|=  1.90235D-02

At iterate   17    f=  5.43070D-01    |proj g|=  4.30690D-02

At iterate   18    f=  5.41422D-01    |proj g|=  2.22181D-02

At iterate   19    f=  5.40426D-01    |proj g|=  1.73453D-02

At iterate   20    f=  5.38053D-01    |proj g|=  8.40614D-03

At iterate   21    f=  5.34752D-01    |proj g|=  7.95210D-03

At iterate   22    f=  5.30916D-01    |proj g|=  2.14449D-02

At iterate   23    f=  5.24791D-01    |proj g|=  1.97159D-02

At iterate   24    f=  5.13090D-01    |proj g|=  2.74549D-02

At iterate   25    f=  5.11366D-01    |proj g|=  2.72399D-02

At iterate   26    f=  5.07581D-01    |proj g|=  4.13448D-02

At iterate   27    f=  5.03314D-01    |proj g|=  9.09905D-03

At iterate   28    f=  5.02754D-01    |proj g|=  4.24950D-03

At iterate   29    f=  5.02607D-01    |proj g|=  7.97418D-03

At iterate   30    f=  5.02506D-01    |proj g|=  1.66059D-03

At iterate   31    f=  5.02466D-01    |proj g|=  1.17427D-03

At iterate   32    f=  5.02359D-01    |proj g|=  1.09700D-03

At iterate   33    f=  5.02251D-01    |proj g|=  1.36186D-03

At iterate   34    f=  5.02122D-01    |proj g|=  1.19104D-03

At iterate   35    f=  5.01662D-01    |proj g|=  4.48908D-03

At iterate   36    f=  5.01311D-01    |proj g|=  1.56942D-02

At iterate   37    f=  5.00872D-01    |proj g|=  7.96930D-03

At iterate   38    f=  5.00458D-01    |proj g|=  4.09628D-03

At iterate   39    f=  5.00274D-01    |proj g|=  2.08938D-03

At iterate   40    f=  5.00267D-01    |proj g|=  2.23872D-03

At iterate   41    f=  5.00261D-01    |proj g|=  2.08744D-04

At iterate   42    f=  5.00261D-01    |proj g|=  1.27325D-04

At iterate   43    f=  5.00260D-01    |proj g|=  2.92911D-04

At iterate   44    f=  5.00260D-01    |proj g|=  2.55053D-04

At iterate   45    f=  5.00260D-01    |proj g|=  2.60350D-04

At iterate   46    f=  5.00259D-01    |proj g|=  1.61525D-04

At iterate   47    f=  5.00259D-01    |proj g|=  9.52033D-06

At iterate   48    f=  5.00259D-01    |proj g|=  4.09235D-06

At iterate   49    f=  5.00259D-01    |proj g|=  3.77639D-06

At iterate   50    f=  5.00259D-01    |proj g|=  9.06016D-06

At iterate   51    f=  5.00259D-01    |proj g|=  4.20666D-06

At iterate   52    f=  5.00259D-01    |proj g|=  3.34326D-06

At iterate   53    f=  5.00259D-01    |proj g|=  3.33596D-06

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
    8     53     71     54     0     0   3.336D-06   5.003D-01
  F =  0.50025938606239595

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  success: True
   status: 0
      fun: 0.500259386062396
        x: [ 3.996e+01  5.994e+01  3.020e+01  1.336e+00  1.940e-01
            -1.019e-01  2.400e-01  6.285e+00]
      nit: 53
      jac: [ 3.742e-07 -4.274e-07  8.057e-08 -2.375e-06 -2.357e-06
             1.324e-07  2.011e-07 -3.336e-06]
     nfev: 71
     njev: 71
 hess_inv: <8x8 LbfgsInvHessProduct with dtype=float64>

Numeric Results

[9]:
fitted_image = opt.get_model()
print(f"Disk Fit Log Likelihood: {log_likelihood(fitted_image, target_image, err_map)}")
opt.print_params()
Disk Fit Log Likelihood: -9805.08396682296
Disk Params: {'accuracy': 0.005, 'alpha_in': 5, 'alpha_out': -5, 'sma': np.float64(39.957633097328475), 'e': 0.0, 'ksi0': 3.0, 'gamma': 2.0, 'beta': 1.0, 'rmin': 0.0, 'dens_at_r0': 1.0, 'inclination': np.float64(59.93720052006036), 'position_angle': np.float64(30.200311103327692), 'x_center': 70.0, 'y_center': 70.0, 'omega': 0.0}
SPF Params: {'backscatt_bound': -0.9, 'forwardscatt_bound': 0.9, 'num_knots': 4, 'knot_values': array([3.80432572, 1.21406514, 0.9031456 , 1.27126598])}
PSF Params: {'scale_factor': 1.0, 'offset': 1.0}
Stellar PSF Params: None
Misc Params: {'distance': 50.0, 'pxInArcsec': 0.01414, 'nx': 140, 'ny': 140, 'halfNbSlices': 25, 'flux_scaling': np.float64(536.1988796365137)}

Visual Results

[10]:
fig, axes = plt.subplots(2, 2, figsize=(12,8))

fitted_pic = axes[0][0].imshow(fitted_image, origin='lower', cmap='inferno')
axes[0][0].set_title('Disk Fit')
plt.colorbar(fitted_pic, ax=axes[0][0], shrink=0.75)

target_pic = axes[0][1].imshow(target_image, origin='lower', cmap='inferno')
axes[0][1].set_title('Target Disk')
plt.colorbar(target_pic, ax=axes[0][1], shrink=0.75)

residual_pic = axes[1][0].imshow(target_image-fitted_image, origin='lower', cmap='inferno')
axes[1][0].set_title('Residual')
plt.colorbar(residual_pic, ax=axes[1][0], shrink=0.75)

residual_pic = axes[1][1].imshow(np.where(err_map == 0, 0., (target_image-fitted_image)/err_map), origin='lower', cmap='inferno')
axes[1][1].set_title('Signal-To-Noise Ratio')
plt.colorbar(residual_pic, ax=axes[1][1], shrink=0.75)

fig.suptitle("Disk Fit Results vs Target Image")
plt.show()
../_images/Tutorials_DiskTutorial_28_0.png

As you can see, even with a relatively simple fitting routine, GRaTer-JAX is able to produce a model that visually matches the target quite well. Although the target disk was generated with a Double Henyey-Greenstein SPF, we perform the fit using the flexible spline framework. The plot below illustrates how the spline is able to recover the input SPF.

Spline Fit Results

[11]:
x_vals = jnp.linspace(-1, 1, 200)
knot_locs = InterpolatedUnivariateSpline_SPF.get_knots(opt.spf_params)
fitted_spf = InterpolatedUnivariateSpline_SPF.pack_pars(opt.spf_params['knot_values'], knots=knot_locs)

# Normalise the target DHG to 1 at cos(phi) = 0, matching the fitted spline's convention
dhg_vals = DoubleHenyeyGreenstein_SPF.compute_phase_function_from_cosphi(
               DoubleHenyeyGreenstein_SPF.pack_pars(spf_params), x_vals)
dhg_at_90 = float(DoubleHenyeyGreenstein_SPF.compute_phase_function_from_cosphi(
                DoubleHenyeyGreenstein_SPF.pack_pars(spf_params), jnp.array([0.0]))[0])
dhg_normalised = dhg_vals / dhg_at_90

fwd = float(opt.spf_params['forwardscatt_bound'])
bck = float(opt.spf_params['backscatt_bound'])

fig, ax = plt.subplots()

# Shade scattering angles not probed at this inclination
ax.axvspan(-1.0, bck, alpha=0.15, color='gray', label='Not probed')
ax.axvspan(fwd,  1.0, alpha=0.15, color='gray')

ax.scatter(knot_locs, fitted_spf(knot_locs), zorder=3, label='Fitted spline knots')
ax.plot(x_vals, fitted_spf(x_vals), label='Fitted spline SPF')
ax.plot(x_vals, dhg_normalised, '--', label='Target SPF (normalised to 1 at 90°)')

ax.set_xlabel('cos(phi)')
ax.set_ylabel('Phase function (normalised to 1 at 90°)')
ax.legend()
plt.show()
../_images/Tutorials_DiskTutorial_31_0.png

Takeaways

As you can see, the objective functions and basics of the Optimizer class provide a useful utility for fitting disks with great accuracy. If you want a simpler way to work with more complex disks, leverage autodifferentiation, and run comprehensive sampling algorithms, refer to the Advanced Disk Modeling tutorial where we demonstrate the Optimizer class.