grater_jax.optimization.mcmc_model

mcmc_model.py

MCMC utilities for model fitting.

This module defines the MCMC_model class, a wrapper around the emcee ensemble sampler with additional tools for running MCMC inference, analyzing posterior distributions, and visualizing results. It provides built-in support for HDF5 backends, custom priors, autocorrelation analysis, and plotting.

Main features

  • MCMC_model.run : Run burn-in and production MCMC chains with emcee.

  • MCMC_model.get_theta_median, get_theta_percs, get_theta_max : Extract parameter estimates from posterior samples.

  • MCMC_model.show_corner_plot : Generate corner plots of posterior samples.

  • MCMC_model.plot_chains : Visualize walker chains for diagnostics.

  • MCMC_model.auto_corr : Estimate autocorrelation times.

Classes

MCMC_model(fun, theta_bounds, name)

Markov Chain Monte Carlo (MCMC) wrapper for emcee sampler with tools for inference, diagnostics, and visualization.

class grater_jax.optimization.mcmc_model.MCMC_model(fun, theta_bounds, name)

Bases: object

Markov Chain Monte Carlo (MCMC) wrapper for emcee sampler with tools for inference, diagnostics, and visualization.

Parameters:
  • fun (callable) – The log-likelihood function to evaluate.

  • theta_bounds (tuple of np.ndarray) – Tuple (lower_bounds, upper_bounds) for parameters.

  • name (str) – Name of the model (used for backend filename).

run(initial, nwalkers=500, niter=500, burn_iter=100, nconst=1e-07, resume=False, confirm_overwrite=True, **kwargs)

Run MCMC sampling using emcee.

Parameters:
  • initial (np.ndarray) – Initial guess for parameters.

  • nwalkers (int) – Number of MCMC walkers.

  • niter (int) – Number of production iterations.

  • burn_iter (int) – Number of burn-in iterations.

  • nconst (float) – Perturbation constant for initializing walkers.

  • resume (bool, optional) – If True, continue sampling from where the existing HDF5 backend left off. No burn-in is run and the chain is extended by niter steps. If False (default), start a fresh run with new burn-in, overwriting the existing backend. See also confirm_overwrite.

  • confirm_overwrite (bool, optional) – Only relevant when resume=False. If True (default), prompt the user interactively before overwriting the existing backend — safe for interactive use. Set to False to overwrite silently without prompting, which is required in scripted or notebook contexts where stdin is not available.

Returns:

(sampler, chain, log-probs, random state)

Return type:

tuple

set_discarded_iters(new_discard_iters)

Set the number of burn-in iterations to discard.

Parameters:

new_discard_iters (int) – Number of iterations to discard (must be >= 0).

get_theta_median(scaled=False)

Return median of posterior samples.

Parameters:

scaled (bool, optional) – If True, use scaled parameter samples.

Returns:

Median parameter values.

Return type:

np.ndarray

get_theta_percs(scaled=False)

Return 16th, 50th, and 84th percentiles of posterior samples.

Parameters:

scaled (bool, optional) – If True, use scaled parameter samples.

Returns:

Array of shape (3, ndim) with 16th, 50th, 84th percentiles.

Return type:

np.ndarray

get_theta_max(scaled=False)

Return parameter values corresponding to maximum log-probability.

Parameters:

scaled (bool, optional) – If True, return scaled values.

Returns:

Parameter values with maximum log probability.

Return type:

np.ndarray

show_corner_plot(labels, truths=None, show_titles=True, plot_datapoints=True, quantiles=[0.16, 0.5, 0.84], quiet=False, scaled=False, range='auto', range_percentiles=(0.5, 99.5), range_pad=0.05)

Generate corner plot using posterior samples.

By default the per-axis range is clipped to a robust percentile window (range_percentiles) with range_pad symmetric padding so a few straggling walkers cannot stretch the axes and hide the core of the posterior. Pass range=None to restore corner’s default (full extent) or pass an explicit list of (lo, hi) tuples to override.

Parameters:
  • labels (list of str) – Parameter names for plot axes.

  • truths (array-like, optional) – True values to overlay as lines.

  • show_titles (bool, optional) – Show per-parameter title with summary stats.

  • plot_datapoints (bool, optional) – Draw individual sample points.

  • quantiles (list of float, optional) – Quantiles to mark on each 1D panel.

  • quiet (bool, optional) – Suppress corner’s stdout output.

  • scaled (bool, optional) – Use scaled (log-transformed) samples instead of physical values.

  • range ({'auto', None} or list of tuple, optional) – Per-axis range passed to corner.corner. 'auto' (default) uses the robust percentile window described above; None disables cropping; a list of (lo, hi) tuples is passed through.

  • range_percentiles (tuple of float, optional) – Low/high percentiles (0–100) used when range='auto'. Defaults to (0.5, 99.5) which trims the ~1% outlier tails.

  • range_pad (float, optional) – Fractional padding applied to the auto window (default 5%).

Returns:

The corner plot figure.

Return type:

matplotlib.figure.Figure

plot_chains(labels, cols_per_row=3, scaled=False)

Plot individual walker chains for each parameter.

Parameters:
  • labels (list of str) – Names of parameters.

  • cols_per_row (int, optional) – Number of subplot columns per row.

  • scaled (bool, optional) – If True, use scaled parameter samples.

auto_corr(chain_length=50)

Estimate autocorrelation length using Sokal’s method.

Parameters:

chain_length (int) – Number of subchain lengths to evaluate.

Returns:

Autocorrelation estimates for different subchain lengths.

Return type:

np.ndarray