Speed Comparisions (VIP vs GRaTeR-JAX)
This notebook provides comparisons between previous numpy-based implementations of GRaTeR (namely VIP, the Vortex Image Processing package). See below for the performance takeaways from these tests.
[1]:
import numpy as np
import matplotlib.pyplot as plt
from vip_hci.fm.scattered_light_disk import ScatteredLightDisk as VIP_ScatteredLightDisk
from vip_hci.var.filters import frame_filter_lowpass
from grater_jax.optimization.optimize_framework import Optimizer
from grater_jax.disk_model.SLD_ojax import ScatteredLightDisk as JAX_ScatteredLightDisk
from grater_jax.disk_model.SLD_utils import (
DustEllipticalDistribution2PowerLaws,
HenyeyGreenstein_SPF,
)
from grater_jax.disk_model.objective_functions import Parameter_Index
import jax
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.50'
jax.config.update("jax_enable_x64", True)
/home/mihirkondapalli/anaconda3/envs/package-env/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Model Generation Benchmarks
[2]:
density = {
"name": "2PowerLaws",
"a": 60.0,
"ksi0": 1.0,
"gamma": 2.0,
"beta": 1.0,
"ain": 5.0,
"aout": -5.0,
"e": 0.0,
"dens_at_r0": 1.0,
}
phase = {"name": "HG", "g": 0.3, "polar": False}
vip_disk = VIP_ScatteredLightDisk(
nx=200,
ny=200,
distance=50.0,
itilt=60.0,
omega=0.0,
pxInArcsec=0.01225,
pa=0.0,
density_dico=density,
spf_dico=phase,
xdo=0.0,
ydo=0.0,
)
vip_img = vip_disk.compute_scattered_light(halfNbSlices=25)
[3]:
# Define Optimizer inputs
spf_params = HenyeyGreenstein_SPF.params.copy()
spf_params["g"] = 0.3
disk_params = Parameter_Index.disk_params.copy()
misc_params = Parameter_Index.misc_params.copy()
disk_params.update(
{
"sma": 60.0,
"alpha_in": 5.0,
"alpha_out": -5.0,
"ksi0": 1.0,
"gamma": 2.0,
"beta": 1.0,
"e": 0.0,
"dens_at_r0": 1.0,
"inclination": 60.0,
"position_angle": 0.0,
"omega": 0.0,
"x_center": 100.0,
"y_center": 100.0,
"halfNbSlices": 25,
}
)
misc_params.update(
{
"distance": 50.0,
"pxInArcsec": 0.01225,
"nx": 200,
"ny": 200,
"halfNbSlices": 25,
"flux_scaling": 1,
}
)
optimizer = Optimizer(
DiskModel=JAX_ScatteredLightDisk,
DistrModel=DustEllipticalDistribution2PowerLaws,
FuncModel=HenyeyGreenstein_SPF,
PSFModel=None,
disk_params=disk_params,
spf_params=spf_params,
psf_params=None,
misc_params=misc_params,
)
[4]:
# Generate JAX model
jax_img = optimizer.get_model()
# Compute likelihood
err_map = np.ones_like(vip_img) * np.nanstd(vip_img)
ll = optimizer.log_likelihood(vip_img, err_map)
print("Log-likelihood:", ll)
# Compare residuals
residual = vip_img - jax_img
mean_resid = np.mean(np.abs(residual))
max_resid = np.max(np.abs(residual))
print("Mean residual:", mean_resid)
print("Max residual:", max_resid)
Log-likelihood: 450799.6516621609
Mean residual: 4.1688652349367055e-13
Max residual: 3.468051170405547e-12
GRaTeR-JAX vs VIP Models
[5]:
# Visualization
vmin = min(vip_img.min(), jax_img.min())
vmax = max(vip_img.max(), jax_img.max())
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
im0 = axes[0].imshow(vip_img, cmap="inferno", vmin=vmin, vmax=vmax)
axes[0].set_title("VIP (PSF-applied)")
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
im1 = axes[1].imshow(jax_img, cmap="inferno", vmin=vmin, vmax=vmax)
axes[1].set_title("GRaTeR-JAX (Optimizer)")
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
im2 = axes[2].imshow(residual, cmap="seismic", vmin=-max_resid, vmax=max_resid)
axes[2].set_title("Residual")
fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
GRaTeR-JAX vs VIP Runtimes
[6]:
vip_model_runtime = %timeit -o vip_disk.compute_scattered_light(halfNbSlices=25)
grater_jax_model_runtime = %timeit -o optimizer.get_model()
127 ms ± 1.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
2.36 ms ± 149 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
GRaTeR-JAX vs VIP Gradient Benchmarks
[7]:
from vip_hci.fm.scattered_light_disk import ScatteredLightDisk as VIP_ScatteredLightDisk
from grater_jax.optimization.optimize_framework import Optimizer, OptimizeUtils
from grater_jax.disk_model.SLD_ojax import ScatteredLightDisk
from grater_jax.disk_model.SLD_utils import (
DustEllipticalDistribution2PowerLaws,
InterpolatedUnivariateSpline_SPF,
EMP_PSF,
)
from grater_jax.disk_model.objective_functions import Parameter_Index
# Synthetic target
def generate_vip_synthetic_target(nx=200, ny=200):
density = {
"name": "2PowerLaws",
"a": 60.0,
"ksi0": 1.0,
"gamma": 2.0,
"beta": 1.0,
"ain": 5.0,
"aout": -5.0,
"e": 0.0,
"dens_at_r0": 1.0,
}
phase = {"name": "HG", "g": 0.3, "polar": False}
vip_disk = VIP_ScatteredLightDisk(
nx=nx, ny=ny,
distance=50.0,
itilt=60.0,
omega=0.0,
pxInArcsec=0.01225,
pa=0.0,
density_dico=density,
spf_dico=phase,
xdo=0.0, ydo=0.0
)
vip_img = vip_disk.compute_scattered_light(halfNbSlices=25)
return vip_img
[8]:
# Synthetic VIP target
target_image = generate_vip_synthetic_target(nx=200, ny=200)
err_map = np.ones_like(target_image, dtype=np.float32)
# GRaTeR-JAX parameter setup
start_disk_params = Parameter_Index.disk_params.copy()
start_spf_params = InterpolatedUnivariateSpline_SPF.params.copy()
# No PSF → both the class and params are None
start_psf_params = None
start_misc_params = Parameter_Index.misc_params.copy()
start_disk_params.update({
"sma": 46.0,
"inclination": 80.0,
"position_angle": 27.5,
"x_center": 100.0,
"y_center": 100.0,
"flux_scaling": 1.0,
})
start_spf_params["num_knots"] = 5
start_spf_params["knot_values"] = np.ones(5) * 0.5
start_misc_params.update({
"distance": 50.0,
"nx": 200,
"ny": 200,
"pxInArcsec": 0.01225,
})
# Initialize optimizer with NO PSF model
opt = Optimizer(
ScatteredLightDisk,
DustEllipticalDistribution2PowerLaws,
InterpolatedUnivariateSpline_SPF,
PSFModel=None,
disk_params=start_disk_params,
spf_params=start_spf_params,
psf_params=start_psf_params,
misc_params=start_misc_params,
)
# SPF knots
opt.inc_bound_knots()
opt.estimate_flux_scaling(target_image)
# Compile model + gradient (no PSF path)
opt.jit_compile_model()
opt.jit_compile_gradient(target_image, err_map)
[9]:
from grater_jax.disk_model.objective_functions import log_likelihood
[10]:
fit_keys = ["sma", "alpha_in", "inclination", "position_angle", "x_center", "y_center", "alpha_out", "knot_values"]
def analytic_grad():
return opt.get_gradient(fit_keys, target_image, err_map)
def numeric_grad(eps=1e-6):
"""
Finite-difference numerical gradient for the current optimizer state.
Supports both scalar parameters and vector-valued knot_values.
"""
base_params = opt.get_values(keys=fit_keys)
grad = []
for k in range(len(fit_keys)):
val = base_params[k]
if np.isscalar(val):
params_plus = base_params.copy()
params_minus = base_params.copy()
params_plus[k] = val + eps
params_minus[k] = val - eps
ll_plus = opt.get_objective_likelihood(
params_plus, fit_keys, target_image, err_map
)
ll_minus = opt.get_objective_likelihood(
params_minus, fit_keys, target_image, err_map
)
grad.append((ll_plus - ll_minus) / (2 * eps))
else:
g = np.zeros_like(val)
for i in range(len(val)):
params_plus = base_params.copy()
params_minus = base_params.copy()
v_plus = val.copy()
v_minus = val.copy()
v_plus[i] += eps
v_minus[i] -= eps
params_plus[k] = v_plus
params_minus[k] = v_minus
ll_plus = opt.get_objective_likelihood(
params_plus, fit_keys, target_image, err_map
)
ll_minus = opt.get_objective_likelihood(
params_minus, fit_keys, target_image, err_map
)
g[i] = (ll_plus - ll_minus) / (2 * eps)
grad.append(g)
return grad
[11]:
def make_vip_disk(params_dict):
density = {
"name": "2PowerLaws",
"a": params_dict["sma"],
"ksi0": 1.0,
"gamma": 2.0,
"beta": 1.0,
"ain": params_dict["alpha_in"],
"aout": params_dict["alpha_out"],
"e": 0.0,
"dens_at_r0": 1.0,
}
phase = {
"name": "HG",
"g": params_dict["g"],
"polar": False, }
return VIP_ScatteredLightDisk(
nx=200,
ny=200,
distance=50.0,
itilt=params_dict["inclination"],
omega=0.0,
pxInArcsec=0.01225,
pa=params_dict["position_angle"],
density_dico=density,
spf_dico=phase,
xdo=100-params_dict["x_center"],
ydo=100-params_dict["y_center"]
)
vip_fit_keys = ["sma", "alpha_in", "inclination", "position_angle", "x_center", "y_center", "alpha_out", "g"]
vip_fit_values = [46.0, 5.0, 80.0, 27.5, 100.0, 100.0, -5.0, 0.3]
base_dict = dict(zip(vip_fit_keys, vip_fit_values))
def numeric_grad_vip(eps=1e-6):
"""
Finite-difference numerical gradient using VIP scattered-light images.
"""
grad = []
for k, key in enumerate(vip_fit_keys):
val = vip_fit_values[k]
if np.isscalar(val):
dict_plus = dict(base_dict)
dict_minus = dict(base_dict)
dict_plus[key] = val + eps
dict_minus[key] = val - eps
img_plus = make_vip_disk(dict_plus).compute_scattered_light(halfNbSlices=25)
img_minus = make_vip_disk(dict_minus).compute_scattered_light(halfNbSlices=25)
ll_plus = log_likelihood(img_plus, target_image, err_map)
ll_minus = log_likelihood(img_minus, target_image, err_map)
grad.append((ll_plus - ll_minus) / (2 * eps))
return grad
Numeric vs Analytic Gradients
VIP does not have the capability to generate numeric gradients, so I will isolate the performance improvement of GRaTeR-JAX’s analyitic gradient function by comparing the analyitic gradient with a numeric function for a GRaTeR-JAX objective function.
[12]:
def compress_grad(grad_list):
out = []
for g in grad_list:
if np.isscalar(g):
out.append(float(g))
else:
out.extend(np.asarray(g).ravel().astype(float).tolist())
return out
def format_list(vals, ndigits=2):
return "[" + ", ".join(f"{v:.{ndigits}e}" for v in vals) + "]"
print("Analytic Gradient:", format_list(compress_grad(analytic_grad()), 2))
print("Numeric Gradient:", format_list(compress_grad(numeric_grad()), 2))
Analytic Gradient: [7.32e-08, 1.41e-07, -1.18e-07, 2.35e-08, 1.14e-08, -1.84e-09, 9.44e-09, -3.00e-07, -1.78e-07, 2.78e-06, -7.41e-07, -6.93e-07]
Numeric Gradient: [7.32e-08, 1.41e-07, -1.18e-07, 2.35e-08, 2.74e-08, 3.55e-09, 9.44e-09, -3.00e-07, -1.78e-07, 2.78e-06, -7.41e-07, -6.93e-07]
VIP (numeric) vs GRaTeR-JAX Gradient Runtimes
[13]:
numeric_grad_runtime = %timeit -o numeric_grad()
analytic_grad_runtime = %timeit -o analytic_grad()
numeric_grad_vip_runtime = %timeit -o numeric_grad_vip()
271 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
16.8 ms ± 21 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.04 s ± 3.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Results
[14]:
from IPython.display import Markdown, display
def fmt_time(seconds):
"""
Format a runtime in seconds using readable units.
"""
if seconds >= 1:
return f"{seconds:.3g} s"
elif seconds >= 1e-3:
return f"{seconds * 1e3:.3g} ms"
elif seconds >= 1e-6:
return f"{seconds * 1e6:.3g} μs"
else:
return f"{seconds * 1e9:.3g} ns"
def fmt_runtime(result):
"""
Format a %timeit -o result as mean ± std.
"""
return f"{fmt_time(result.average)} ± {fmt_time(result.stdev)}"
model_speedup = vip_model_runtime.average / grater_jax_model_runtime.average
jax_grad_speedup = numeric_grad_runtime.average / analytic_grad_runtime.average
vip_grad_speedup = numeric_grad_vip_runtime.average / analytic_grad_runtime.average
vip_numeric_grad_speedup = numeric_grad_vip_runtime.average / numeric_grad_runtime.average
display(Markdown(f"""
# Performance Takeaways
GRaTeR-JAX delivers substantial performance improvements over the legacy VIP GRaTeR disk framework in both **model generation** and **gradient evaluation**. These gains are enabled by JAX-based vectorization, XLA compilation, and analytic differentiation, making GRaTeR-JAX significantly more scalable for optimization and inference workflows.
---
## Benchmark Summary
| Task | Framework | Method | Runtime per Loop | Speedup |
|-----:|-----------|--------|------------------|---------|
| Model generation | VIP GRaTeR | Native | {fmt_runtime(vip_model_runtime)} | 1× |
| Model generation | GRaTeR-JAX | JAX model | {fmt_runtime(grater_jax_model_runtime)} | **~{model_speedup:.0f}×** |
| Gradient | VIP GRaTeR | Numeric FD | {fmt_runtime(numeric_grad_vip_runtime)} | 1× |
| Gradient | GRaTeR-JAX | Numeric FD | {fmt_runtime(numeric_grad_runtime)} | ~{vip_numeric_grad_speedup:.0f}× |
| Gradient | GRaTeR-JAX | Analytic AD | {fmt_runtime(analytic_grad_runtime)} | **~{vip_grad_speedup:.0f}×** |
---
## Model Generation Performance
We compare scattered-light model generation between:
- **VIP GRaTeR**: `vip_disk.compute_scattered_light`
- **GRaTeR-JAX**: `optimizer.get_model`
~~~python
vip_model_runtime = %timeit -o vip_disk.compute_scattered_light(halfNbSlices=25)
grater_jax_runtime = %timeit -o optimizer.get_model()
~~~
**Results**
- VIP GRaTeR
{fmt_runtime(vip_model_runtime)} per loop
- GRaTeR-JAX
{fmt_runtime(grater_jax_model_runtime)} per loop
**Speedup**
GRaTeR-JAX achieves a **~{model_speedup:.0f}× reduction** in model generation runtime.
This improvement dramatically reduces the cost of iterative workflows such as gradient-based fitting, MCMC sampling, and large-scale parameter sweeps.
---
## Gradient Computation Performance
Three gradient computation scenarios are evaluated:
~~~python
numeric_grad_runtime = %timeit -o numeric_grad()
analytic_grad_runtime = %timeit -o analytic_grad()
numeric_grad_vip_runtime = %timeit -o numeric_grad_vip()
~~~
**Results**
- GRaTeR-JAX numeric gradient
{fmt_runtime(numeric_grad_runtime)} per loop
- GRaTeR-JAX analytic gradient
{fmt_runtime(analytic_grad_runtime)} per loop
- VIP numeric gradient
{fmt_runtime(numeric_grad_vip_runtime)} per loop
### Notes on Comparability
- The **numeric and analytic gradients** are computed on the **same GRaTeR-JAX model**, allowing for a direct comparison.
- The **VIP numeric gradient** corresponds to a different model configuration because VIP does **not support spline-based SPF functions**. It is included as a realistic reference for legacy workflows rather than a strict apples-to-apples comparison.
---
## Gradient Speedups
- **Analytic vs numeric GRaTeR-JAX**
→ **~{jax_grad_speedup:.1f}× speedup**
- **Analytic GRaTeR-JAX vs numeric VIP**
→ **~{vip_grad_speedup:.0f}× speedup** in gradient evaluation
In addition to performance gains, analytic gradients eliminate numerical instability and poor scaling behavior inherent to finite-difference methods.
---
## Summary
- Model generation is nearly **{model_speedup:.0f}× faster** with GRaTeR-JAX.
- Analytic gradients provide:
- Significant runtime reductions
- Improved numerical stability
- Better scaling with parameter dimensionality
- GRaTeR-JAX enables workflows that are impractical with VIP, including high-dimensional optimization, HMC/NUTS sampling, and interactive modeling.
Overall, GRaTeR-JAX represents a substantial advance in both **performance** and **scalability** for scattered-light disk modeling.
"""))
Performance Takeaways
GRaTeR-JAX delivers substantial performance improvements over the legacy VIP GRaTeR disk framework in both model generation and gradient evaluation. These gains are enabled by JAX-based vectorization, XLA compilation, and analytic differentiation, making GRaTeR-JAX significantly more scalable for optimization and inference workflows.
Benchmark Summary
Task |
Framework |
Method |
Runtime per Loop |
Speedup |
|---|---|---|---|---|
Model generation |
VIP GRaTeR |
Native |
127 ms ± 1.29 ms |
1× |
Model generation |
GRaTeR-JAX |
JAX model |
2.36 ms ± 149 μs |
~54× |
Gradient |
VIP GRaTeR |
Numeric FD |
2.04 s ± 3.64 ms |
1× |
Gradient |
GRaTeR-JAX |
Numeric FD |
271 ms ± 2.67 ms |
~8× |
Gradient |
GRaTeR-JAX |
Analytic AD |
16.8 ms ± 21 μs |
~121× |
Model Generation Performance
We compare scattered-light model generation between:
VIP GRaTeR:
vip_disk.compute_scattered_lightGRaTeR-JAX:
optimizer.get_model
vip_model_runtime = %timeit -o vip_disk.compute_scattered_light(halfNbSlices=25)
grater_jax_runtime = %timeit -o optimizer.get_model()
Results
- VIP GRaTeR127 ms ± 1.29 ms per loop
- GRaTeR-JAX2.36 ms ± 149 μs per loop
Speedup
GRaTeR-JAX achieves a ~54× reduction in model generation runtime.
This improvement dramatically reduces the cost of iterative workflows such as gradient-based fitting, MCMC sampling, and large-scale parameter sweeps.
Gradient Computation Performance
Three gradient computation scenarios are evaluated:
numeric_grad_runtime = %timeit -o numeric_grad()
analytic_grad_runtime = %timeit -o analytic_grad()
numeric_grad_vip_runtime = %timeit -o numeric_grad_vip()
Results
- GRaTeR-JAX numeric gradient271 ms ± 2.67 ms per loop
- GRaTeR-JAX analytic gradient16.8 ms ± 21 μs per loop
- VIP numeric gradient2.04 s ± 3.64 ms per loop
Notes on Comparability
The numeric and analytic gradients are computed on the same GRaTeR-JAX model, allowing for a direct comparison.
The VIP numeric gradient corresponds to a different model configuration because VIP does not support spline-based SPF functions. It is included as a realistic reference for legacy workflows rather than a strict apples-to-apples comparison.
Gradient Speedups
- Analytic vs numeric GRaTeR-JAX→ ~16.1× speedup
- Analytic GRaTeR-JAX vs numeric VIP→ ~121× speedup in gradient evaluation
In addition to performance gains, analytic gradients eliminate numerical instability and poor scaling behavior inherent to finite-difference methods.
Summary
Model generation is nearly 54× faster with GRaTeR-JAX.
Analytic gradients provide:
Significant runtime reductions
Improved numerical stability
Better scaling with parameter dimensionality
GRaTeR-JAX enables workflows that are impractical with VIP, including high-dimensional optimization, HMC/NUTS sampling, and interactive modeling.
Overall, GRaTeR-JAX represents a substantial advance in both performance and scalability for scattered-light disk modeling.