Skip to content

Instantly share code, notes, and snippets.

@zaxtax
Last active April 2, 2024 11:25
Show Gist options
  • Select an option

  • Save zaxtax/5fd7c881c6ac83a7ca2798d0a7e230b7 to your computer and use it in GitHub Desktop.

Select an option

Save zaxtax/5fd7c881c6ac83a7ca2798d0a7e230b7 to your computer and use it in GitHub Desktop.
Example notebook showing multiple chain progress bars for blackjax
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "83270ba5-de1c-455e-be88-ab88a44a23f6",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import multiprocessing\n",
"\n",
"os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count={}\".format(\n",
" multiprocessing.cpu_count()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "327348c3-9d79-4e01-a09c-5f9e4227a6c8",
"metadata": {},
"outputs": [],
"source": [
"import tqdm\n",
"from tqdm.auto import tqdm as tqdm_auto\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.scipy.stats as stats\n",
"import numpy as np\n",
"\n",
"lax = jax.lax\n",
"\n",
"def progress_bar_factory(num_samples, num_chains):\n",
" \"\"\"Factory that builds a progress bar decorator along\n",
" with the `set_tqdm_description` and `close_tqdm` functions\n",
" \"\"\"\n",
"\n",
" if num_samples > 20:\n",
" print_rate = int(num_samples / 20)\n",
" else:\n",
" print_rate = 1\n",
"\n",
" remainder = num_samples % print_rate\n",
"\n",
" tqdm_bars = {}\n",
" for chain in range(num_chains):\n",
" tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)\n",
" tqdm_bars[chain].set_description(\"Compiling.. \", refresh=True)\n",
"\n",
" def _update_tqdm(arg, chain):\n",
" chain = int(chain)\n",
" tqdm_bars[chain].set_description(f\"Running chain {chain}\", refresh=False)\n",
" tqdm_bars[chain].update(arg)\n",
"\n",
" def _close_tqdm(arg, chain):\n",
" chain = int(chain)\n",
" tqdm_bars[chain].update(arg)\n",
" tqdm_bars[chain].close()\n",
"\n",
" def _update_progress_bar(iter_num, chain):\n",
" \"\"\"Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate\n",
" Usage: carry = progress_bar((iter_num, print_rate), carry)\n",
" \"\"\"\n",
"\n",
" _ = lax.cond(\n",
" iter_num == 0,\n",
" lambda _: jax.debug.callback(_update_tqdm, iter_num, chain),\n",
" lambda _: None,\n",
" operand=None,\n",
" )\n",
" _ = lax.cond(\n",
" (iter_num % print_rate) == 0,\n",
" lambda _: jax.debug.callback(_update_tqdm, print_rate, chain),\n",
" lambda _: None,\n",
" operand=None,\n",
" )\n",
" _ = lax.cond(\n",
" iter_num == num_samples - 1,\n",
" lambda _: jax.debug.callback(_close_tqdm, remainder, chain),\n",
" lambda _: None,\n",
" operand=None,\n",
" )\n",
"\n",
" def _progress_bar_scan(func):\n",
" \"\"\"Decorator that adds a progress bar to `body_fun` used in `lax.scan`.\n",
" Note that `body_fun` must either be looping over `np.arange(num_samples)`,\n",
" looping over a tuple whose elements are `np.arange(num_samples), and a\n",
" chain id defined as `chain * np.ones(num_samples)`, or be looping over a\n",
" tuple who's first element and second elements include iter_num and chain.\n",
" This means that `iter_num` is the current iteration number\n",
" \"\"\"\n",
"\n",
" def wrapper_progress_bar(carry, x):\n",
" if type(x) is tuple:\n",
" if num_chains > 1:\n",
" iter_num, chain, *_ = x\n",
" else:\n",
" iter_num, *_ = x\n",
" chain = 0\n",
" else:\n",
" iter_num = x\n",
" chain = 0\n",
" _update_progress_bar(iter_num, chain)\n",
" return func(carry, x)\n",
"\n",
" return wrapper_progress_bar\n",
"\n",
" return _progress_bar_scan"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6ea28709-aa8b-4b27-bcec-2bb22b049de6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.scipy.stats as stats\n",
"from blackjax.progress_bar import progress_bar_scan\n",
"\n",
"loc, scale = 10, 20\n",
"observed = np.random.normal(loc, scale, size=1_000)\n",
"\n",
"\n",
"def logdensity_fn(loc, log_scale, observed=observed):\n",
" \"\"\"Univariate Normal\"\"\"\n",
" scale = jnp.exp(log_scale)\n",
" logpdf = stats.norm.logpdf(observed, loc, scale)\n",
" return jnp.sum(logpdf)\n",
"\n",
"\n",
"def logdensity(x):\n",
" return logdensity_fn(**x)\n",
"\n",
"\n",
"def inference_loop(rng_key, kernel, initial_state, chain, num_samples, num_chains):\n",
"\n",
" def _one_step(state, xs):\n",
" _, _, rng_key = xs\n",
" state, _ = kernel(rng_key, state)\n",
" return state, state\n",
" one_step = jax.jit(progress_bar_factory(num_samples, num_chains)(_one_step))\n",
"\n",
" keys = jax.random.split(rng_key, num_samples)\n",
" _, states = jax.lax.scan(\n",
" one_step,\n",
" initial_state,\n",
" (np.arange(num_samples), chain * np.ones(num_samples), keys),\n",
" )\n",
"\n",
" return states\n",
"\n",
"from datetime import date\n",
"rng_key = jax.random.key(int(date.today().strftime(\"%Y%m%d\")))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2dbbdc44-fb24-439d-9d40-b92ad2b7d36c",
"metadata": {},
"outputs": [],
"source": [
"import blackjax\n",
"\n",
"\n",
"inv_mass_matrix = np.array([0.5, 0.01])\n",
"step_size = 1e-3\n",
"\n",
"nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e92decf5-9fa7-41c8-a01b-936fd7e08f9b",
"metadata": {},
"outputs": [],
"source": [
"num_chains = 4"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0d440c9a-58c7-4f62-8e3d-68481ae24f82",
"metadata": {},
"outputs": [],
"source": [
"initial_positions = {\"loc\": np.ones(num_chains), \"log_scale\": np.ones(num_chains)}\n",
"initial_states = jax.vmap(nuts.init, in_axes=(0))(initial_positions)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0f624640-cbfe-4c88-a20c-a58f2d00e97c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax.devices()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "29f839ad-3a78-4997-802f-e811053809a8",
"metadata": {},
"outputs": [],
"source": [
"inference_loop_multiple_chains = jax.pmap(\n",
" inference_loop,\n",
" in_axes=(0, None, 0, 0, None, None),\n",
" static_broadcasted_argnums=(1, 4, 5),\n",
" devices=jax.devices(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7a075352-f3ce-4b20-9c00-4493ef898ddc",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e071eb84a0434ecb853ec4c7377e460a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "870596fa992b40b29b2067761159250c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "994ad5db90e44e52af92dca5ebf2cd71",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7d9d785405f941edad938d3553b752a4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 38 s, sys: 188 ms, total: 38.2 s\n",
"Wall time: 16.2 s\n"
]
}
],
"source": [
"%%time\n",
"rng_key, sample_key = jax.random.split(rng_key)\n",
"sample_keys = jax.random.split(sample_key, num_chains)\n",
"\n",
"pmap_states = inference_loop_multiple_chains(\n",
" sample_keys, nuts.step, initial_states, jnp.arange(num_chains), 2_000, num_chains,\n",
")\n",
"_ = pmap_states.position[\"loc\"].block_until_ready()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b92b34f0-b71d-418a-9854-e1294315884d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4, 2000)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pmap_states.position[\"loc\"].shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pymc-dev",
"language": "python",
"name": "pymc-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment