Created
November 26, 2018 16:28
-
-
Save sadatnfs/d5004f488ba00371333770059ab99776 to your computer and use it in GitHub Desktop.
Attempting to make PyMC4 work
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import tensorflow as tf\n", | |
| "from tensorflow_probability import edward2 as ed\n", | |
| "from tensorflow_probability import distributions as tfd\n", | |
| "import pymc4 as pm4" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "## Create fake data\n", | |
| "alpha_raw = 0.1\n", | |
| "beta_raw = 0.5\n", | |
| "sigma_raw = 0.25\n", | |
| "N = 1500\n", | |
| "x = np.random.normal(size = N) \n", | |
| "y = alpha_raw + beta_raw * x + np.random.normal(scale = sigma_raw**0.5, size = N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Reset graph\n", | |
| "tf.reset_default_graph()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<pymc4.model.base.Model at 0x2abef427fc88>" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# PyMC4 model initialize\n", | |
| "pymc4_ols = pm4.Model(X = x, Y = y)\n", | |
| "pymc4_ols.observe(X = pymc4_ols.cfg.X, Y = pymc4_ols.cfg.Y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@pymc4_ols.define\n", | |
| "def process(cfg):\n", | |
| " alpha = ed.Normal(\n", | |
| " loc=0., scale=5., name=\"alpha\")\n", | |
| " beta = ed.Normal(\n", | |
| " loc=0., scale=5., name = \"beta\")\n", | |
| " sigma = tf.exp(ed.Gamma(1., 1., name = 'sigma'))\n", | |
| " Yhat = alpha + beta*cfg.X\n", | |
| " Y = ed.Normal(\n", | |
| " loc=Yhat,\n", | |
| " scale=sigma,\n", | |
| " name=\"Y\")\n", | |
| " return Y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'X': array([-0.83580625, -0.16027781, -0.20688124, ..., 0.01717773,\n", | |
| " -1.28557357, 0.07603539]),\n", | |
| " 'Y': array([-0.33343107, 0.17018252, -0.0079468 , ..., -0.5135459 ,\n", | |
| " 0.10321138, 0.92192937])}" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pymc4_ols.observed" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "OrderedDict([('alpha',\n", | |
| " VariableDescription(Dist=<class 'tensorflow_probability.python.distributions.normal.Normal'>, shape=TensorShape([]), rv=<ed.RandomVariable 'alpha/' shape=() dtype=float32>)),\n", | |
| " ('beta',\n", | |
| " VariableDescription(Dist=<class 'tensorflow_probability.python.distributions.normal.Normal'>, shape=TensorShape([]), rv=<ed.RandomVariable 'beta/' shape=() dtype=float32>)),\n", | |
| " ('sigma',\n", | |
| " VariableDescription(Dist=<class 'tensorflow_probability.python.distributions.gamma.Gamma'>, shape=TensorShape([]), rv=<ed.RandomVariable 'sigma/' shape=() dtype=float32>))])" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pymc4_ols.unobserved" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/opt/conda/lib/python3.6/site-packages/numpy-1.16.0.dev0+b47ed76-py3.6-linux-x86_64.egg/numpy/lib/type_check.py:549: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n", | |
| " 'a.item() instead', DeprecationWarning, stacklevel=1)\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Acceptance rate: 0.0\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "pymc4_trace = pm4.sample(pymc4_ols, num_burnin_steps=1000, num_results=20)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'alpha': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n", | |
| " 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32),\n", | |
| " 'beta': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n", | |
| " 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32),\n", | |
| " 'sigma': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n", | |
| " 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=float32)}" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pymc4_trace" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "### However, the 8 schools example works off the box..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Data of the Eight Schools Model\n", | |
| "J = 8\n", | |
| "y = np.array([28., 8., -3., 7., -1., 1., 18., 12.])\n", | |
| "sigma = np.array([15., 10., 16., 11., 9., 11., 10., 18.])\n", | |
| "# tau = 25." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model = pm4.Model(num_schools=J, y=y, sigma=sigma )\n", | |
| "@model.define\n", | |
| "def process(cfg):\n", | |
| " mu = ed.Normal(loc=0., scale=5., name=\"mu\") # `mu` above\n", | |
| " # Due to the lack of HalfCauchy distribution.\n", | |
| " log_tau = ed.Normal(\n", | |
| " loc=5., scale=1., name=\"log_tau\") # `log(tau)` above\n", | |
| " theta_prime = ed.Normal(\n", | |
| " loc=tf.zeros(cfg.num_schools),\n", | |
| " scale=tf.ones(cfg.num_schools),\n", | |
| " name=\"theta_prime\") # `theta_prime` above\n", | |
| " theta = mu + tf.exp(\n", | |
| " log_tau) * theta_prime # `theta` above\n", | |
| " y = ed.Normal(\n", | |
| " loc=theta,\n", | |
| " scale=np.float32(cfg.sigma),\n", | |
| " name=\"y\") # `y` above\n", | |
| " \n", | |
| " return y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<pymc4.model.base.Model at 0x2ac07cc1cdd8>" | |
| ] | |
| }, | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.observe(y = model.cfg.y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/opt/conda/lib/python3.6/site-packages/numpy-1.16.0.dev0+b47ed76-py3.6-linux-x86_64.egg/numpy/lib/type_check.py:549: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead\n", | |
| " 'a.item() instead', DeprecationWarning, stacklevel=1)\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Acceptance rate: 0.6102\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "trace = pm4.sample(model, num_burnin_steps=1000, num_results=5000)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'mu': array([ 1.2623473, 1.2623473, 1.2623473, ..., 10.174107 , 10.101565 ,\n", | |
| " 11.02169 ], dtype=float32),\n", | |
| " 'log_tau': array([2.798712 , 2.798712 , 2.798712 , ..., 2.237354 , 2.206139 ,\n", | |
| " 1.4436612], dtype=float32),\n", | |
| " 'theta_prime': array([[ 1.4495487 , 0.13190484, -0.55798876, ..., -0.39001518,\n", | |
| " 0.49999905, -0.14500381],\n", | |
| " [ 1.4495487 , 0.13190484, -0.55798876, ..., -0.39001518,\n", | |
| " 0.49999905, -0.14500381],\n", | |
| " [ 1.4495487 , 0.13190484, -0.55798876, ..., -0.39001518,\n", | |
| " 0.49999905, -0.14500381],\n", | |
| " ...,\n", | |
| " [ 2.3089426 , 0.6587325 , 0.04801591, ..., -0.14484467,\n", | |
| " 1.0311049 , -0.980293 ],\n", | |
| " [ 0.772265 , -1.3196737 , -2.4158447 , ..., 0.49178684,\n", | |
| " -0.6607568 , 0.24697188],\n", | |
| " [ 0.01715243, -1.6104103 , 0.7611127 , ..., -0.10088314,\n", | |
| " 1.5387309 , 0.6444084 ]], dtype=float32)}" | |
| ] | |
| }, | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "trace" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment