Disk Modeling Basics
Introduction
grater-jax.The model is structured into the following pieces:
Dust Distribution – specifies the 3D density of dust grains (e.g., radial power-law slopes, eccentricity, vertical scale height).
Scattering Phase Function (SPF) – controls how starlight is scattered by dust as a function of angle (Henyey–Greenstein, double HG, or spline-based models).
Point Spread Function (PSF) – applies instrument response by convolving the model image with a chosen PSF kernel (Gaussian, empirical, etc.).
The notebook demonstrates how to configure these components, run them through the objective_model pipeline, and generate synthetic disk images that can be compared directly with observations. Additional information on the analytic model definition is also provided here.
Defining the Dust Distribution
GRaTeR-JAX follows the same dust distribution parameterization as the disk forward modeling module in VIP, which is also based off the framework introduced in Augereau+ 1999. The density distribution is in cylindrical coordinates, and contains three major terms: the surface density at the reference radius of the dust in the midplane (\(z\)=0) \(\rho_0\), the density distribution in the midplane, and the vertical profile, as shown below:
where:
\(\alpha_{\rm in}\) and \(\alpha_{\rm out}\) are the inner and outer power laws describing the radial dust distribution
\(R(\theta)\) is the reference radius
\(a\) is the semimajor axis
\(e\) is the eccentricity
\(\gamma\) is the decay exponent (\(\gamma\)=2 corresponds to a Gaussian profile)
\(H(r)\) is the scale height defined as follows (where \(\xi_0\) is the scale height at the reference radius and \(\beta\) is the flaring coefficient (with \(\beta\)=1 corresponding to linear flaring)):
- :nbsphinx-math:`begin{equation}
H(r) = xi_0 times left(frac{r}{R(theta)}right)^beta
end{equation}`
Note that the reference radius is dependent on \(\theta\) for an elliptical disk, as \(R(\theta) = \frac{a(1-e^2)}{1+e{\rm cos}(\theta)}\).
Scattering Phase Functions
The scattering phase function (SPF) describes the amount of light scattered by a particle (e.g. a dust grain) at different scattering angles \(\theta_{\rm sca}\), which ranges from 0 (forward scattering) to 180 (back scattering). Debris disks generally have more forward scattering phase functions.
The range of scattering angles probed in any given observation is limited by viewing geometry, and that limit is incorporated into GRaTeR-JAX by using an initial guess of inclination to bound the scattering angles used in the model’s SPF with the relationships \(\theta_{\rm sca,min}\approx90-i\) and \(\theta_{\rm sca,max}\approx90+i\).
GRaTeR-JAX provides three options for the SPF:
- Henyey-Greenstein: an analytical function determined by the scattering angle \(\theta_{\rm sca}\) and the asymmetry parameter \(g\), where isotropic scattering is indicated with \(g\)=0, forward scattering with 0$:nbsphinx-math:leq `g :nbsphinx-math:leq`\(1, and back scattering with -1\)\geq `g :nbsphinx-math:geq`$0 (Henyey and Greenstein 1941) :nbsphinx-math:`begin{equation}
I(theta_{rm sca}) = frac{1}{4pi} frac{1-g^2}{[1+g^2-2g {rm cos}theta_{rm sca}]^{3/2}}
end{equation}`
Double Henyey-Greenstein: a weighted combination of two HG functions, where the free parameters are the first asymmetry parameter \(g_1\), the second function’s asymmetry parameter, \(g_2\), and the relative weight of the two
Spline: a more flexible framework not based on an analytic function; here, a specified number of knots are fit at scattering angles between bounds defined by the approximate inclination of the disk. The number of knots \(n_k\) is determined by the user. The absolute values of the spline knots are degenerate with the overall flux scaling of the disk; as a result, the SPF is normalized such that it is equal to 1 at 90 deg and the flux scaling factor is adjusted accordingly. See the spline injection and recovery tutorial notebook for more information on the assumptions in this formalism.
Setup
First, import the relevant components of GRaTer-JAX and outside packages. The two environment variables set are related to JAX’s handling of memory allocation, and reduce memory usage errors.
[1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.15"
from grater_jax.disk_model.objective_functions import objective_model, log_likelihood, Parameter_Index
from grater_jax.disk_model.SLD_utils import DoubleHenyeyGreenstein_SPF, InterpolatedUnivariateSpline_SPF, EMP_PSF, DustEllipticalDistribution2PowerLaws, Winnie_PSF
from grater_jax.disk_model.SLD_ojax import ScatteredLightDisk
from grater_jax.optimization.optimize_framework import Optimizer
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
from astropy.io import fits
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True) <-- can use this to trace issues at the cost of runtime
Generating Disk Models with Simple PSFs
This example notebook uses an empirically-measured PSF from Gemini Planet Imager H-band data (Macintosh+ 2014, Crotts+ 2025). More complex PSFs, such as those from JWST, are explored in the JWST Tutorial (Experimental).
[2]:
# Generates a simple psf function image
scale_factor = 1
image = fits.open("tutorial_images/GPI_Hband_PSF.fits")[0].data[0,:,:]
scaled_image = (image[::scale_factor, ::scale_factor])[1::, 1::]
cropped_image = image[70:211, 70:211]
def safe_float32_conversion(value):
try:
return np.float32(value)
except (ValueError, TypeError):
print("This value is unjaxable: " + str(value))
emp_psf_image = np.nan_to_num(cropped_image)
emp_psf_image = np.vectorize(safe_float32_conversion)(emp_psf_image)
Here, we generate a simple disk model with specified parameters. There are four dictionaries: disk_params, spf_params, psf_params, and misc_params. These contain the following parameters relevant to a user:
Disk Parameters (``disk_params``):
\(\alpha_{\rm in}\) (keyword:
alpha_in)\(\alpha_{\rm out}\) (
alpha_out)semimajor axis (
sma)eccentricity \(e\) (
e)\(xi_0\) (
ksi0)\(\gamma\) (
gamma)\(\beta\) (
beta)\(i\) (
inclination)PA (
position_angle)\(x_c\) (
x_center)\(y_c\) (
y_center)\(\omega\) (
omega)
SPF Parameters (``spf_params``):
If Single HG (HenyeyGreenstein_SPF):
asymmetry parameter \(g\) (
g)
If Double HG (DoubleHenyeyGreenstein_SPF):
first asymmetry parameter \(g_1\) (
g1)second asymmetry parameter \(g_2\) (
g2)relative weight (
weight)
If Spline (InterpolatedUnivariateSpline_SPF):
number of knots \(n_k\) (
num_knots)array of knot \(y\) values (
knot_values)back scattering bound — cosine of angle in radians, defaults to -1 (equivalent to 180 deg), but can be set by inclination (
backscatt_bound)forward scattering bound — cosine of angle in radians, defaults to 1 (equivalent to 0 deg), but can be set by inclination (
forwardscatt_bound)
Miscellaneous Parameters (``misc_params``):
distance to target (
distance)pixel scale in arcsec/pix (
pxInArcsec)number of pixels in x dimension (
nx)number of pixels in y dimension (
ny)number of slices calculated in the model (increase if you are encountering computational artifacts, but increasing also lengthens runtime) (
halfNbSlices)arbitrary flux scaling (
flux_scaling)
[3]:
EMP_PSF.img = emp_psf_image
# Using parameters dictionaries from SLD_utils
spf_params = DoubleHenyeyGreenstein_SPF.params
psf_params = EMP_PSF.params
# Initializing the disk parameter dictionary (all the possible parameters that can be configured)
disk_params = Parameter_Index.disk_params
# Setting certain parameters to custom values
disk_params['sma'] = 40
disk_params['inclination'] = 60
disk_params['position_angle'] = 30
# Initializing the misc parameter dictionary
misc_params = Parameter_Index.misc_params
# Setting flux_scaling to custom value
misc_params['flux_scaling'] = 1e4
simple_disk_img = objective_model(disk_params, spf_params, psf_params, misc_params,
ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, DoubleHenyeyGreenstein_SPF, EMP_PSF)
Disk Model Image
[4]:
plt.imshow(simple_disk_img, origin='lower', cmap='inferno')
plt.title('Simple Scattered Light Disk Model')
[4]:
Text(0.5, 1.0, 'Simple Scattered Light Disk Model')
Runtime Analysis
[5]:
%timeit objective_model(disk_params, spf_params, psf_params, Parameter_Index.misc_params, ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, DoubleHenyeyGreenstein_SPF, EMP_PSF)
2.3 ms ± 6.96 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Forward Modeling
Generating a Target Disk
To illustrate GRaTer-JAX’s ability to model a disk accurately, we will use the image generated above as our ``data” by adding noise to it.
[6]:
noise_level = 1
noise = np.random.normal(0, noise_level, simple_disk_img.shape)
target_image = simple_disk_img + noise
err_map = jnp.ones(simple_disk_img.shape)*noise_level
plt.imshow(target_image, origin='lower', cmap='inferno')
plt.title('Target Disk')
[6]:
Text(0.5, 1.0, 'Target Disk')
Getting the Log-Likelihood
Now that we have our target image set up, we can set up the infrastructure for the fit. We initialize starting parameters that are deliberately offset from the true values (different sma, inclination, and position_angle) to demonstrate that the optimizer can recover them.
A few things to note about the spline SPF parameters:
num_knotssets the number of free knot values. The spline always has one additional fixed knot at cos_phi = 0 (90 degrees, value = 1.0), so the spline hasnum_knots + 1total control points.backscatt_boundandforwardscatt_bound(notfrontscatt_bound) set the cosine-angle range covered by the spline. Values outside this range are extrapolated: linear extrapolation on the forward scattering side (to capture physically motivated forward-scattering behavior), and constant extrapolation at the backscattering boundary.
[7]:
EMP_PSF.img = emp_psf_image
start_spf_params = InterpolatedUnivariateSpline_SPF.params.copy()
start_psf_params = EMP_PSF.params.copy()
start_spf_params['num_knots'] = 5
start_spf_params['knot_values'] = jnp.ones(5)
start_spf_params['backscatt_bound'] = -0.9
start_spf_params['forwardscatt_bound'] = 0.9
start_disk_params = Parameter_Index.disk_params.copy()
# Perturb starting guess away from truth (sma=40, inc=60, PA=30) but within gradient reach
start_disk_params['sma'] = 35
start_disk_params['inclination'] = 50
start_disk_params['position_angle'] = 40
misc_params = Parameter_Index.misc_params.copy()
# Start flux_scaling 2× below the true value of 1e4
misc_params['flux_scaling'] = 5e3
start_image = objective_model(start_disk_params, start_spf_params, start_psf_params, misc_params,
ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, InterpolatedUnivariateSpline_SPF, EMP_PSF)
fig, axes = plt.subplots(1, 2)
axes[0].imshow(start_image, origin='lower', cmap='inferno', vmin=0, vmax=10)
axes[0].set_title('Initial Image')
axes[1].imshow(target_image, origin='lower', cmap='inferno', vmin=0, vmax=10)
axes[1].set_title('Target Image')
plt.show()
print(f'Log-Likelihood: {log_likelihood(start_image, target_image, err_map)}')
Log-Likelihood: -13.169985534888209
Disk Fitting using Scipy Optimize Minimize
Here we perform a simple fit using scipy.minimize — more complex parameter estimation is possible with MCMC, which is illustrated in the advanced disk modeling tutorial.
[8]:
opt = Optimizer(ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, InterpolatedUnivariateSpline_SPF, EMP_PSF,
start_disk_params, start_spf_params, start_psf_params, misc_params)
fit_keys = ['sma', 'inclination', 'position_angle', 'knot_values', 'flux_scaling']
logscaled_params = ['knot_values', 'flux_scaling']
array_params = ['knot_values']
nk = start_spf_params['num_knots']
bounds = (
[0, 0, 0, np.zeros(nk), 1e0],
[150, 90, 360, 1e3*np.ones(nk), 1e8],
)
soln = opt.scipy_bounded_optimize(fit_keys, bounds, logscaled_params, array_params,
target_image, err_map, use_grad=True, disp_soln=True, iters=500)
message: CONVERGENCE: RELATIVE REDUCTION OF F <= FACTR*EPSMCH
success: True
status: 0
fun: 0.4924025319010353
x: [ 3.975e+01 5.995e+01 2.995e+01 1.400e+00 3.121e-01
-5.851e-02 2.156e-03 2.475e-01 6.225e+00]
nit: 67
jac: [-4.561e-08 4.696e-08 2.973e-08 2.729e-07 3.046e-07
1.161e-07 -1.305e-07 -3.978e-08 7.787e-07]
nfev: 88
njev: 88
hess_inv: <9x9 LbfgsInvHessProduct with dtype=float64>
Numeric Results
[9]:
fitted_image = opt.get_model()
print(f"Disk Fit Log Likelihood: {log_likelihood(fitted_image, target_image, err_map)}")
opt.print_params()
Disk Fit Log Likelihood: -0.4924025319010353
Disk Params: {'accuracy': 0.005, 'alpha_in': 5, 'alpha_out': -5, 'sma': np.float64(39.747969663268805), 'e': 0.0, 'ksi0': 3.0, 'gamma': 2.0, 'beta': 1.0, 'rmin': 0.0, 'dens_at_r0': 1.0, 'inclination': np.float64(59.952792086584566), 'position_angle': np.float64(29.95292536026525), 'x_center': 70.0, 'y_center': 70.0, 'omega': 0.0}
SPF Params: {'backscatt_bound': -0.9, 'forwardscatt_bound': 0.9, 'num_knots': 5, 'knot_values': array([4.05562732, 1.36630048, 0.94317314, 1.00215794, 1.28085239])}
PSF Params: {'scale_factor': 1.0, 'offset': 1.0}
Stellar PSF Params: None
Misc Params: {'distance': 50.0, 'pxInArcsec': 0.01414, 'nx': 140, 'ny': 140, 'halfNbSlices': 25, 'flux_scaling': np.float64(505.0029853837608)}
Visual Results
[10]:
fig, axes = plt.subplots(2, 2, figsize=(12,8))
fitted_pic = axes[0][0].imshow(fitted_image, origin='lower', cmap='inferno')
axes[0][0].set_title('Disk Fit')
plt.colorbar(fitted_pic, ax=axes[0][0], shrink=0.75)
target_pic = axes[0][1].imshow(target_image, origin='lower', cmap='inferno')
axes[0][1].set_title('Target Disk')
plt.colorbar(target_pic, ax=axes[0][1], shrink=0.75)
residual_pic = axes[1][0].imshow(target_image-fitted_image, origin='lower', cmap='inferno')
axes[1][0].set_title('Residual')
plt.colorbar(residual_pic, ax=axes[1][0], shrink=0.75)
residual_pic = axes[1][1].imshow(np.where(err_map == 0, 0., (target_image-fitted_image)/err_map), origin='lower', cmap='inferno')
axes[1][1].set_title('Signal-To-Noise Ratio')
plt.colorbar(residual_pic, ax=axes[1][1], shrink=0.75)
fig.suptitle("Disk Fit Results vs Target Image")
plt.show()
As you can see, even with a relatively simple fitting routine, GRaTer-JAX is able to produce a model that visually matches the target quite well. Although the target disk was generated with a Double Henyey-Greenstein SPF, we perform the fit using the flexible spline framework. The plot below illustrates how the spline is able to recover the input SPF.
Spline Fit Results
[11]:
x_vals = jnp.linspace(-1, 1, 200)
knot_locs = InterpolatedUnivariateSpline_SPF.get_knots(opt.spf_params)
fitted_spf = InterpolatedUnivariateSpline_SPF.pack_pars(opt.spf_params['knot_values'], knots=knot_locs)
# Normalise the target DHG to 1 at cos(phi) = 0, matching the fitted spline's convention
dhg_vals = DoubleHenyeyGreenstein_SPF.compute_phase_function_from_cosphi(
DoubleHenyeyGreenstein_SPF.pack_pars(spf_params), x_vals)
dhg_at_90 = float(DoubleHenyeyGreenstein_SPF.compute_phase_function_from_cosphi(
DoubleHenyeyGreenstein_SPF.pack_pars(spf_params), jnp.array([0.0]))[0])
dhg_normalised = dhg_vals / dhg_at_90
fwd = float(opt.spf_params['forwardscatt_bound'])
bck = float(opt.spf_params['backscatt_bound'])
fig, ax = plt.subplots()
# Shade scattering angles not probed at this inclination
ax.axvspan(-1.0, bck, alpha=0.15, color='gray', label='Not probed')
ax.axvspan(fwd, 1.0, alpha=0.15, color='gray')
ax.scatter(knot_locs, fitted_spf(knot_locs), zorder=3, label='Fitted spline knots')
ax.plot(x_vals, fitted_spf(x_vals), label='Fitted spline SPF')
ax.plot(x_vals, dhg_normalised, '--', label='Target SPF (normalised to 1 at 90°)')
ax.set_xlabel('cos(phi)')
ax.set_ylabel('Phase function (normalised to 1 at 90°)')
ax.legend()
plt.show()
Takeaways
As you can see, the objective functions and basics of the Optimizer class provide a useful utility for fitting disks with great accuracy. If you want a simpler way to work with more complex disks, leverage autodifferentiation, and run comprehensive sampling algorithms, refer to the Advanced Disk Modeling tutorial where we demonstrate the Optimizer class.