grater_jax.optimization.mcmc_model
mcmc_model.py
MCMC utilities for model fitting.
This module defines the MCMC_model class, a wrapper around the emcee ensemble sampler with additional tools for running MCMC inference, analyzing posterior distributions, and visualizing results. It provides built-in support for HDF5 backends, custom priors, autocorrelation analysis, and plotting.
Main features
MCMC_model.run : Run burn-in and production MCMC chains with emcee.
MCMC_model.get_theta_median, get_theta_percs, get_theta_max : Extract parameter estimates from posterior samples.
MCMC_model.show_corner_plot : Generate corner plots of posterior samples.
MCMC_model.plot_chains : Visualize walker chains for diagnostics.
MCMC_model.auto_corr : Estimate autocorrelation times.
Classes
|
Markov Chain Monte Carlo (MCMC) wrapper for emcee sampler with tools for inference, diagnostics, and visualization. |
- class grater_jax.optimization.mcmc_model.MCMC_model(fun, theta_bounds, name)
Bases:
objectMarkov Chain Monte Carlo (MCMC) wrapper for emcee sampler with tools for inference, diagnostics, and visualization.
- Parameters:
fun (callable) – The log-likelihood function to evaluate.
theta_bounds (tuple of np.ndarray) – Tuple (lower_bounds, upper_bounds) for parameters.
name (str) – Name of the model (used for backend filename).
- run(initial, nwalkers=500, niter=500, burn_iter=100, nconst=1e-07, resume=False, confirm_overwrite=True, **kwargs)
Run MCMC sampling using emcee.
- Parameters:
initial (np.ndarray) – Initial guess for parameters.
nwalkers (int) – Number of MCMC walkers.
niter (int) – Number of production iterations.
burn_iter (int) – Number of burn-in iterations.
nconst (float) – Perturbation constant for initializing walkers.
resume (bool, optional) – If True, continue sampling from where the existing HDF5 backend left off. No burn-in is run and the chain is extended by niter steps. If False (default), start a fresh run with new burn-in, overwriting the existing backend. See also confirm_overwrite.
confirm_overwrite (bool, optional) – Only relevant when resume=False. If True (default), prompt the user interactively before overwriting the existing backend — safe for interactive use. Set to False to overwrite silently without prompting, which is required in scripted or notebook contexts where stdin is not available.
- Returns:
(sampler, chain, log-probs, random state)
- Return type:
tuple
- set_discarded_iters(new_discard_iters)
Set the number of burn-in iterations to discard.
- Parameters:
new_discard_iters (int) – Number of iterations to discard (must be >= 0).
- get_theta_median(scaled=False)
Return median of posterior samples.
- Parameters:
scaled (bool, optional) – If True, use scaled parameter samples.
- Returns:
Median parameter values.
- Return type:
np.ndarray
- get_theta_percs(scaled=False)
Return 16th, 50th, and 84th percentiles of posterior samples.
- Parameters:
scaled (bool, optional) – If True, use scaled parameter samples.
- Returns:
Array of shape (3, ndim) with 16th, 50th, 84th percentiles.
- Return type:
np.ndarray
- get_theta_max(scaled=False)
Return parameter values corresponding to maximum log-probability.
- Parameters:
scaled (bool, optional) – If True, return scaled values.
- Returns:
Parameter values with maximum log probability.
- Return type:
np.ndarray
- show_corner_plot(labels, truths=None, show_titles=True, plot_datapoints=True, quantiles=[0.16, 0.5, 0.84], quiet=False, scaled=False)
Generate corner plot using posterior samples.
- Parameters:
labels (list of str) – List of parameter names for plot axes.
truths (array-like, optional) – True values to show as vertical lines.
show_titles (bool, optional) – If True, show titles with stats above plots.
plot_datapoints (bool, optional) – If True, show individual data points.
quantiles (list of float, optional) – Quantiles to show on plot.
quiet (bool, optional) – Suppress text output.
scaled (bool, optional) – If True, use scaled parameter samples.
- Returns:
The corner plot figure.
- Return type:
matplotlib.figure.Figure
- plot_chains(labels, cols_per_row=3, scaled=False)
Plot individual walker chains for each parameter.
- Parameters:
labels (list of str) – Names of parameters.
cols_per_row (int, optional) – Number of subplot columns per row.
scaled (bool, optional) – If True, use scaled parameter samples.
- auto_corr(chain_length=50)
Estimate autocorrelation length using Sokal’s method.
- Parameters:
chain_length (int) – Number of subchain lengths to evaluate.
- Returns:
Autocorrelation estimates for different subchain lengths.
- Return type:
np.ndarray