Skip to content

Instantly share code, notes, and snippets.

@bmorris3
Last active September 2, 2025 22:51
Show Gist options
  • Select an option

  • Save bmorris3/d63c041aaca8504f24c79a7aa5346e2b to your computer and use it in GitHub Desktop.

Select an option

Save bmorris3/d63c041aaca8504f24c79a7aa5346e2b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "50911bcb-a862-42c2-aa61-c0fabac6cbe9",
"metadata": {},
"source": [
"Run `numpyro.infer.MCMC` with checkpoints before all samples are completed. At each checkpoint, run a custom visualization or write-out function."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1eb6882e-3717-4086-9c47-e6e690237b1d",
"metadata": {},
"outputs": [],
"source": [
"# Set the number of cores on your machine for parallel computing:\n",
"import numpyro\n",
"cpu_cores = 3\n",
"numpyro.set_host_device_count(cpu_cores)\n",
"\n",
"from jax import numpy as jnp\n",
"import shone\n",
"from shone.spectrum import bin_spectrum\n",
"\n",
"from numpyro.infer import MCMC, NUTS\n",
"from numpyro import distributions as dist\n",
"\n",
"# The just-in-time decorator:\n",
"from jax import jit\n",
"\n",
"# random numbers in jax are handled by these:\n",
"from jax import random\n",
"\n",
"# these packages will aid in visualization:\n",
"import arviz\n",
"from corner import corner"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa0a5ba6-01a6-4a25-ac6f-e3a6cdb1a214",
"metadata": {},
"outputs": [],
"source": [
"a_truth = 0.05\n",
"b_truth = 0.40\n",
"x = jnp.linspace(-5, 5, 100)\n",
"y_err = 0.01\n",
"y_obs = a_truth * x ** 2 + b_truth * x + y_err * random.normal(PRNGKey(0), shape=x.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f257f125-d354-41ca-9bd2-a1c530882823",
"metadata": {},
"outputs": [],
"source": [
"def model():\n",
" a = numpyro.sample('a', dist.Uniform(0, 0.1))\n",
" b = numpyro.sample('b', dist.Uniform(0, 0.6))\n",
" y_model = a * x ** 2 + b * x\n",
" return numpyro.sample(\n",
" 'obs', dist.Normal(y_model, y_err), \n",
" obs=y_obs\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8a9a8b9-3ded-4473-949f-9eb8b6e1cbee",
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"from collections import defaultdict\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from tqdm.auto import tqdm \n",
"from jax import device_get, jit, lax, local_device_count, pmap, random, vmap\n",
"\n",
"from functools import partial\n",
"from numpyro.infer.util import initialize_model\n",
"from numpyro.infer import MCMC, log_likelihood\n",
"from numpyro.util import fori_collect\n",
"from numpyro.infer.hmc import hmc\n",
"from datetime import datetime\n",
"\n",
"rng_seed = 0\n",
"\n",
"\n",
"def hstack_recursive(final_states, checkpoint_states):\n",
" for key in final_states.keys():\n",
" if isinstance(final_states[key], dict):\n",
" hstack_recursive(final_states[key], checkpoint_states[key])\n",
" else:\n",
" final_states[key] = jnp.hstack([\n",
" final_states[key], \n",
" checkpoint_states[key]\n",
" ])\n",
"\n",
"\n",
"def print_big_message(big_message):\n",
" print('\\n\\n')\n",
" print('=' * len(big_message))\n",
" print(big_message)\n",
" print('=' * len(big_message))\n",
" print('\\n\\n')\n",
"\n",
"\n",
"class MCMCWithCheckpoints(MCMC):\n",
" running_states = None\n",
" checkpoint = 0\n",
" start_time = None\n",
" \n",
"\n",
" def run_checkpoints(self, rng_key, *args, extra_fields=(), n_checkpoints=10, \n",
" progress_bar_warmup=True, progress_bar_samples=True, \n",
" init_params=None, on_checkpoint=None, **kwargs):\n",
" \"\"\"\n",
" Run the MCMC samplers and collect samples.\n",
"\n",
" :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.\n",
" For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key`\n",
" does not have batch_size, it will be split in to a batch of `num_chains` keys.\n",
" :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method.\n",
" These are typically the arguments needed by the `model`.\n",
" :param extra_fields: Extra fields (aside from `\"z\"`, `\"diverging\"`) from the\n",
" state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to be collected\n",
" during the MCMC run. Note that subfields can be accessed using dots, e.g.\n",
" `\"adapt_state.step_size\"` can be used to collect step sizes at each step. Exclude sample sites from\n",
" collection with \"~`sampler.sample_field`.`sample_site`\". e.g. \"~z.a\" will prevent site \"a\" from\n",
" being collected if you're using the NUTS sampler. To collect samples of a site \"a\" in the\n",
" unconstrained space, we can specify the variable here, e.g. `extra_fields=(\"z.a\",)`.\n",
" :type extra_fields: tuple or list of str\n",
" :param init_params: Initial parameters to begin sampling. The type must be consistent\n",
" with the input type to `potential_fn` provided to the kernel. If the kernel is\n",
" instantiated by a numpyro model, the initial parameters here correspond to latent\n",
" values in unconstrained space.\n",
" :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`\n",
" method. These are typically the keyword arguments needed by the `model`.\n",
"\n",
" .. note:: jax allows python code to continue even when the compiled code has not finished yet.\n",
" This can cause troubles when trying to profile the code for speed.\n",
" See https://jax.readthedocs.io/en/latest/async_dispatch.html and\n",
" https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.\n",
" \"\"\"\n",
" self.start_time = datetime.now().strftime(\"%Y-%m-%d_%H-%M\")\n",
" num_warmup_total = int(self.num_warmup)\n",
" num_samples_total = int(self.num_samples)\n",
" \n",
" check_point_indices = [\n",
" jnp.arange(num_warmup_total), \n",
" *jnp.array_split(jnp.arange(num_samples_total), n_checkpoints)\n",
" ]\n",
" n_checkpoints = len(check_point_indices)\n",
" rng_keys = random.split(rng_key, n_checkpoints)\n",
" pbar = tqdm(enumerate(zip(rng_keys, check_point_indices)), total=n_checkpoints)\n",
" for checkpoint, (rng_key, bounds) in pbar:\n",
" if checkpoint == 0:\n",
" self.progress_bar = progress_bar_warmup\n",
" pbar.set_description('Run warmup')\n",
" print_big_message(\"Begin warmup\")\n",
" self.warmup(rng_key, *args, extra_fields=extra_fields, init_params=init_params, **kwargs)\n",
" print_big_message(f\"Begin {num_samples_total} samples with {n_checkpoints} checkpoints\")\n",
"\n",
" else:\n",
" pbar.set_description(f'Run samples {bounds.min()} to {bounds.max()}')\n",
"\n",
" self.progress_bar = progress_bar_samples\n",
" self.num_samples = bounds.size\n",
" self.run(rng_key, *args, extra_fields=extra_fields, init_params=init_params, **kwargs)\n",
"\n",
" # add to running states:\n",
" if self.running_states is None:\n",
" self.running_states = dict(self._states)\n",
" else:\n",
" hstack_recursive(self.running_states, self._states)\n",
" \n",
" # ensure that calls to `self.get_samples` will build a new samples array\n",
" # out of the running states:\n",
" self._states_flat = None\n",
" self._states = self.running_states\n",
"\n",
" if on_checkpoint is not None:\n",
" on_checkpoint(self, **kwargs)\n",
" self.checkpoint += 1\n",
" \n",
" pbar.close()\n",
"\n",
" # reset to total number for arviz IO\n",
" self.num_samples = num_samples_total"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63d9bef2-8218-4427-ae0c-18ab8b49d891",
"metadata": {},
"outputs": [],
"source": [
"def post_batch_viz_save(self, **kwargs):\n",
" \"\"\"\n",
" here we define some tasks to do after each completed checkpoint:\n",
" \"\"\"\n",
" print(f'Corner for checkpoint {self.checkpoint}')\n",
" samples_cumulative = self.get_samples()\n",
" corner(samples_cumulative)\n",
"\n",
" plt.suptitle(f'checkpoint {self.checkpoint}')\n",
" plt.show()\n",
"\n",
" with open(f'samples_cumulative_{self.start_time}_checkpoint_{self.checkpoint:04d}.pkl', 'wb') as file:\n",
" pickle.dump(dict(samples_cumulative), file)"
]
},
{
"cell_type": "markdown",
"id": "a6bc2ebf-7f44-4362-8a9f-a504ae239754",
"metadata": {},
"source": [
"to load a checkpoint:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68fd41ef-7e11-4a49-8a74-d288064dee8c",
"metadata": {},
"outputs": [],
"source": [
"# with open(\"samples_cumulative_2025-09-02_16-12_checkpoint_0001.pkl\", 'rb') as f:\n",
"# last_checkpoint = pickle.load(f)\n",
"# print(last_checkpoint['a'].shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48cf6ef8-5d1e-42dd-8d08-a803eedf7360",
"metadata": {},
"outputs": [],
"source": [
"from numpyro.infer import NUTS\n",
"\n",
"rng_key = PRNGKey(10)\n",
"mcmc = MCMCWithCheckpoints(\n",
" NUTS(model), \n",
" num_warmup=100, \n",
" num_samples=1000,\n",
" num_chains=cpu_cores\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ca31ff7-caaf-4563-a940-2083db6849c7",
"metadata": {},
"outputs": [],
"source": [
"mcmc.run_checkpoints(rng_key, n_checkpoints=3, on_checkpoint=post_batch_viz_save)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e41cf52-c86f-4668-b587-245017bb995a",
"metadata": {},
"outputs": [],
"source": [
"mcmc.print_summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18897a70-fb7c-42a3-a20b-71becb8033e0",
"metadata": {},
"outputs": [],
"source": [
"result = arviz.from_numpyro(mcmc)\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d21584aa-f85c-47dc-b608-8799257b772a",
"metadata": {},
"outputs": [],
"source": [
"arviz.summary(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9b38cb7-c31c-4acb-bfd9-0ea4f9fd18a3",
"metadata": {},
"outputs": [],
"source": [
"corner(result, truths=[a_truth, b_truth]);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9afd7e8-a9a1-4414-b02e-4dede20974c6",
"metadata": {},
"outputs": [],
"source": [
"arviz.summary(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4acaddf-8c5f-4db8-85d3-a19a602396ac",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment