Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save tomonari-masada/e5ce480092acb275f6ff3f653d111c0d to your computer and use it in GitHub Desktop.

Select an option

Save tomonari-masada/e5ce480092acb275f6ff3f653d111c0d to your computer and use it in GitHub Desktop.
zero_inflated_negative_binomial.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/tomonari-masada/e5ce480092acb275f6ff3f653d111c0d/zero_inflated_negative_binomial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "12788303",
"metadata": {
"id": "12788303"
},
"source": [
"# Zero-Inflated negative binomial distribution"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e18fe42a",
"metadata": {
"id": "e18fe42a"
},
"outputs": [],
"source": [
"!pip install numpyro"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18ee6419",
"metadata": {
"id": "18ee6419"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"from scipy.stats import nbinom\n",
"\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
"from jax.scipy.special import expit\n",
"import numpyro\n",
"import numpyro.distributions as dist\n",
"from numpyro.infer import NUTS, MCMC\n",
"\n",
"%config InlineBackend.figure_format = 'retina'\n",
"\n",
"rng = np.random.default_rng()\n",
"rng_key = random.PRNGKey(0)\n",
"\n",
"numpyro.set_platform(\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec802eff",
"metadata": {
"id": "ec802eff"
},
"outputs": [],
"source": [
"size = 1000\n",
"\n",
"true_proportion_of_zero_inflation = 0.3\n",
"number_of_successes = 5\n",
"success_probability = 0.4\n",
"\n",
"negative_binomial_samples = nbinom.rvs(\n",
" n=number_of_successes,\n",
" p=success_probability,\n",
" random_state=rng,\n",
" size=size\n",
")\n",
"negative_binomial_samples = jnp.array(negative_binomial_samples)\n",
"\n",
"rng_key, rng_key_data = random.split(rng_key)\n",
"is_zero_inflated = random.bernoulli(rng_key_data, p=true_proportion_of_zero_inflation, shape=(size,))\n",
"data = jnp.where(is_zero_inflated, 0, negative_binomial_samples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d9ca6a9",
"metadata": {
"id": "5d9ca6a9",
"outputId": "e0b0f0b5-84c9-494c-e078-40cf329408f7"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([ 0, 0, 2, 5, 0, 0, 0, 4, 0, 3, 5, 1, 0, 0, 4, 15, 0,\n",
" 9, 0, 0, 0, 4, 10, 9, 0, 0, 3, 1, 3, 3, 0, 0, 6, 10,\n",
" 0, 0, 0, 0, 5, 0, 0, 0, 12, 0, 2, 0, 6, 5, 8, 8, 9,\n",
" 0, 6, 0, 8, 0, 10, 5, 0, 5, 10, 11, 10, 13, 8, 0, 24, 0,\n",
" 0, 0, 0, 8, 6, 3, 1, 14, 4, 12, 10, 5, 0, 7, 0, 5, 0,\n",
" 4, 0, 1, 0, 4, 0, 7, 17, 10, 9, 10, 0, 4, 0, 5, 5, 10,\n",
" 0, 7, 0, 3, 9, 3, 6, 0, 0, 11, 0, 2, 0, 5, 7, 13, 15,\n",
" 0, 0, 2, 13, 8, 3, 0, 2, 7, 12, 6, 0, 5, 0, 2, 7, 9,\n",
" 0, 10, 5, 1, 12, 12, 19, 15, 0, 0, 0, 15, 2, 9, 0, 8, 0,\n",
" 10, 9, 7, 3, 6, 3, 11, 5, 7, 0, 3, 9, 0, 0, 9, 4, 0,\n",
" 5, 1, 0, 6, 3, 5, 5, 0, 0, 9, 0, 5, 7, 2, 8, 6, 8,\n",
" 6, 0, 11, 3, 6, 7, 7, 0, 10, 5, 6, 9, 17, 0, 0, 9, 0,\n",
" 0, 5, 4, 0, 11, 0, 0, 2, 9, 6, 9, 4, 11, 0, 0, 0, 8,\n",
" 5, 7, 22, 8, 6, 0, 5, 0, 6, 5, 3, 4, 6, 0, 0, 2, 1,\n",
" 8, 0, 2, 0, 6, 10, 3, 0, 5, 6, 8, 8, 0, 3, 8, 5, 6,\n",
" 1, 9, 2, 9, 0, 4, 14, 4, 0, 0, 0, 0, 2, 13, 0, 0, 9,\n",
" 0, 4, 0, 0, 0, 4, 0, 0, 5, 8, 8, 0, 3, 0, 6, 3, 0,\n",
" 8, 7, 0, 0, 9, 8, 4, 7, 2, 4, 0, 10, 2, 0, 8, 6, 1,\n",
" 7, 8, 0, 0, 0, 0, 0, 13, 5, 8, 7, 0, 0, 0, 5, 1, 4,\n",
" 6, 9, 11, 6, 0, 4, 9, 4, 12, 10, 0, 0, 4, 0, 13, 12, 0,\n",
" 0, 15, 0, 10, 14, 4, 6, 4, 7, 6, 7, 0, 9, 0, 10, 3, 12,\n",
" 7, 4, 0, 0, 8, 10, 10, 0, 0, 0, 2, 4, 5, 13, 0, 1, 5,\n",
" 5, 5, 0, 10, 6, 13, 3, 15, 10, 11, 11, 0, 0, 2, 10, 4, 8,\n",
" 5, 0, 3, 7, 4, 0, 5, 3, 23, 0, 10, 4, 0, 1, 0, 4, 3,\n",
" 2, 5, 6, 0, 2, 0, 3, 0, 0, 8, 0, 0, 0, 9, 18, 7, 14,\n",
" 17, 3, 6, 6, 8, 3, 10, 16, 0, 8, 0, 0, 5, 0, 7, 6, 0,\n",
" 0, 9, 14, 0, 4, 7, 5, 0, 8, 0, 0, 0, 17, 10, 0, 5, 8,\n",
" 5, 5, 6, 9, 0, 2, 5, 0, 3, 9, 0, 0, 1, 0, 0, 0, 0,\n",
" 16, 14, 0, 0, 6, 5, 2, 15, 0, 0, 0, 0, 0, 0, 7, 3, 4,\n",
" 0, 6, 0, 13, 9, 0, 11, 3, 0, 8, 0, 0, 8, 7, 1, 13, 2,\n",
" 11, 4, 2, 1, 12, 11, 5, 2, 0, 11, 5, 7, 0, 14, 7, 11, 0,\n",
" 6, 0, 0, 0, 0, 6, 2, 5, 3, 8, 0, 0, 0, 4, 14, 5, 0,\n",
" 13, 4, 9, 5, 4, 0, 8, 6, 11, 0, 4, 7, 0, 9, 6, 0, 4,\n",
" 5, 0, 6, 6, 6, 4, 6, 1, 4, 2, 4, 7, 0, 15, 11, 0, 8,\n",
" 0, 4, 3, 18, 2, 6, 4, 3, 0, 4, 4, 4, 0, 0, 13, 0, 11,\n",
" 0, 0, 13, 0, 3, 0, 13, 8, 10, 0, 7, 0, 8, 9, 0, 3, 6,\n",
" 10, 0, 7, 8, 3, 0, 3, 15, 9, 0, 0, 2, 8, 6, 10, 10, 3,\n",
" 20, 0, 11, 13, 9, 2, 12, 0, 0, 13, 0, 0, 3, 13, 4, 5, 13,\n",
" 0, 0, 8, 7, 0, 4, 3, 0, 6, 8, 0, 12, 11, 9, 6, 0, 9,\n",
" 0, 0, 8, 8, 10, 4, 0, 0, 5, 5, 3, 8, 2, 12, 0, 0, 9,\n",
" 6, 15, 8, 0, 0, 13, 9, 4, 0, 8, 0, 2, 13, 12, 0, 2, 0,\n",
" 7, 12, 4, 17, 8, 0, 0, 16, 14, 3, 0, 8, 6, 1, 0, 0, 6,\n",
" 21, 4, 7, 0, 0, 11, 0, 3, 7, 9, 3, 3, 14, 7, 9, 3, 0,\n",
" 11, 2, 0, 0, 6, 7, 0, 8, 11, 14, 0, 13, 0, 0, 16, 7, 0,\n",
" 0, 2, 5, 0, 8, 5, 9, 12, 0, 17, 4, 11, 3, 7, 3, 2, 10,\n",
" 7, 0, 0, 0, 0, 9, 7, 5, 13, 14, 3, 0, 3, 4, 17, 5, 0,\n",
" 15, 0, 3, 0, 0, 7, 5, 9, 7, 0, 18, 4, 4, 0, 3, 13, 9,\n",
" 0, 16, 0, 7, 8, 0, 8, 6, 5, 12, 0, 3, 4, 5, 5, 0, 0,\n",
" 5, 0, 5, 10, 1, 9, 4, 11, 9, 6, 2, 0, 0, 5, 0, 2, 9,\n",
" 0, 2, 0, 3, 11, 6, 4, 0, 0, 4, 3, 6, 14, 0, 4, 1, 1,\n",
" 0, 6, 9, 0, 0, 0, 12, 0, 7, 0, 10, 0, 3, 7, 8, 8, 12,\n",
" 4, 0, 2, 15, 0, 0, 4, 9, 7, 3, 5, 10, 0, 10, 10, 0, 3,\n",
" 9, 12, 6, 20, 8, 0, 8, 0, 3, 0, 0, 8, 6, 6, 6, 8, 7,\n",
" 7, 4, 0, 4, 4, 9, 6, 14, 0, 8, 2, 2, 0, 8, 17, 11, 0,\n",
" 8, 10, 2, 6, 3, 2, 8, 2, 7, 0, 0, 8, 0, 6, 0, 13, 6,\n",
" 2, 0, 0, 0, 8, 14, 5, 0, 0, 17, 0, 10, 0, 0, 0, 14, 4,\n",
" 0, 8, 4, 1, 0, 12, 5, 0, 7, 0, 10, 6, 7, 0, 11, 0, 6,\n",
" 3, 6, 15, 4, 5, 2, 5, 0, 2, 0, 11, 0, 0, 1, 11, 6, 4,\n",
" 3, 4, 3, 8, 12, 8, 8, 3, 0, 4, 7, 5, 0, 6], dtype=int32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c82450fb",
"metadata": {
"id": "c82450fb",
"outputId": "e55090a7-4fb0-4ac8-c736-b631696a5770"
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {
"image/png": {
"height": 454,
"width": 571
}
},
"output_type": "display_data"
}
],
"source": [
"sns.barplot(x=pd.Series(data).value_counts().sort_index().index,\n",
" y=pd.Series(data).value_counts().sort_index().values,\n",
" color=\"C0\")\n",
"plt.xlabel(\"Value\")\n",
"plt.ylabel(\"Count\")\n",
"plt.title(\"Histogram of observed data\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed039183",
"metadata": {
"id": "ed039183"
},
"outputs": [],
"source": [
"class ZeroInflatedNegativeBinomial(dist.discrete.ZeroInflatedProbs):\n",
" arg_constraints = {\n",
" \"gate\": dist.constraints.unit_interval,\n",
" \"number_of_successes\": dist.constraints.positive,\n",
" \"success_probability\": dist.constraints.unit_interval,\n",
" }\n",
" support = dist.constraints.nonnegative_integer\n",
" pytree_data_fields = (\"number_of_successes\", \"success_probability\")\n",
" def __init__(self, gate, number_of_successes, success_probability, *, validate_args=None):\n",
" _, self.number_of_successes = dist.util.promote_shapes(gate, number_of_successes)\n",
" _, self.success_probability = dist.util.promote_shapes(gate, success_probability)\n",
" negative_binomial_component = dist.NegativeBinomial2(\n",
" concentration=number_of_successes,\n",
" mean=number_of_successes * (1 - success_probability) / success_probability\n",
" )\n",
" super().__init__(negative_binomial_component, gate, validate_args=validate_args)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e35abca8",
"metadata": {
"id": "e35abca8"
},
"outputs": [],
"source": [
"def model(y):\n",
" ap = numpyro.sample(\"ap\", dist.Normal(0, 10))\n",
" p = expit(ap)\n",
" r = numpyro.sample(\"r\", dist.Exponential(1.0))\n",
" aq = numpyro.sample(\"aq\", dist.Normal(0, 10))\n",
" q = expit(aq)\n",
" numpyro.sample(\"y\", ZeroInflatedNegativeBinomial(p, r, q), obs=y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c6461a2",
"metadata": {
"id": "3c6461a2",
"outputId": "ecc1967e-ebf9-42dd-c42a-2dfc48279e59"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"sample: 100%|██████████| 3000/3000 [00:05<00:00, 530.21it/s, 31 steps of size 1.42e-01. acc. prob=0.95]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" mean std median 5.0% 95.0% n_eff r_hat\n",
" ap -0.77 0.07 -0.77 -0.88 -0.65 819.51 1.00\n",
" aq -0.40 0.10 -0.40 -0.56 -0.23 654.70 1.00\n",
" r 4.75 0.47 4.74 3.93 5.48 637.92 1.00\n",
"\n",
"Number of divergences: 0\n"
]
}
],
"source": [
"rng_key, rng_key_mcmc = random.split(rng_key)\n",
"nuts_kernel = NUTS(model)\n",
"mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=2000)\n",
"mcmc.run(rng_key_mcmc, y=data)\n",
"mcmc.print_summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb19de18",
"metadata": {
"id": "eb19de18",
"outputId": "8005c5f3-105c-4a8a-88fe-9bea9e3b220a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"proportion of zero inflation: 0.317\n",
"number of successes: 4.754\n",
"success probability: 0.402\n"
]
}
],
"source": [
"post = mcmc.get_samples()\n",
"print(f\"proportion of zero inflation: {jnp.mean(expit(post['ap'])):0.3f}\") # proportion of zero inflation\n",
"print(f\"number of successes: {jnp.mean(post['r']):0.3f}\") # number of successes\n",
"print(f\"success probability: {jnp.mean(expit(post['aq'])):0.3f}\") # success probability"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f2c91e9",
"metadata": {
"id": "9f2c91e9"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment