Disk Modeling Basics
Introduction
grater-jax.The model is structured into the following pieces:
Dust Distribution – specifies the 3D density of dust grains (e.g., radial power-law slopes, eccentricity, vertical scale height).
Scattering Phase Function (SPF) – controls how starlight is scattered by dust as a function of angle (Henyey–Greenstein, double HG, or spline-based models).
Point Spread Function (PSF) – applies instrument response by convolving the model image with a chosen PSF kernel (Gaussian, empirical, etc.).
The notebook demonstrates how to configure these components, run them through the objective_model pipeline, and generate synthetic disk images that can be compared directly with observations. Additional information on the analytic model definition is also provided here.
Defining the Dust Distribution
GRaTeR-JAX follows the same dust distribution parameterization as the disk forward modeling module in VIP, which is also based off the framework introduced in Augereau+ 1999. The density distribution is in cylindrical coordinates, and contains three major terms: the surface density at the reference radius of the dust in the midplane (\(z\)=0) \(\rho_0\), the density distribution in the midplane, and the vertical profile, as shown below:
where:
\(\alpha_{\rm in}\) and \(\alpha_{\rm out}\) are the inner and outer power laws describing the radial dust distribution
\(R(\theta)\) is the reference radius
\(a\) is the semimajor axis
\(e\) is the eccentricity
\(\gamma\) is the decay exponent (\(\gamma\)=2 corresponds to a Gaussian profile)
\(H(r)\) is the scale height defined as follows (where \(\xi_0\) is the scale height at the reference radius and \(\beta\) is the flaring coefficient (with \(\beta\)=1 corresponding to linear flaring)):
Note that the reference radius is dependent on \(\theta\) for an elliptical disk, as \(R(\theta) = \frac{a(1-e^2)}{1+e{\rm cos}(\theta)}\).
Scattering Phase Functions
The scattering phase function (SPF) describes the amount of light scattered by a particle (e.g. a dust grain) at different scattering angles \(\theta_{\rm sca}\), which ranges from 0 (forward scattering) to 180 (back scattering). Debris disks generally have more forward scattering phase functions.
The range of scattering angles probed in any given observation is limited by viewing geometry, and that limit is incorporated into GRaTeR-JAX by using an initial guess of inclination to bound the scattering angles used in the model’s SPF with the relationships \(\theta_{\rm sca,min}\approx90-i\) and \(\theta_{\rm sca,max}\approx90+i\).
GRaTeR-JAX provides three options for the SPF:
Henyey-Greenstein: an analytical function determined by the scattering angle \(\theta_{\rm sca}\) and the asymmetry parameter \(g\), where isotropic scattering is indicated by \(g = 0\), forward scattering by \(0 \le g \le 1\), and back scattering by \(-1 \le g \le 0\) (Henyey and Greenstein 1941).
\[I(\theta_{\rm sca}) = \frac{1}{4\pi} \frac{1-g^2}{\left(1+g^2-2g\cos\theta_{\rm sca}\right)^{3/2}}\]Double Henyey-Greenstein: a weighted combination of two HG functions, where the free parameters are the first asymmetry parameter \(g_1\), the second function’s asymmetry parameter, \(g_2\), and the relative weight of the two
Spline: a more flexible framework not based on an analytic function; here, a specified number of knots are fit at scattering angles between bounds defined by the approximate inclination of the disk. The number of knots \(n_k\) is determined by the user. The absolute values of the spline knots are degenerate with the overall flux scaling of the disk; as a result, the SPF is normalized such that it is equal to 1 at 90 deg and the flux scaling factor is adjusted accordingly. See the spline injection and recovery tutorial notebook for more information on the assumptions in this formalism.
Note: \(n_k\) refers to the number of unfixed knots (the number of knots not including the fixed knot at (0, 1)). If an \(n_k\) is odd, the extra knot will go on the forward scattering side of the spline.
Setup
First, import the relevant components of GRaTer-JAX and outside packages. The two environment variables set are related to JAX’s handling of memory allocation, and reduce memory usage errors.
Note: This notebook was run on an NVIDIA GeForce RTX 4090
[1]:
import os
# Can limit JAX/XLA to using at most a fraction of total GPU memory to help
# avoid out-of-memory issues when sharing the GPU with other processes.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.15"
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
from astropy.io import fits
from grater_jax.disk_model.objective_functions import objective_model, log_likelihood, Parameter_Index
from grater_jax.disk_model.SLD_ojax import ScatteredLightDisk
from grater_jax.optimization.optimize_framework import Optimizer
from grater_jax.disk_model.SLD_utils import (
DoubleHenyeyGreenstein_SPF,
InterpolatedUnivariateSpline_SPF,
EMP_PSF,
DustEllipticalDistribution2PowerLaws,
)
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True) <-- can use this to trace issues at the cost of runtime
Generating Disk Models with Simple PSFs
This example notebook uses an empirically-measured PSF from Gemini Planet Imager H-band data (Macintosh+ 2014, Crotts+ 2025). More complex PSFs, such as those from JWST, are explored in the JWST Tutorial (Experimental).
[2]:
# Generates a simple psf function image
scale_factor = 1
image = fits.open("tutorial_images/GPI_Hband_PSF.fits")[0].data[0,:,:]
scaled_image = (image[::scale_factor, ::scale_factor])[1::, 1::]
cropped_image = image[70:211, 70:211]
def safe_float32_conversion(value):
try:
return np.float32(value)
except (ValueError, TypeError):
print("This value is unjaxable: " + str(value))
emp_psf_image = np.nan_to_num(cropped_image)
emp_psf_image = np.vectorize(safe_float32_conversion)(emp_psf_image)
Here, we generate a simple disk model with specified parameters. There are four dictionaries: disk_params, spf_params, psf_params, and misc_params. These contain the following parameters relevant to a user:
Disk Parameters (``disk_params``):
inner radial density exponent \(\alpha_{\rm in}\) (
alpha_in)outer radial density exponent \(\alpha_{\rm out}\) (
alpha_out)semimajor axis of the disk \(a\) (
sma)disk eccentricity \(e\) (
e)scale height at the reference radius \(\xi_0\) (
ksi0)vertical decay exponent \(\gamma\) (
gamma)flaring exponent \(\beta\) (
beta)disk inclination \(i\) (
inclination)disk position angle \(\mathrm{PA}\) (
position_angle)x offset of the disk center \(x_c\) (
x_center)y offset of the disk center \(y_c\) (
y_center)argument of pericenter \(\omega\) (
omega)
SPF Parameters (``spf_params``):
If Single HG (HenyeyGreenstein_SPF):
asymmetry parameter \(g\) (
g)
If Double HG (DoubleHenyeyGreenstein_SPF):
first asymmetry parameter \(g_1\) (
g1)second asymmetry parameter \(g_2\) (
g2)relative weight (
weight)
If Spline (InterpolatedUnivariateSpline_SPF):
number of knots \(n_k\) (
num_knots)array of knot \(y\) values (
knot_values)back scattering bound — cosine of angle in radians, defaults to -1 (equivalent to 180 deg), but can be set by inclination (
backscatt_bound)forward scattering bound — cosine of angle in radians, defaults to 1 (equivalent to 0 deg), but can be set by inclination (
forwardscatt_bound)
Miscellaneous Parameters (``misc_params``):
distance to target (
distance)pixel scale in arcsec/pix (
pxInArcsec)number of pixels in x dimension (
nx)number of pixels in y dimension (
ny)number of slices calculated in the model (increase if you are encountering computational artifacts, but increasing also lengthens runtime) (
halfNbSlices)arbitrary flux scaling (
flux_scaling)
[3]:
EMP_PSF.img = emp_psf_image
# Using parameters dictionaries from SLD_utils
spf_params = DoubleHenyeyGreenstein_SPF.params
psf_params = EMP_PSF.params
# Initializing the disk parameter dictionary (all the possible parameters that can be configured)
disk_params = Parameter_Index.disk_params
# Setting certain parameters to custom values
disk_params['sma'] = 40
disk_params['inclination'] = 60
disk_params['position_angle'] = 30
# Initializing the misc parameter dictionary
misc_params = Parameter_Index.misc_params
# Setting flux_scaling to custom value
misc_params['flux_scaling'] = 1e4
simple_disk_img = objective_model(disk_params, spf_params, psf_params, misc_params,
ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, DoubleHenyeyGreenstein_SPF, EMP_PSF)
Disk Model Image
[4]:
plt.imshow(simple_disk_img, origin='lower', cmap='inferno')
plt.title('Simple Scattered Light Disk Model')
[4]:
Text(0.5, 1.0, 'Simple Scattered Light Disk Model')
Runtime Analysis
[5]:
%timeit objective_model(disk_params, spf_params, psf_params, Parameter_Index.misc_params, ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, DoubleHenyeyGreenstein_SPF, EMP_PSF)
993 μs ± 892 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Forward Modeling
Generating a Target Disk
To illustrate GRaTer-JAX’s ability to model a disk accurately, we will use the image generated above as our ``data” by adding noise to it.
[6]:
noise_level = 1
seed = 24601
np.random.seed(seed)
noise = np.random.normal(0, noise_level, simple_disk_img.shape)
target_image = simple_disk_img + noise
err_map = jnp.ones(simple_disk_img.shape)*noise_level
plt.imshow(target_image, origin='lower', cmap='inferno')
plt.title('Target Disk')
[6]:
Text(0.5, 1.0, 'Target Disk')
Getting the Log-Likelihood
Now that we have our target image set up, we can set up the infrastructure for the fit. We initialize starting parameters that are deliberately offset from the true values (different sma, inclination, and position_angle) to demonstrate that the optimizer can recover them.
A note about the log-likelihood in GRaTer-JAX:
You can set a keyword argument
ll_methodto eithersumormean.The mean is usually better for gradient descent workflows because it keeps the objective functions on a consistent scale, regardless of image size
The sum, on the other hand, usually works better for MCMC since the sums are more precise than the means
These are just guidelines not hard rules, so use whichever method works best for your workflow.
A few things to note about the spline SPF parameters:
num_knotssets the number of free knot values. The spline always has one additional fixed knot at cos_phi = 0 (90 degrees, value = 1.0), so the spline hasnum_knots + 1total control points.backscatt_boundandforwardscatt_bound(notfrontscatt_bound) set the cosine-angle range covered by the spline. Values outside this range are extrapolated: linear extrapolation on the forward scattering side (to capture physically motivated forward-scattering behavior), and constant extrapolation at the backscattering boundary.
[7]:
EMP_PSF.img = emp_psf_image
start_spf_params = InterpolatedUnivariateSpline_SPF.params.copy()
start_psf_params = EMP_PSF.params.copy()
start_spf_params['num_knots'] = 4
start_spf_params['knot_values'] = jnp.ones(4)
start_spf_params['backscatt_bound'] = -0.9
start_spf_params['forwardscatt_bound'] = 0.9
start_disk_params = Parameter_Index.disk_params.copy()
# Perturb starting guess away from truth (sma=40, inc=60, PA=30) but within gradient reach
start_disk_params['sma'] = 35
start_disk_params['inclination'] = 50
start_disk_params['position_angle'] = 40
misc_params = Parameter_Index.misc_params.copy()
# Start flux_scaling 2× below the true value of 1e4
misc_params['flux_scaling'] = 5e3
start_image = objective_model(start_disk_params, start_spf_params, start_psf_params, misc_params,
ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, InterpolatedUnivariateSpline_SPF, EMP_PSF)
fig, axes = plt.subplots(1, 2)
axes[0].imshow(start_image, origin='lower', cmap='inferno', vmin=0, vmax=10)
axes[0].set_title('Initial Image')
axes[1].imshow(target_image, origin='lower', cmap='inferno', vmin=0, vmax=10)
axes[1].set_title('Target Image')
plt.show()
print(f'Log-Likelihood: {log_likelihood(start_image, target_image, err_map, ll_method="mean")}')
Log-Likelihood: -13.246594097464568
Disk Fitting using Scipy Optimize Minimize
Here we perform a simple fit using scipy.minimize — more complex parameter estimation is possible with MCMC, which is illustrated in the advanced disk modeling tutorial.
[8]:
opt = Optimizer(ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, InterpolatedUnivariateSpline_SPF, EMP_PSF,
start_disk_params, start_spf_params, start_psf_params, misc_params)
fit_keys = ['sma', 'inclination', 'position_angle', 'knot_values', 'flux_scaling']
logscaled_params = ['knot_values', 'flux_scaling']
array_params = ['knot_values']
nk = start_spf_params['num_knots']
bounds = (
[0, 0, 0, np.zeros(nk), 1e0],
[150, 90, 360, 1e3*np.ones(nk), 1e8],
)
soln = opt.scipy_bounded_optimize(fit_keys, bounds, logscaled_params, array_params,
target_image, err_map, use_grad=True, disp_soln=True, iters=500)
RUNNING THE L-BFGS-B CODE
* * *
Machine precision = 2.220D-16
N = 8 M = 10
At X0 0 variables are exactly at the bounds
At iterate 0 f= 1.32466D+01 |proj g|= 1.22467D+01
At iterate 1 f= 5.13428D+00 |proj g|= 8.31159D+00
At iterate 2 f= 3.09070D+00 |proj g|= 6.73641D+00
At iterate 3 f= 1.75483D+00 |proj g|= 3.36645D+00
At iterate 4 f= 1.16921D+00 |proj g|= 1.74458D+00
At iterate 5 f= 8.79353D-01 |proj g|= 8.43197D-01
At iterate 6 f= 7.46850D-01 |proj g|= 3.73329D-01
At iterate 7 f= 6.80809D-01 |proj g|= 1.24504D-01
At iterate 8 f= 6.36095D-01 |proj g|= 1.00747D-01
At iterate 9 f= 5.90680D-01 |proj g|= 5.93322D-02
At iterate 10 f= 5.67931D-01 |proj g|= 7.94534D-02
At iterate 11 f= 5.62380D-01 |proj g|= 4.84469D-02
At iterate 12 f= 5.55815D-01 |proj g|= 1.60445D-02
At iterate 13 f= 5.53923D-01 |proj g|= 1.02575D-02
At iterate 14 f= 5.52836D-01 |proj g|= 8.10256D-03
At iterate 15 f= 5.50551D-01 |proj g|= 1.73617D-02
At iterate 16 f= 5.44972D-01 |proj g|= 1.90235D-02
At iterate 17 f= 5.43070D-01 |proj g|= 4.30690D-02
At iterate 18 f= 5.41422D-01 |proj g|= 2.22181D-02
At iterate 19 f= 5.40426D-01 |proj g|= 1.73453D-02
At iterate 20 f= 5.38053D-01 |proj g|= 8.40614D-03
At iterate 21 f= 5.34752D-01 |proj g|= 7.95210D-03
At iterate 22 f= 5.30916D-01 |proj g|= 2.14449D-02
At iterate 23 f= 5.24791D-01 |proj g|= 1.97159D-02
At iterate 24 f= 5.13090D-01 |proj g|= 2.74549D-02
At iterate 25 f= 5.11366D-01 |proj g|= 2.72399D-02
At iterate 26 f= 5.07581D-01 |proj g|= 4.13448D-02
At iterate 27 f= 5.03314D-01 |proj g|= 9.09905D-03
At iterate 28 f= 5.02754D-01 |proj g|= 4.24950D-03
At iterate 29 f= 5.02607D-01 |proj g|= 7.97418D-03
At iterate 30 f= 5.02506D-01 |proj g|= 1.66059D-03
At iterate 31 f= 5.02466D-01 |proj g|= 1.17427D-03
At iterate 32 f= 5.02359D-01 |proj g|= 1.09700D-03
At iterate 33 f= 5.02251D-01 |proj g|= 1.36186D-03
At iterate 34 f= 5.02122D-01 |proj g|= 1.19104D-03
At iterate 35 f= 5.01662D-01 |proj g|= 4.48908D-03
At iterate 36 f= 5.01311D-01 |proj g|= 1.56942D-02
At iterate 37 f= 5.00872D-01 |proj g|= 7.96930D-03
At iterate 38 f= 5.00458D-01 |proj g|= 4.09628D-03
At iterate 39 f= 5.00274D-01 |proj g|= 2.08938D-03
At iterate 40 f= 5.00267D-01 |proj g|= 2.23872D-03
At iterate 41 f= 5.00261D-01 |proj g|= 2.08744D-04
At iterate 42 f= 5.00261D-01 |proj g|= 1.27325D-04
At iterate 43 f= 5.00260D-01 |proj g|= 2.92911D-04
At iterate 44 f= 5.00260D-01 |proj g|= 2.55053D-04
At iterate 45 f= 5.00260D-01 |proj g|= 2.60350D-04
At iterate 46 f= 5.00259D-01 |proj g|= 1.61525D-04
At iterate 47 f= 5.00259D-01 |proj g|= 9.52033D-06
At iterate 48 f= 5.00259D-01 |proj g|= 4.09235D-06
At iterate 49 f= 5.00259D-01 |proj g|= 3.77639D-06
At iterate 50 f= 5.00259D-01 |proj g|= 9.06016D-06
At iterate 51 f= 5.00259D-01 |proj g|= 4.20666D-06
At iterate 52 f= 5.00259D-01 |proj g|= 3.34326D-06
At iterate 53 f= 5.00259D-01 |proj g|= 3.33596D-06
* * *
Tit = total number of iterations
Tnf = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip = number of BFGS updates skipped
Nact = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F = final function value
* * *
N Tit Tnf Tnint Skip Nact Projg F
8 53 71 54 0 0 3.336D-06 5.003D-01
F = 0.50025938606239595
CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
success: True
status: 0
fun: 0.500259386062396
x: [ 3.996e+01 5.994e+01 3.020e+01 1.336e+00 1.940e-01
-1.019e-01 2.400e-01 6.285e+00]
nit: 53
jac: [ 3.742e-07 -4.274e-07 8.057e-08 -2.375e-06 -2.357e-06
1.324e-07 2.011e-07 -3.336e-06]
nfev: 71
njev: 71
hess_inv: <8x8 LbfgsInvHessProduct with dtype=float64>
Numeric Results
[9]:
fitted_image = opt.get_model()
print(f"Disk Fit Log Likelihood: {log_likelihood(fitted_image, target_image, err_map)}")
opt.print_params()
Disk Fit Log Likelihood: -9805.08396682296
Disk Params: {'accuracy': 0.005, 'alpha_in': 5, 'alpha_out': -5, 'sma': np.float64(39.957633097328475), 'e': 0.0, 'ksi0': 3.0, 'gamma': 2.0, 'beta': 1.0, 'rmin': 0.0, 'dens_at_r0': 1.0, 'inclination': np.float64(59.93720052006036), 'position_angle': np.float64(30.200311103327692), 'x_center': 70.0, 'y_center': 70.0, 'omega': 0.0}
SPF Params: {'backscatt_bound': -0.9, 'forwardscatt_bound': 0.9, 'num_knots': 4, 'knot_values': array([3.80432572, 1.21406514, 0.9031456 , 1.27126598])}
PSF Params: {'scale_factor': 1.0, 'offset': 1.0}
Stellar PSF Params: None
Misc Params: {'distance': 50.0, 'pxInArcsec': 0.01414, 'nx': 140, 'ny': 140, 'halfNbSlices': 25, 'flux_scaling': np.float64(536.1988796365137)}
Visual Results
[10]:
fig, axes = plt.subplots(2, 2, figsize=(12,8))
fitted_pic = axes[0][0].imshow(fitted_image, origin='lower', cmap='inferno')
axes[0][0].set_title('Disk Fit')
plt.colorbar(fitted_pic, ax=axes[0][0], shrink=0.75)
target_pic = axes[0][1].imshow(target_image, origin='lower', cmap='inferno')
axes[0][1].set_title('Target Disk')
plt.colorbar(target_pic, ax=axes[0][1], shrink=0.75)
residual_pic = axes[1][0].imshow(target_image-fitted_image, origin='lower', cmap='inferno')
axes[1][0].set_title('Residual')
plt.colorbar(residual_pic, ax=axes[1][0], shrink=0.75)
residual_pic = axes[1][1].imshow(np.where(err_map == 0, 0., (target_image-fitted_image)/err_map), origin='lower', cmap='inferno')
axes[1][1].set_title('Signal-To-Noise Ratio')
plt.colorbar(residual_pic, ax=axes[1][1], shrink=0.75)
fig.suptitle("Disk Fit Results vs Target Image")
plt.show()
As you can see, even with a relatively simple fitting routine, GRaTer-JAX is able to produce a model that visually matches the target quite well. Although the target disk was generated with a Double Henyey-Greenstein SPF, we perform the fit using the flexible spline framework. The plot below illustrates how the spline is able to recover the input SPF.
Spline Fit Results
[11]:
x_vals = jnp.linspace(-1, 1, 200)
knot_locs = InterpolatedUnivariateSpline_SPF.get_knots(opt.spf_params)
fitted_spf = InterpolatedUnivariateSpline_SPF.pack_pars(opt.spf_params['knot_values'], knots=knot_locs)
# Normalise the target DHG to 1 at cos(phi) = 0, matching the fitted spline's convention
dhg_vals = DoubleHenyeyGreenstein_SPF.compute_phase_function_from_cosphi(
DoubleHenyeyGreenstein_SPF.pack_pars(spf_params), x_vals)
dhg_at_90 = float(DoubleHenyeyGreenstein_SPF.compute_phase_function_from_cosphi(
DoubleHenyeyGreenstein_SPF.pack_pars(spf_params), jnp.array([0.0]))[0])
dhg_normalised = dhg_vals / dhg_at_90
fwd = float(opt.spf_params['forwardscatt_bound'])
bck = float(opt.spf_params['backscatt_bound'])
fig, ax = plt.subplots()
# Shade scattering angles not probed at this inclination
ax.axvspan(-1.0, bck, alpha=0.15, color='gray', label='Not probed')
ax.axvspan(fwd, 1.0, alpha=0.15, color='gray')
ax.scatter(knot_locs, fitted_spf(knot_locs), zorder=3, label='Fitted spline knots')
ax.plot(x_vals, fitted_spf(x_vals), label='Fitted spline SPF')
ax.plot(x_vals, dhg_normalised, '--', label='Target SPF (normalised to 1 at 90°)')
ax.set_xlabel('cos(phi)')
ax.set_ylabel('Phase function (normalised to 1 at 90°)')
ax.legend()
plt.show()
Takeaways
As you can see, the objective functions and basics of the Optimizer class provide a useful utility for fitting disks with great accuracy. If you want a simpler way to work with more complex disks, leverage autodifferentiation, and run comprehensive sampling algorithms, refer to the Advanced Disk Modeling tutorial where we demonstrate the Optimizer class.