Skip to content

Instantly share code, notes, and snippets.

@keisukefukuda
Last active December 28, 2024 13:26
Show Gist options
  • Select an option

  • Save keisukefukuda/47b191fdffb2d6a90811042e03a47187 to your computer and use it in GitHub Desktop.

Select an option

Save keisukefukuda/47b191fdffb2d6a90811042e03a47187 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from abc import ABC, abstractmethod\n",
"import random\n",
"import math\n",
"\n",
"\n",
"# 離散確率分布の抽象クラス(現在は1次元のみ)\n",
"class DiscreteProb(ABC):\n",
" @abstractmethod\n",
" def pmf(self, x: int):\n",
" # 確率量子関数\n",
" pass\n",
"\n",
" @abstractmethod\n",
" def __call__(self):\n",
" # 確率変数\n",
" pass\n",
"\n",
"\n",
"class Bern(DiscreteProb):\n",
" def __init__(self, mu: float):\n",
" self.mu = mu\n",
"\n",
" def pmf(self, x: int):\n",
" if x == 0:\n",
" return 1 - self.mu\n",
" elif x == 1:\n",
" return self.mu\n",
" else:\n",
" return 0\n",
"\n",
" def __call__(self):\n",
" return 1 if random.random() < self.mu else 0"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# サンプリングによる期待値計算とエントロピー計算\n",
"\n",
"\n",
"class SimpleSampler(object):\n",
" def __init__(self):\n",
" pass\n",
"\n",
" def __call__(self, p, L: int, f=None) -> float:\n",
" # 単純なサンプリングを用いて f(x) の期待値を計算する\n",
" # # p: 確率分布\n",
" # # f: 関数\n",
" # L: サンプル数\n",
" if f is None:\n",
" f = lambda x: x\n",
" return sum([f(p()) for _ in range(L)]) / L\n",
"\n",
"\n",
"class Entropy(object):\n",
" def __init__(self, sampler=SimpleSampler()):\n",
" self.sampler = sampler\n",
"\n",
" def __call__(self, p, n_samples: int = 10000) -> float:\n",
" return -1 * self.sampler(p, n_samples, lambda x: math.log(p.pmf(x)))\n",
"\n",
"\n",
"class KLDiv(object):\n",
" def __init__(self, sampler=SimpleSampler()):\n",
" self.sampler = sampler\n",
"\n",
" def __call__(self, *, p, q, L: int = 10000) -> float:\n",
" # p(X) と q(X) のKL距離\n",
" # KL[q(x) || p(x)]]\n",
" # を計算する\n",
" return abs(-1 * self.sampler(\n",
" q, f=lambda x: log(p.pmf(x) / q.pmf(x)), L=L\n",
" ))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"近似値:0.64078858\n",
"理論値:0.63651417\n"
]
}
],
"source": [
"# ================================================\n",
"# 2.1\n",
"# ================================================\n",
"\n",
"\n",
"# 2.1.5 サンプリングによる期待値の近似計算\n",
"# 例題:単純な確率分布のエントロピー計算\n",
"\n",
"from math import log\n",
"\n",
"\n",
"class MyDist(DiscreteProb):\n",
" def pmf(self, x: int) -> float:\n",
" assert x in [0, 1]\n",
" if x == 0:\n",
" return 2.0 / 3\n",
" else:\n",
" return 1.0 / 3\n",
"\n",
" def __call__(self) -> int:\n",
" if random.random() < 2.0 / 3:\n",
" return 0\n",
" else:\n",
" return 1\n",
"\n",
"\n",
"p = MyDist()\n",
"H = Entropy()\n",
"print(f\"近似値:{H(p):.8f}\")\n",
"print(f\"理論値:{-(1/3 * log(1/3) + 2/3 * log(2/3)):.8f}\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0]\n",
"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
"E[Bern(x|0.5)] = 0.5\n",
"E[Bern(x|0.9)] = 0.8957\n",
"H[Bern(x|0.5)] = 0.6931471805600546\n",
"H[Bern(x|0.9)] = 0.3220068589832419\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>0.1</th>\n",
" <th>0.5</th>\n",
" <th>0.9</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0.1</th>\n",
" <td>0.000000</td>\n",
" <td>0.512364</td>\n",
" <td>1.760416</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0.5</th>\n",
" <td>0.368372</td>\n",
" <td>0.000000</td>\n",
" <td>0.370635</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0.9</th>\n",
" <td>1.766085</td>\n",
" <td>0.509353</td>\n",
" <td>0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 0.1 0.5 0.9\n",
"0.1 0.000000 0.512364 1.760416\n",
"0.5 0.368372 0.000000 0.370635\n",
"0.9 1.766085 0.509353 0.000000"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# ================================================\n",
"# 2.2.1 ベルヌーイ分布\n",
"# ================================================\n",
"\n",
"b1 = Bern(0.5)\n",
"print([b1() for _ in range(20)])\n",
"\n",
"b2 = Bern(0.9)\n",
"print([b2() for _ in range(20)])\n",
"\n",
"sampler = SimpleSampler()\n",
"print(f\"E[Bern(x|0.5)] = {sampler(b1, 10000)}\")\n",
"print(f\"E[Bern(x|0.9)] = {sampler(b2, 10000)}\")\n",
"\n",
"H = Entropy()\n",
"print(f\"H[Bern(x|0.5)] = {H(b1)}\")\n",
"print(f\"H[Bern(x|0.9)] = {H(b2)}\")\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"\n",
"x = np.linspace(0, 1, 100)\n",
"y = [H(Bern(mu), 100000) for mu in x]\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.plot(x, y)\n",
"ax.set_ylim(0, 0.7)\n",
"ax.set_xlim(0, 1.0)\n",
"ax.set_xlabel(\"μ\")\n",
"ax.set_ylabel(\"H[Bern(x|μ)]\")\n",
"\n",
"import pandas as pd\n",
"\n",
"kl = KLDiv()\n",
"\n",
"x = [0.1, 0.5, 0.9]\n",
"data = pd.DataFrame(\n",
" [[kl(p=Bern(mu2), q=Bern(mu1), L=100000) for mu1 in x] for mu2 in x],\n",
" index=x,\n",
" columns=x\n",
")\n",
"\n",
"data\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.0, 0.35)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# ================================================\n",
"# 2.2.2 二項分布\n",
"# ================================================\n",
"\n",
"class Bin(DiscreteProb):\n",
" def __init__(self, M: int, mu: float):\n",
" self.M = M\n",
" self.mu = mu\n",
"\n",
" def pmf(self, m: int) -> float:\n",
" return math.comb(self.M, m) * self.mu**m * (1 - self.mu)**(self.M - m)\n",
"\n",
" def __call__(self) -> int:\n",
" return sum([1 if random.random() < self.mu else 0 for _ in range(self.M)])\n",
"\n",
"\n",
"for (M, mu) in [(10, 0.5), (10, 0.2), (10, 0.9), (100, 0.5)]:\n",
" b = Bin(M=M, mu=mu)\n",
" x = range(0, 21)\n",
" y = [b.pmf(m) for m in x]\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.bar(x, y)\n",
" ax.set_ylim(0, 0.35)\n",
" ax.title(f\"M={M}, μ={mu}\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.11.7 (main, Dec 4 2023, 18:10:11) [Clang 15.0.0 (clang-1500.1.0.2.5)]\n",
"Requirement already satisfied: matplotlib in /opt/homebrew/lib/python3.11/site-packages (3.8.1)\n",
"Requirement already satisfied: pandas in /opt/homebrew/lib/python3.11/site-packages (2.2.3)\n",
"Requirement already satisfied: numpy in /opt/homebrew/lib/python3.11/site-packages (1.26.1)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /opt/homebrew/lib/python3.11/site-packages (from matplotlib) (1.1.1)\n",
"Requirement already satisfied: cycler>=0.10 in /opt/homebrew/lib/python3.11/site-packages (from matplotlib) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /opt/homebrew/lib/python3.11/site-packages (from matplotlib) (4.43.1)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /opt/homebrew/lib/python3.11/site-packages (from matplotlib) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/keisukefukuda/Library/Python/3.11/lib/python/site-packages (from matplotlib) (23.2)\n",
"Requirement already satisfied: pillow>=8 in /opt/homebrew/lib/python3.11/site-packages (from matplotlib) (10.1.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /opt/homebrew/lib/python3.11/site-packages (from matplotlib) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /Users/keisukefukuda/Library/Python/3.11/lib/python/site-packages (from matplotlib) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2024.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /opt/homebrew/lib/python3.11/site-packages (from pandas) (2024.2)\n",
"Requirement already satisfied: six>=1.5 in /opt/homebrew/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython3.11 -m pip install --upgrade pip\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"CompletedProcess(args=['/opt/homebrew/opt/python@3.11/bin/python3.11', '-m', 'pip', 'install', 'matplotlib', 'pandas', 'numpy'], returncode=0)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import math\n",
"math.comb(10, 3)\n",
"\n",
"import sys\n",
"print(sys.version)\n",
"from subprocess import run\n",
"run([sys.executable, \"-m\", \"pip\", \"install\", \"matplotlib\", \"pandas\", \"numpy\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment