Last active
October 16, 2020 19:28
-
-
Save ahartikainen/8713171d259718cf737d8a483500e0c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pystan\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "stan_code = \"\"\"\n", | |
| "parameters {\n", | |
| " real<lower=0> a;\n", | |
| " matrix[3,4] B;\n", | |
| "}\n", | |
| "model {\n", | |
| " a ~ normal(0,1);\n", | |
| " for (n in 1:3) {\n", | |
| " for (m in 1:4) {\n", | |
| " B[n,m] ~ normal(0,2);\n", | |
| " }\n", | |
| " }\n", | |
| "}\n", | |
| "\"\"\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_a71ba528c20fc622bc4c49e3064eafab NOW.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model = pystan.StanModel(model_code=stan_code)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "fit = model.sampling(iter=1, warmup=0, init=0, seed=1, control={\"adapt_engaged\": False}, check_hmc_diagnostics=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "sample = fit.extract(permuted=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "lp = sample[\"lp__\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "sample = {key: values for key, values in sample.items() if not key.endswith(\"__\")}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'a': array([0.3855873 , 0.98054701, 0.49286373, 1.11726411]),\n", | |
| " 'B': array([[[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n", | |
| " [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n", | |
| " [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]],\n", | |
| " \n", | |
| " [[ 4.1422781 , 3.68825208, -0.17252071, -1.68395211],\n", | |
| " [-2.86501378, 0.33404798, -1.44207571, -1.80443805],\n", | |
| " [-1.10181232, 0.2255809 , -0.37776586, 0.5961288 ]],\n", | |
| " \n", | |
| " [[-0.22384381, 0.3285261 , 0.27917157, 2.15246902],\n", | |
| " [ 1.5110607 , -1.25719988, 0.80260026, -0.40612884],\n", | |
| " [ 1.11420704, 0.18069843, -1.24951964, -0.30942338]],\n", | |
| " \n", | |
| " [[ 0.27401254, -1.66957992, -0.26767151, -0.44180366],\n", | |
| " [ 0.25835625, -0.16191208, 0.95659342, 1.78234786],\n", | |
| " [ 0.26586376, 0.91323483, 0.18991228, 0.13852963]]])}" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sample" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([-2.34083794, -6.63107518, -2.38813117, -1.54752603])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "lp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "a (4,)\n", | |
| "B (4, 3, 4)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for key, values in sample.items():\n", | |
| " print(key, values.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'a': 0.38558730390174756,\n", | |
| " 'B': array([[ 0.31295903, 1.59176432, 0.25751408, -0.54202137],\n", | |
| " [-0.28700276, 0.00954776, 0.80052233, 0.21879722],\n", | |
| " [ 0.07881552, -0.20168992, 1.10617236, 2.33979838]])}" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# get one draw\n", | |
| "example_dict = {key: values[0] for key, values in sample.items()}\n", | |
| "example_dict" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "a ()\n", | |
| "B (3, 4)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for key, values in example_dict.items():\n", | |
| " print(key, values.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[[], [3, 4]]" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "fit.par_dims" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([-0.95298764, 0.31295903, -0.28700276, 0.07881552, 1.59176432,\n", | |
| " 0.00954776, -0.20168992, 0.25751408, 0.80052233, 1.10617236,\n", | |
| " -0.54202137, 0.21879722, 2.33979838])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "unconstrained = fit.unconstrain_pars(example_dict)\n", | |
| "unconstrained" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "['a',\n", | |
| " 'B.1.1',\n", | |
| " 'B.2.1',\n", | |
| " 'B.3.1',\n", | |
| " 'B.1.2',\n", | |
| " 'B.2.2',\n", | |
| " 'B.3.2',\n", | |
| " 'B.1.3',\n", | |
| " 'B.2.3',\n", | |
| " 'B.3.3',\n", | |
| " 'B.1.4',\n", | |
| " 'B.2.4',\n", | |
| " 'B.3.4']" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# this is the order expected, but unconstrain_pars handles that\n", | |
| "fit.unconstrained_param_names()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Calculate log_prob" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "-2.3408379411644353" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "fit.log_prob(unconstrained)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "-2.3408379411644353" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "lp[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "True" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "fit.log_prob(unconstrained) == lp[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([-0.14867757, -0.07823976, 0.07175069, -0.01970388, -0.39794108,\n", | |
| " -0.00238694, 0.05042248, -0.06437852, -0.20013058, -0.27654309,\n", | |
| " 0.13550534, -0.0546993 , -0.5849496 ])" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "fit.grad_log_prob(unconstrained, adjust_transform=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([ 0.85132243, -0.07823976, 0.07175069, -0.01970388, -0.39794108,\n", | |
| " -0.00238694, 0.05042248, -0.06437852, -0.20013058, -0.27654309,\n", | |
| " 0.13550534, -0.0546993 , -0.5849496 ])" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# the Jacobian adjustment\n", | |
| "fit.grad_log_prob(unconstrained, adjust_transform=True)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment