Last active
September 2, 2025 22:51
-
-
Save bmorris3/d63c041aaca8504f24c79a7aa5346e2b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "50911bcb-a862-42c2-aa61-c0fabac6cbe9", | |
| "metadata": {}, | |
| "source": [ | |
| "Run `numpyro.infer.MCMC` with checkpoints before all samples are completed. At each checkpoint, run a custom visualization or write-out function." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "1eb6882e-3717-4086-9c47-e6e690237b1d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Set the number of cores on your machine for parallel computing:\n", | |
| "import numpyro\n", | |
| "cpu_cores = 3\n", | |
| "numpyro.set_host_device_count(cpu_cores)\n", | |
| "\n", | |
| "from jax import numpy as jnp\n", | |
| "import shone\n", | |
| "from shone.spectrum import bin_spectrum\n", | |
| "\n", | |
| "from numpyro.infer import MCMC, NUTS\n", | |
| "from numpyro import distributions as dist\n", | |
| "\n", | |
| "# The just-in-time decorator:\n", | |
| "from jax import jit\n", | |
| "\n", | |
| "# random numbers in jax are handled by these:\n", | |
| "from jax import random\n", | |
| "\n", | |
| "# these packages will aid in visualization:\n", | |
| "import arviz\n", | |
| "from corner import corner" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "aa0a5ba6-01a6-4a25-ac6f-e3a6cdb1a214", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "a_truth = 0.05\n", | |
| "b_truth = 0.40\n", | |
| "x = jnp.linspace(-5, 5, 100)\n", | |
| "y_err = 0.01\n", | |
| "y_obs = a_truth * x ** 2 + b_truth * x + y_err * random.normal(PRNGKey(0), shape=x.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "f257f125-d354-41ca-9bd2-a1c530882823", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def model():\n", | |
| " a = numpyro.sample('a', dist.Uniform(0, 0.1))\n", | |
| " b = numpyro.sample('b', dist.Uniform(0, 0.6))\n", | |
| " y_model = a * x ** 2 + b * x\n", | |
| " return numpyro.sample(\n", | |
| " 'obs', dist.Normal(y_model, y_err), \n", | |
| " obs=y_obs\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "c8a9a8b9-3ded-4473-949f-9eb8b6e1cbee", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle\n", | |
| "from collections import defaultdict\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "\n", | |
| "from tqdm.auto import tqdm \n", | |
| "from jax import device_get, jit, lax, local_device_count, pmap, random, vmap\n", | |
| "\n", | |
| "from functools import partial\n", | |
| "from numpyro.infer.util import initialize_model\n", | |
| "from numpyro.infer import MCMC, log_likelihood\n", | |
| "from numpyro.util import fori_collect\n", | |
| "from numpyro.infer.hmc import hmc\n", | |
| "from datetime import datetime\n", | |
| "\n", | |
| "rng_seed = 0\n", | |
| "\n", | |
| "\n", | |
| "def hstack_recursive(final_states, checkpoint_states):\n", | |
| " for key in final_states.keys():\n", | |
| " if isinstance(final_states[key], dict):\n", | |
| " hstack_recursive(final_states[key], checkpoint_states[key])\n", | |
| " else:\n", | |
| " final_states[key] = jnp.hstack([\n", | |
| " final_states[key], \n", | |
| " checkpoint_states[key]\n", | |
| " ])\n", | |
| "\n", | |
| "\n", | |
| "def print_big_message(big_message):\n", | |
| " print('\\n\\n')\n", | |
| " print('=' * len(big_message))\n", | |
| " print(big_message)\n", | |
| " print('=' * len(big_message))\n", | |
| " print('\\n\\n')\n", | |
| "\n", | |
| "\n", | |
| "class MCMCWithCheckpoints(MCMC):\n", | |
| " running_states = None\n", | |
| " checkpoint = 0\n", | |
| " start_time = None\n", | |
| " \n", | |
| "\n", | |
| " def run_checkpoints(self, rng_key, *args, extra_fields=(), n_checkpoints=10, \n", | |
| " progress_bar_warmup=True, progress_bar_samples=True, \n", | |
| " init_params=None, on_checkpoint=None, **kwargs):\n", | |
| " \"\"\"\n", | |
| " Run the MCMC samplers and collect samples.\n", | |
| "\n", | |
| " :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.\n", | |
| " For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key`\n", | |
| " does not have batch_size, it will be split in to a batch of `num_chains` keys.\n", | |
| " :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method.\n", | |
| " These are typically the arguments needed by the `model`.\n", | |
| " :param extra_fields: Extra fields (aside from `\"z\"`, `\"diverging\"`) from the\n", | |
| " state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to be collected\n", | |
| " during the MCMC run. Note that subfields can be accessed using dots, e.g.\n", | |
| " `\"adapt_state.step_size\"` can be used to collect step sizes at each step. Exclude sample sites from\n", | |
| " collection with \"~`sampler.sample_field`.`sample_site`\". e.g. \"~z.a\" will prevent site \"a\" from\n", | |
| " being collected if you're using the NUTS sampler. To collect samples of a site \"a\" in the\n", | |
| " unconstrained space, we can specify the variable here, e.g. `extra_fields=(\"z.a\",)`.\n", | |
| " :type extra_fields: tuple or list of str\n", | |
| " :param init_params: Initial parameters to begin sampling. The type must be consistent\n", | |
| " with the input type to `potential_fn` provided to the kernel. If the kernel is\n", | |
| " instantiated by a numpyro model, the initial parameters here correspond to latent\n", | |
| " values in unconstrained space.\n", | |
| " :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`\n", | |
| " method. These are typically the keyword arguments needed by the `model`.\n", | |
| "\n", | |
| " .. note:: jax allows python code to continue even when the compiled code has not finished yet.\n", | |
| " This can cause troubles when trying to profile the code for speed.\n", | |
| " See https://jax.readthedocs.io/en/latest/async_dispatch.html and\n", | |
| " https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.\n", | |
| " \"\"\"\n", | |
| " self.start_time = datetime.now().strftime(\"%Y-%m-%d_%H-%M\")\n", | |
| " num_warmup_total = int(self.num_warmup)\n", | |
| " num_samples_total = int(self.num_samples)\n", | |
| " \n", | |
| " check_point_indices = [\n", | |
| " jnp.arange(num_warmup_total), \n", | |
| " *jnp.array_split(jnp.arange(num_samples_total), n_checkpoints)\n", | |
| " ]\n", | |
| " n_checkpoints = len(check_point_indices)\n", | |
| " rng_keys = random.split(rng_key, n_checkpoints)\n", | |
| " pbar = tqdm(enumerate(zip(rng_keys, check_point_indices)), total=n_checkpoints)\n", | |
| " for checkpoint, (rng_key, bounds) in pbar:\n", | |
| " if checkpoint == 0:\n", | |
| " self.progress_bar = progress_bar_warmup\n", | |
| " pbar.set_description('Run warmup')\n", | |
| " print_big_message(\"Begin warmup\")\n", | |
| " self.warmup(rng_key, *args, extra_fields=extra_fields, init_params=init_params, **kwargs)\n", | |
| " print_big_message(f\"Begin {num_samples_total} samples with {n_checkpoints} checkpoints\")\n", | |
| "\n", | |
| " else:\n", | |
| " pbar.set_description(f'Run samples {bounds.min()} to {bounds.max()}')\n", | |
| "\n", | |
| " self.progress_bar = progress_bar_samples\n", | |
| " self.num_samples = bounds.size\n", | |
| " self.run(rng_key, *args, extra_fields=extra_fields, init_params=init_params, **kwargs)\n", | |
| "\n", | |
| " # add to running states:\n", | |
| " if self.running_states is None:\n", | |
| " self.running_states = dict(self._states)\n", | |
| " else:\n", | |
| " hstack_recursive(self.running_states, self._states)\n", | |
| " \n", | |
| " # ensure that calls to `self.get_samples` will build a new samples array\n", | |
| " # out of the running states:\n", | |
| " self._states_flat = None\n", | |
| " self._states = self.running_states\n", | |
| "\n", | |
| " if on_checkpoint is not None:\n", | |
| " on_checkpoint(self, **kwargs)\n", | |
| " self.checkpoint += 1\n", | |
| " \n", | |
| " pbar.close()\n", | |
| "\n", | |
| " # reset to total number for arviz IO\n", | |
| " self.num_samples = num_samples_total" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "63d9bef2-8218-4427-ae0c-18ab8b49d891", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def post_batch_viz_save(self, **kwargs):\n", | |
| " \"\"\"\n", | |
| " here we define some tasks to do after each completed checkpoint:\n", | |
| " \"\"\"\n", | |
| " print(f'Corner for checkpoint {self.checkpoint}')\n", | |
| " samples_cumulative = self.get_samples()\n", | |
| " corner(samples_cumulative)\n", | |
| "\n", | |
| " plt.suptitle(f'checkpoint {self.checkpoint}')\n", | |
| " plt.show()\n", | |
| "\n", | |
| " with open(f'samples_cumulative_{self.start_time}_checkpoint_{self.checkpoint:04d}.pkl', 'wb') as file:\n", | |
| " pickle.dump(dict(samples_cumulative), file)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "a6bc2ebf-7f44-4362-8a9f-a504ae239754", | |
| "metadata": {}, | |
| "source": [ | |
| "to load a checkpoint:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "68fd41ef-7e11-4a49-8a74-d288064dee8c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# with open(\"samples_cumulative_2025-09-02_16-12_checkpoint_0001.pkl\", 'rb') as f:\n", | |
| "# last_checkpoint = pickle.load(f)\n", | |
| "# print(last_checkpoint['a'].shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "48cf6ef8-5d1e-42dd-8d08-a803eedf7360", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from numpyro.infer import NUTS\n", | |
| "\n", | |
| "rng_key = PRNGKey(10)\n", | |
| "mcmc = MCMCWithCheckpoints(\n", | |
| " NUTS(model), \n", | |
| " num_warmup=100, \n", | |
| " num_samples=1000,\n", | |
| " num_chains=cpu_cores\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "9ca31ff7-caaf-4563-a940-2083db6849c7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "mcmc.run_checkpoints(rng_key, n_checkpoints=3, on_checkpoint=post_batch_viz_save)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "8e41cf52-c86f-4668-b587-245017bb995a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "mcmc.print_summary()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "18897a70-fb7c-42a3-a20b-71becb8033e0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "result = arviz.from_numpyro(mcmc)\n", | |
| "result" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "d21584aa-f85c-47dc-b608-8799257b772a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "arviz.summary(result)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "e9b38cb7-c31c-4acb-bfd9-0ea4f9fd18a3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "corner(result, truths=[a_truth, b_truth]);" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "e9afd7e8-a9a1-4414-b02e-4dede20974c6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "arviz.summary(result)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "b4acaddf-8c5f-4db8-85d3-a19a602396ac", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment