Architecture & Workflow Overview

This page provides a high-level introduction, conceptual workflow, and block-diagram–driven understanding of the GRaTeR-JAX codebase. It is intended to help new users and contributors quickly understand how the different components fit together: from user scripts and notebooks, through optimization logic, down to JAX-compiled disk models and gradients.


1. What is GRaTeR-JAX?

GRaTeR-JAX is a JAX-native, differentiable implementation of the GRaTeR scattered-light debris disk framework from Augereau+ 1999. It is designed to:

  • Generate scattered-light images of circumstellar debris disks

  • Perform fast, accurate model fitting using automatic differentiation

  • Support MCMC-based probabilistic inference to estimate disk parameters and uncertainties

  • Run efficiently on CPU/GPU via JAX JIT compilation

The library cleanly separates:

  • Physical disk models

  • Objective functions (likelihoods + gradients)

  • Optimization / inference orchestration


2. End-to-End User Workflow

At a high level, a typical GRaTeR-JAX workflow looks like:

  1. Define a disk model (density profile, scattering phase function, PSF handling)

  2. Specify initial parameter guesses and model configuration

  3. Load a target image, error map, and optional empirical PSF

  4. JIT-compile the forward model and gradients

  5. Run either:

    • Deterministic optimization (SciPy), or

    • Bayesian inference (MCMC)

  6. Inspect best-fit parameters and diagnostics

This flow is summarized visually in the block diagram below.


3. High-Level Optimization Flow Diagram

The diagram below represents the conceptual optimization pipeline exposed to the user.

GRaTeR-JAX Optimization Flow

Key stages

  • Define Disk Model

    • Dust distribution

    • Scattering phase function (SPF, e.g. Henyey-Greenstein vs. Spline)

    • On-axis / off-axis PSF handling (Optional)

  • Define Initial Guess

    • Morphological parameters

    • SPF parameters

    • Throughput maps / misc parameters (e.g. distance to system, flux scaling)

  • Initialize Optimizer

    • Creates a central object coordinating data, parameters, and JAX functions

  • Target Image & PSF Setup

    • Image preprocessing

    • Error map construction

    • Optional empirical PSF injection

  • JIT Compilation

    • Forward model

    • Gradients (analytic via JAX automatic differentiation)

  • Optimization / MCMC

    • SciPy minimization or

    • MCMC sampling


4. Internal Code Architecture

GRaTeR-JAX is organized around a clean separation of concerns. The diagram below shows how the main modules interact.

GRaTeR-JAX Architecture

Core components

Tutorials

  • Disk Modeling Basics (Tutorials)

  • Advanced Disk Modeling (How-Tos)

  • SPFInjectionRecoveryTutorial.ipynb (How-Tos)

  • Speed Comparisions (Tutorials)

Tutorial Page

Optimizer Class

  • Front facing API for disk fitting

  • Handles disk component parameters

  • Compiles objective model and gradient

  • Excecutes gradient-descent/MCMC fitting

Objective Functions

  • objective_model

  • objective_grad

  • Wrap JAX disk models into scalar likelihoods

JAX Models / Gradients

  • Pure, functional disk forward models

  • Fully differentiable

  • No side effects

SLD_utils.py

  • Shared utilities:

    • Image preprocessing

    • PSF handling

    • Data normalization

  • Used by both:

    • Objective functions

    • Optimizer logic

SLD_ojax.py

  • Low-level JAX disk model definitions

  • Performance-critical code


5. JAX JIT Compilation Strategy

One of the key design goals of GRaTeR-JAX is performance without sacrificing clarity.

  • Disk models are written as pure JAX functions

  • Objective functions wrap those models

  • The optimizer JIT-compiles once, then reuses compiled kernels

This results in:

  • Massive speedups over legacy NumPy implementations

  • Stable gradients for both optimization and MCMC

  • Identical code paths for CPU and GPU execution

Note: Autodiff is Jitted by default for better runtime


6. Where to Go Next

  • Tutorials

    • Basic disk modeling and benchmarking.

  • How-Tos

    • Model read debris disk data using autodiff

    • Perform MCMC and SPF injection recovery

  • API reference

    • See module-level documentation in the optimization and disk_model subpackages

This page is intentionally conceptual. Once this mental model is clear, the individual APIs and classes should feel significantly easier to navigate.