Last active
April 2, 2024 11:25
-
-
Save zaxtax/5fd7c881c6ac83a7ca2798d0a7e230b7 to your computer and use it in GitHub Desktop.
Example notebook showing multiple chain progress bars for blackjax
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "83270ba5-de1c-455e-be88-ab88a44a23f6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "import multiprocessing\n", | |
| "\n", | |
| "os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count={}\".format(\n", | |
| " multiprocessing.cpu_count()\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "327348c3-9d79-4e01-a09c-5f9e4227a6c8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import tqdm\n", | |
| "from tqdm.auto import tqdm as tqdm_auto\n", | |
| "\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "import jax.scipy.stats as stats\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "lax = jax.lax\n", | |
| "\n", | |
| "def progress_bar_factory(num_samples, num_chains):\n", | |
| " \"\"\"Factory that builds a progress bar decorator along\n", | |
| " with the `set_tqdm_description` and `close_tqdm` functions\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " if num_samples > 20:\n", | |
| " print_rate = int(num_samples / 20)\n", | |
| " else:\n", | |
| " print_rate = 1\n", | |
| "\n", | |
| " remainder = num_samples % print_rate\n", | |
| "\n", | |
| " tqdm_bars = {}\n", | |
| " for chain in range(num_chains):\n", | |
| " tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)\n", | |
| " tqdm_bars[chain].set_description(\"Compiling.. \", refresh=True)\n", | |
| "\n", | |
| " def _update_tqdm(arg, chain):\n", | |
| " chain = int(chain)\n", | |
| " tqdm_bars[chain].set_description(f\"Running chain {chain}\", refresh=False)\n", | |
| " tqdm_bars[chain].update(arg)\n", | |
| "\n", | |
| " def _close_tqdm(arg, chain):\n", | |
| " chain = int(chain)\n", | |
| " tqdm_bars[chain].update(arg)\n", | |
| " tqdm_bars[chain].close()\n", | |
| "\n", | |
| " def _update_progress_bar(iter_num, chain):\n", | |
| " \"\"\"Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate\n", | |
| " Usage: carry = progress_bar((iter_num, print_rate), carry)\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " _ = lax.cond(\n", | |
| " iter_num == 0,\n", | |
| " lambda _: jax.debug.callback(_update_tqdm, iter_num, chain),\n", | |
| " lambda _: None,\n", | |
| " operand=None,\n", | |
| " )\n", | |
| " _ = lax.cond(\n", | |
| " (iter_num % print_rate) == 0,\n", | |
| " lambda _: jax.debug.callback(_update_tqdm, print_rate, chain),\n", | |
| " lambda _: None,\n", | |
| " operand=None,\n", | |
| " )\n", | |
| " _ = lax.cond(\n", | |
| " iter_num == num_samples - 1,\n", | |
| " lambda _: jax.debug.callback(_close_tqdm, remainder, chain),\n", | |
| " lambda _: None,\n", | |
| " operand=None,\n", | |
| " )\n", | |
| "\n", | |
| " def _progress_bar_scan(func):\n", | |
| " \"\"\"Decorator that adds a progress bar to `body_fun` used in `lax.scan`.\n", | |
| " Note that `body_fun` must either be looping over `np.arange(num_samples)`,\n", | |
| " looping over a tuple whose elements are `np.arange(num_samples), and a\n", | |
| " chain id defined as `chain * np.ones(num_samples)`, or be looping over a\n", | |
| " tuple who's first element and second elements include iter_num and chain.\n", | |
| " This means that `iter_num` is the current iteration number\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " def wrapper_progress_bar(carry, x):\n", | |
| " if type(x) is tuple:\n", | |
| " if num_chains > 1:\n", | |
| " iter_num, chain, *_ = x\n", | |
| " else:\n", | |
| " iter_num, *_ = x\n", | |
| " chain = 0\n", | |
| " else:\n", | |
| " iter_num = x\n", | |
| " chain = 0\n", | |
| " _update_progress_bar(iter_num, chain)\n", | |
| " return func(carry, x)\n", | |
| "\n", | |
| " return wrapper_progress_bar\n", | |
| "\n", | |
| " return _progress_bar_scan" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "6ea28709-aa8b-4b27-bcec-2bb22b049de6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import numpy as np\n", | |
| "\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "import jax.scipy.stats as stats\n", | |
| "from blackjax.progress_bar import progress_bar_scan\n", | |
| "\n", | |
| "loc, scale = 10, 20\n", | |
| "observed = np.random.normal(loc, scale, size=1_000)\n", | |
| "\n", | |
| "\n", | |
| "def logdensity_fn(loc, log_scale, observed=observed):\n", | |
| " \"\"\"Univariate Normal\"\"\"\n", | |
| " scale = jnp.exp(log_scale)\n", | |
| " logpdf = stats.norm.logpdf(observed, loc, scale)\n", | |
| " return jnp.sum(logpdf)\n", | |
| "\n", | |
| "\n", | |
| "def logdensity(x):\n", | |
| " return logdensity_fn(**x)\n", | |
| "\n", | |
| "\n", | |
| "def inference_loop(rng_key, kernel, initial_state, chain, num_samples, num_chains):\n", | |
| "\n", | |
| " def _one_step(state, xs):\n", | |
| " _, _, rng_key = xs\n", | |
| " state, _ = kernel(rng_key, state)\n", | |
| " return state, state\n", | |
| " one_step = jax.jit(progress_bar_factory(num_samples, num_chains)(_one_step))\n", | |
| "\n", | |
| " keys = jax.random.split(rng_key, num_samples)\n", | |
| " _, states = jax.lax.scan(\n", | |
| " one_step,\n", | |
| " initial_state,\n", | |
| " (np.arange(num_samples), chain * np.ones(num_samples), keys),\n", | |
| " )\n", | |
| "\n", | |
| " return states\n", | |
| "\n", | |
| "from datetime import date\n", | |
| "rng_key = jax.random.key(int(date.today().strftime(\"%Y%m%d\")))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "2dbbdc44-fb24-439d-9d40-b92ad2b7d36c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import blackjax\n", | |
| "\n", | |
| "\n", | |
| "inv_mass_matrix = np.array([0.5, 0.01])\n", | |
| "step_size = 1e-3\n", | |
| "\n", | |
| "nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "e92decf5-9fa7-41c8-a01b-936fd7e08f9b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "num_chains = 4" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "0d440c9a-58c7-4f62-8e3d-68481ae24f82", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "initial_positions = {\"loc\": np.ones(num_chains), \"log_scale\": np.ones(num_chains)}\n", | |
| "initial_states = jax.vmap(nuts.init, in_axes=(0))(initial_positions)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "0f624640-cbfe-4c88-a20c-a58f2d00e97c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jax.devices()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "29f839ad-3a78-4997-802f-e811053809a8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "inference_loop_multiple_chains = jax.pmap(\n", | |
| " inference_loop,\n", | |
| " in_axes=(0, None, 0, 0, None, None),\n", | |
| " static_broadcasted_argnums=(1, 4, 5),\n", | |
| " devices=jax.devices(),\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "7a075352-f3ce-4b20-9c00-4493ef898ddc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "e071eb84a0434ecb853ec4c7377e460a", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/2000 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "870596fa992b40b29b2067761159250c", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/2000 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "994ad5db90e44e52af92dca5ebf2cd71", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/2000 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "7d9d785405f941edad938d3553b752a4", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| " 0%| | 0/2000 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 38 s, sys: 188 ms, total: 38.2 s\n", | |
| "Wall time: 16.2 s\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "rng_key, sample_key = jax.random.split(rng_key)\n", | |
| "sample_keys = jax.random.split(sample_key, num_chains)\n", | |
| "\n", | |
| "pmap_states = inference_loop_multiple_chains(\n", | |
| " sample_keys, nuts.step, initial_states, jnp.arange(num_chains), 2_000, num_chains,\n", | |
| ")\n", | |
| "_ = pmap_states.position[\"loc\"].block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "b92b34f0-b71d-418a-9854-e1294315884d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(4, 2000)" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pmap_states.position[\"loc\"].shape" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "pymc-dev", | |
| "language": "python", | |
| "name": "pymc-dev" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment