grater_jax.optimization.mcmc_model
mcmc_model.py
MCMC utilities for model fitting.
This module defines the MCMC_model class, a wrapper around the emcee ensemble sampler with additional tools for running MCMC inference, analyzing posterior distributions, and visualizing results. It provides built-in support for HDF5 backends, custom priors, autocorrelation analysis, and plotting.
Main features
MCMC_model.run : Run burn-in and production MCMC chains with emcee.
MCMC_model.get_theta_median, get_theta_percs, get_theta_max : Extract parameter estimates from posterior samples.
MCMC_model.show_corner_plot : Generate corner plots of posterior samples.
MCMC_model.plot_chains : Visualize walker chains for diagnostics.
MCMC_model.auto_corr : Estimate autocorrelation times.
Classes
|
Markov Chain Monte Carlo (MCMC) wrapper for emcee sampler with tools for inference, diagnostics, and visualization. |
- class grater_jax.optimization.mcmc_model.MCMC_model(fun, theta_bounds, name)
Bases:
objectMarkov Chain Monte Carlo (MCMC) wrapper for emcee sampler with tools for inference, diagnostics, and visualization.
- Parameters:
fun (callable) – The log-likelihood function to evaluate.
theta_bounds (tuple of np.ndarray) – Tuple (lower_bounds, upper_bounds) for parameters.
name (str) – Name of the model (used for backend filename).
- run(initial, nwalkers=500, niter=500, burn_iter=100, nconst=1e-07, resume=False, confirm_overwrite=True, **kwargs)
Run MCMC sampling using emcee.
- Parameters:
initial (np.ndarray) – Initial guess for parameters.
nwalkers (int) – Number of MCMC walkers.
niter (int) – Number of production iterations.
burn_iter (int) – Number of burn-in iterations.
nconst (float) – Perturbation constant for initializing walkers.
resume (bool, optional) – If True, continue sampling from where the existing HDF5 backend left off. No burn-in is run and the chain is extended by niter steps. If False (default), start a fresh run with new burn-in, overwriting the existing backend. See also confirm_overwrite.
confirm_overwrite (bool, optional) – Only relevant when resume=False. If True (default), prompt the user interactively before overwriting the existing backend — safe for interactive use. Set to False to overwrite silently without prompting, which is required in scripted or notebook contexts where stdin is not available.
- Returns:
(sampler, chain, log-probs, random state)
- Return type:
tuple
- set_discarded_iters(new_discard_iters)
Set the number of burn-in iterations to discard.
- Parameters:
new_discard_iters (int) – Number of iterations to discard (must be >= 0).
- get_theta_median(scaled=False)
Return median of posterior samples.
- Parameters:
scaled (bool, optional) – If True, use scaled parameter samples.
- Returns:
Median parameter values.
- Return type:
np.ndarray
- get_theta_percs(scaled=False)
Return 16th, 50th, and 84th percentiles of posterior samples.
- Parameters:
scaled (bool, optional) – If True, use scaled parameter samples.
- Returns:
Array of shape (3, ndim) with 16th, 50th, 84th percentiles.
- Return type:
np.ndarray
- get_theta_max(scaled=False)
Return parameter values corresponding to maximum log-probability.
- Parameters:
scaled (bool, optional) – If True, return scaled values.
- Returns:
Parameter values with maximum log probability.
- Return type:
np.ndarray
- show_corner_plot(labels, truths=None, show_titles=True, plot_datapoints=True, quantiles=[0.16, 0.5, 0.84], quiet=False, scaled=False, range='auto', range_percentiles=(0.5, 99.5), range_pad=0.05)
Generate corner plot using posterior samples.
By default the per-axis range is clipped to a robust percentile window (
range_percentiles) withrange_padsymmetric padding so a few straggling walkers cannot stretch the axes and hide the core of the posterior. Passrange=Noneto restore corner’s default (full extent) or pass an explicit list of (lo, hi) tuples to override.- Parameters:
labels (list of str) – Parameter names for plot axes.
truths (array-like, optional) – True values to overlay as lines.
show_titles (bool, optional) – Show per-parameter title with summary stats.
plot_datapoints (bool, optional) – Draw individual sample points.
quantiles (list of float, optional) – Quantiles to mark on each 1D panel.
quiet (bool, optional) – Suppress corner’s stdout output.
scaled (bool, optional) – Use scaled (log-transformed) samples instead of physical values.
range ({'auto', None} or list of tuple, optional) – Per-axis range passed to
corner.corner.'auto'(default) uses the robust percentile window described above;Nonedisables cropping; a list of (lo, hi) tuples is passed through.range_percentiles (tuple of float, optional) – Low/high percentiles (0–100) used when
range='auto'. Defaults to (0.5, 99.5) which trims the ~1% outlier tails.range_pad (float, optional) – Fractional padding applied to the auto window (default 5%).
- Returns:
The corner plot figure.
- Return type:
matplotlib.figure.Figure
- plot_chains(labels, cols_per_row=3, scaled=False)
Plot individual walker chains for each parameter.
- Parameters:
labels (list of str) – Names of parameters.
cols_per_row (int, optional) – Number of subplot columns per row.
scaled (bool, optional) – If True, use scaled parameter samples.
- auto_corr(chain_length=50)
Estimate autocorrelation length using Sokal’s method.
- Parameters:
chain_length (int) – Number of subchain lengths to evaluate.
- Returns:
Autocorrelation estimates for different subchain lengths.
- Return type:
np.ndarray