Created
March 4, 2016 15:53
-
-
Save daob/d3db5c60892654f212b2 to your computer and use it in GitHub Desktop.
An implementation of a finite mixture model with covariates ("latent class model") in TensorFlow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import tensorflow as tf \n", | |
| "import numpy as np\n", | |
| "import pandas as pd" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def weight_variable(shape):\n", | |
| " initial = tf.truncated_normal(shape, stddev=0.1)\n", | |
| " return tf.Variable(initial)\n", | |
| "\n", | |
| "def bias_variable(shape):\n", | |
| " initial = tf.constant(0.1, shape=shape)\n", | |
| " return tf.Variable(initial)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Latent class example on data generated in R" | |
| ] | |
| }, | |
| { | |
| "cell_type": "raw", | |
| "metadata": {}, | |
| "source": [ | |
| "n <- 1000 #1e8\n", | |
| "J <- 3\n", | |
| "\n", | |
| "set.seed(654)\n", | |
| "\n", | |
| "Z1 <- rbinom(n, size = 1, prob = 0.5)\n", | |
| "Z2 <- rbinom(n, size = 1, prob = 0.5)\n", | |
| "\n", | |
| "X <- rbinom(n, size = 1, prob = plogis(1 + 0.3*Z1 - 99*Z2))\n", | |
| "\n", | |
| "Y <- matrix(rbinom(n*J, size = 1, prob = 0.3 + 0.4*X), n)\n", | |
| "\n", | |
| "dat <- data.frame(X = X, Y=Y, Z1 = Z1, Z2 = Z2)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### LG5.0 output\n", | |
| "\n", | |
| "This shows the output obtained when estimating the correct model in Latent GOLD 5.0.0.14161 on these data, in one go. Default settings are used except Bayes constant which is set to 0. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "raw", | |
| "metadata": {}, | |
| "source": [ | |
| "options\n", | |
| " maxthreads=4;\n", | |
| " algorithm \n", | |
| " tolerance=1e-008 emtolerance=0.01 emiterations=550 nriterations=50 ;\n", | |
| " startvalues\n", | |
| " seed=0 sets=30 tolerance=1e-005 iterations=50;\n", | |
| " bayes\n", | |
| " categorical=0 variances=0 latent=0 poisson=0;\n", | |
| " montecarlo\n", | |
| " seed=0 sets=0 replicates=500 tolerance=1e-008;\n", | |
| " quadrature nodes=10;\n", | |
| " missing includeall;\n", | |
| " output \n", | |
| " parameters=first betaopts=wl standarderrors profile probmeans=posterior\n", | |
| " frequencies bivariateresiduals classification estimatedvalues=regression\n", | |
| " predictionstatistics iterationdetails;\n", | |
| "variables\n", | |
| " dependent Y.1, Y.2, Y.3;\n", | |
| " independent Z1, Z2;\n", | |
| " latent\n", | |
| " Class nominal 2;\n", | |
| "equations\n", | |
| " Class <- 1 + Z1 + Z2;\n", | |
| " Y.1 <- 1 + Class;\n", | |
| " Y.2 <- 1 + Class;\n", | |
| " Y.3 <- 1 + Class;\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "raw", | |
| "metadata": {}, | |
| "source": [ | |
| "Regression Parameters\t\t\t\t\t\t\n", | |
| "term\t\t\tcoef\tWald(0)\tdf\tp-value\n", | |
| "Class(1)\t<-\t=\"1\"\t0.0000\t6.8382\t1\t0.0090\n", | |
| "Class(2)\t<-\t=\"1\"\t-0.7978\t\t\t\n", | |
| "Class(1)\t<-\tZ1\t0.0000\t1.3411\t1\t0.25\n", | |
| "Class(2)\t<-\tZ1\t-0.3994\t\t\t\n", | |
| "Class(1)\t<-\tZ2\t0.0000\t0.0010\t1\t0.97\n", | |
| "Class(2)\t<-\tZ2\t31.6722\t\t\t\n", | |
| "\t\t\t\t\t\t\n", | |
| "Y.1(0)\t<-\t=\"1\"\t0.0000\t21.5333\t1\t3.5e-6\n", | |
| "Y.1(1)\t<-\t=\"1\"\t0.7336\t\t\t\n", | |
| "Y.1\t<-\tClass(1)\t0.0000\t65.0564\t1\t7.3e-16\n", | |
| "Y.1\t<-\tClass(2)\t-1.5342\t\t\t\n", | |
| "\t\t\t\t\t\t\n", | |
| "Y.2(0)\t<-\t=\"1\"\t0.0000\t35.8329\t1\t2.2e-9\n", | |
| "Y.2(1)\t<-\t=\"1\"\t1.1681\t\t\t\n", | |
| "Y.2\t<-\tClass(1)\t0.0000\t87.7631\t1\t7.4e-21\n", | |
| "Y.2\t<-\tClass(2)\t-2.0554\t\t\t\n", | |
| "\t\t\t\t\t\t\n", | |
| "Y.3(0)\t<-\t=\"1\"\t0.0000\t35.4458\t1\t2.6e-9\n", | |
| "Y.3(1)\t<-\t=\"1\"\t0.9487\t\t\t\n", | |
| "Y.3\t<-\tClass(1)\t0.0000\t75.5371\t1\t3.6e-18\n", | |
| "Y.3\t<-\tClass(2)\t-1.6387\t\t\t\n", | |
| "\n", | |
| " \tClass\t \t \n", | |
| " \t1\t2\tOverall\n", | |
| "Size\t0.3722\t0.6278\t \n", | |
| "Y.1\t \t\t\n", | |
| "0\t0.3244\t0.6901\t0.5540\n", | |
| "1\t0.6756\t0.3099\t0.4460\n", | |
| "Mean\t0.6756\t0.3099\t0.4460\n", | |
| "Y.2\t \t\t\n", | |
| "0\t0.2372\t0.7083\t0.5330\n", | |
| "1\t0.7628\t0.2917\t0.4670\n", | |
| "Mean\t0.7628\t0.2917\t0.4670\n", | |
| "Y.3\t \t\t\n", | |
| "0\t0.2791\t0.6660\t0.5220\n", | |
| "1\t0.7209\t0.3340\t0.4780\n", | |
| "Mean\t0.7209\t0.3340\t0.4780\n", | |
| "\n", | |
| " \t \t \t \t \t \tClass\t \t \n", | |
| "Z1\tZ2\tY.1\tY.2\tY.3\tObsFreq\tModal\t1\t2\n", | |
| "0\t0\t0\t0\t0\t27.0000\t2\t0.1278\t0.8722\n", | |
| "0\t0\t0\t0\t1\t33.0000\t2\t0.4300\t0.5700\n", | |
| "0\t0\t0\t1\t0\t28.0000\t1\t0.5337\t0.4663\n", | |
| "0\t0\t0\t1\t1\t33.0000\t1\t0.8549\t0.1451\n", | |
| "0\t0\t1\t0\t0\t18.0000\t2\t0.4046\t0.5954\n", | |
| "0\t0\t1\t0\t1\t35.0000\t1\t0.7777\t0.2223\n", | |
| "0\t0\t1\t1\t0\t32.0000\t1\t0.8414\t0.1586\n", | |
| "0\t0\t1\t1\t1\t72.0000\t1\t0.9647\t0.0353\n", | |
| "0\t1\t0\t0\t0\t78.0000\t2\t0.0000\t1.0000\n", | |
| "0\t1\t0\t0\t1\t40.0000\t2\t0.0000\t1.0000\n", | |
| "0\t1\t0\t1\t0\t29.0000\t2\t0.0000\t1.0000\n", | |
| "0\t1\t0\t1\t1\t13.0000\t2\t0.0000\t1.0000\n", | |
| "0\t1\t1\t0\t0\t40.0000\t2\t0.0000\t1.0000\n", | |
| "0\t1\t1\t0\t1\t18.0000\t2\t0.0000\t1.0000\n", | |
| "0\t1\t1\t1\t0\t16.0000\t2\t0.0000\t1.0000\n", | |
| "0\t1\t1\t1\t1\t6.0000\t2\t0.0000\t1.0000\n", | |
| "1\t0\t0\t0\t0\t23.0000\t2\t0.1793\t0.8207\n", | |
| "1\t0\t0\t0\t1\t18.0000\t1\t0.5293\t0.4707\n", | |
| "1\t0\t0\t1\t0\t25.0000\t1\t0.6305\t0.3695\n", | |
| "1\t0\t0\t1\t1\t34.0000\t1\t0.8978\t0.1022\n", | |
| "1\t0\t1\t0\t0\t13.0000\t1\t0.5033\t0.4967\n", | |
| "1\t0\t1\t0\t1\t19.0000\t1\t0.8391\t0.1609\n", | |
| "1\t0\t1\t1\t0\t28.0000\t1\t0.8878\t0.1122\n", | |
| "1\t0\t1\t1\t1\t75.0000\t1\t0.9760\t0.0240\n", | |
| "1\t1\t0\t0\t0\t75.0000\t2\t0.0000\t1.0000\n", | |
| "1\t1\t0\t0\t1\t39.0000\t2\t0.0000\t1.0000\n", | |
| "1\t1\t0\t1\t0\t45.0000\t2\t0.0000\t1.0000\n", | |
| "1\t1\t0\t1\t1\t14.0000\t2\t0.0000\t1.0000\n", | |
| "1\t1\t1\t0\t0\t35.0000\t2\t0.0000\t1.0000\n", | |
| "1\t1\t1\t0\t1\t22.0000\t2\t0.0000\t1.0000\n", | |
| "1\t1\t1\t1\t0\t10.0000\t2\t0.0000\t1.0000\n", | |
| "1\t1\t1\t1\t1\t7.0000\t2\t0.0000\t1.0000\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Logistic regression of Y1 on Z1, Z2\n", | |
| "\n", | |
| "This is just some legacy from when I tried it out just doing a dumb logistic regrssion. It worked so I'm not using this anymore." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#x = tf.placeholder(tf.float32, [None, 2])\n", | |
| "##y_obs = tf.placeholder(tf.float32, shape=[None, 1])\n", | |
| "\n", | |
| "#a = bias_variable([1])\n", | |
| "#b = weight_variable([2, 1])\n", | |
| "\n", | |
| "#y_pred = tf.nn.sigmoid(tf.matmul(x, b) + a)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Latent class model with two covariates" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Performs streaming updates of a LCM. The approach is to simply specify the likelihood and let tensorflow do the heavy lifting. Afterwards we can obtain the first and second derivatives, even of parameters not updated in the model. Although TensorFlow makes that more difficult than Theano does." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Model definition" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#tf.reset_default_graph()\n", | |
| "\n", | |
| "#sess.close()\n", | |
| "sess = tf.Session()\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Observed variables\n", | |
| "x = tf.placeholder(tf.float32, [None, 2]) # Covariates\n", | |
| "y_obs = tf.placeholder(tf.float32, shape=[None, 3]) # Dependent \"items\"\n", | |
| "\n", | |
| "# Parameters\n", | |
| "a = bias_variable([1]) # LC logistic intercept\n", | |
| "b = weight_variable([2, 1]) # LC logistic slopes wrt Z's\n", | |
| "\n", | |
| "tau = bias_variable([3]) # Item logistic intercepts\n", | |
| "lam = weight_variable([1, 3]) # Item logistic slopes wrt LC\n", | |
| "\n", | |
| "# P(X | Z)\n", | |
| "eta_pred = tf.nn.sigmoid(tf.matmul(x, b) + a)\n", | |
| "\n", | |
| "# Takes a prediction for Y=1 and transforms it to \n", | |
| "# a prediction of Y=1 whereever Y=1 and\n", | |
| "# a prediction of Y=0 whereever Y=0 \n", | |
| "def transform_pred1_to_lik(p):\n", | |
| " return((y_obs * p) + ((1-y_obs) * (1-p)))\n", | |
| "\n", | |
| "# P(Y_j | X)\n", | |
| "# Could include a zero effect of Z1 here and not update to get derivs\n", | |
| "y_pred1 = transform_pred1_to_lik(tf.nn.sigmoid(lam + tau)) \n", | |
| "y_pred2 = transform_pred1_to_lik(tf.nn.sigmoid(tau))\n", | |
| "\n", | |
| "# P(Y | X, Z) = P(Y | X)\n", | |
| "# Takes the prediciton for each item and applies conditional independence rule to yield joint\n", | |
| "ones = np.array([[1,],[1,],[1,],], dtype = np.float32)\n", | |
| "y_pred1_joint = tf.exp(tf.matmul(tf.log(y_pred1), ones))\n", | |
| "y_pred2_joint = tf.exp(tf.matmul(tf.log(y_pred2), ones))\n", | |
| "\n", | |
| "# P(Y | Z)\n", | |
| "# Mixture model\n", | |
| "y_joint = (eta_pred * y_pred1_joint) + ((1 - eta_pred) * y_pred2_joint)\n", | |
| "\n", | |
| "# P(X | Y, Z)\n", | |
| "# Posterior for class 1\n", | |
| "eta_post = (eta_pred * y_pred1_joint) / y_joint\n", | |
| "\n", | |
| " " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Objective and optimization definition" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "min2ll = -2*tf.reduce_sum(tf.log(y_joint)) # independent obs's\n", | |
| "#cross_entropy = -tf.reduce_sum(y_obs * tf.log(y_pred))\n", | |
| "\n", | |
| "#train_step = tf.train.AdamOptimizer().minimize(min2ll)\n", | |
| "#train_step = tf.train.GradientDescentOptimizer(0.01).minimize(min2ll)\n", | |
| "#train_step = tf.train.RMSPropOptimizer(.1).minimize(min2ll)\n", | |
| "\n", | |
| "global_step = tf.Variable(0, trainable=False)\n", | |
| "\n", | |
| "starter_learning_rate = 0.1\n", | |
| "learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,\n", | |
| " 100, 0.96, staircase=True)\n", | |
| "train_step = tf.train.AdamOptimizer(learning_rate).minimize(min2ll, global_step=global_step)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "\n", | |
| "init = tf.initialize_all_variables()\n", | |
| "sess.run(init)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Summaries for TensorBoard. This breaks when rerunnign the Jupyter cells several times but should work the first time. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Build the summary operation based on the TF collection of Summaries.\n", | |
| "tf.train.write_graph(sess.graph_def, '/tmp/lca_logs','graph.pbtxt')\n", | |
| "\n", | |
| "tf.histogram_summary(\"a:\", a)\n", | |
| "tf.histogram_summary('b', b)\n", | |
| "tf.histogram_summary('tau', tau)\n", | |
| "tf.histogram_summary('lam', lam)\n", | |
| "tf.scalar_summary('-2*log-likelihood', min2ll)\n", | |
| "\n", | |
| "summary_op = tf.merge_all_summaries()\n", | |
| "summary_writer = tf.train.SummaryWriter('/tmp/lca_logs',sess.graph_def)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Actually running the optimization" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "chunk_size = 1000\n", | |
| "epochs = 1 # One is enough for the 10 Million big dataset; for 1000 records, 10 epochs are needed\n", | |
| "\n", | |
| "#boot_size = 10" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "This makes use of Pandas option to read in files chunk by chunk so I don't have to first read everything into memory and then subselect certain pieces of it. Could use `skiprows = test_size` to denominate some part the validation set. Or just use a separate file. `dat_big` has 10 million rows and `dat` has 1000. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 916, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Not possible with big data set\n", | |
| "#feed_full = {x: df[['Z1', 'Z2',]], y_obs: df[['Y.1','Y.2','Y.3',]]}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| " Epoch: Obs: -2LL: a: b[0]: b[1]:\n", | |
| " 0 0.0e+00 3770.68 -0.9515 -0.2136 4.8678\n", | |
| " 0 5.0e+05 3788.94 -1.0114 -0.3555 6.0652\n", | |
| " 0 1.0e+06 3814.81 -0.9851 -0.1853 6.4288\n", | |
| " 0 1.5e+06 3830.98 -1.0326 -0.3319 6.6422\n", | |
| " 0 2.0e+06 3779.44 -1.0616 -0.3019 6.7667\n", | |
| " 0 2.5e+06 3825.46 -0.9909 -0.3216 6.9539\n", | |
| " 0 3.0e+06 3843.16 -1.0418 -0.4212 7.1230\n", | |
| " 0 3.5e+06 3866.26 -0.9937 -0.2220 7.3417\n", | |
| " 0 4.0e+06 3834.61 -0.9681 -0.3630 7.3372\n", | |
| " 0 4.5e+06 3796.56 -1.0323 -0.2355 7.4924\n", | |
| " 0 5.0e+06 3857.98 -1.0719 -0.2346 7.3724\n", | |
| " 0 5.5e+06 3777.30 -1.0074 -0.2907 7.4575\n", | |
| " 0 6.0e+06 3819.84 -0.9815 -0.3449 7.5354\n", | |
| " 0 6.5e+06 3848.10 -0.9803 -0.3837 7.7356\n", | |
| " 0 7.0e+06 3781.37 -1.1016 -0.3132 7.7613\n", | |
| " 0 7.5e+06 3841.09 -1.0411 -0.3044 7.8689\n", | |
| " 0 8.0e+06 3841.88 -1.0126 -0.3186 7.8211\n", | |
| " 0 8.5e+06 3797.86 -0.9804 -0.3154 7.8435\n", | |
| " 0 9.0e+06 3851.94 -0.9967 -0.3104 7.8968\n", | |
| " 0 9.5e+06 3822.74 -0.9823 -0.2963 7.9561\n", | |
| " 0 1.0e+07 3818.26 -0.9748 -0.2913 7.9589\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def print_it(feed_chunk):\n", | |
| " LL = sess.run(min2ll, feed_dict=feed_chunk)\n", | |
| " #g = sess.run(tf.gradients(min2ll, b), feed_dict=feed_chunk)[0]\n", | |
| " print(\"{:10d}{:10.1e}{:10.2f}{:10.4f}{:10.4f}{:10.4f}\".format(j, i*chunk_size, LL, sess.run(a)[0], \n", | |
| " sess.run(b)[0][0], sess.run(b)[1][0])) #, g[0][0], g[1][0]))\n", | |
| "\n", | |
| "\n", | |
| "print(\"{:>10s}{:>10s}{:>10s}{:>10s}{:>10s}{:>10s}\".format(\"Epoch:\", \"Obs:\", \"-2LL:\", \"a:\", \"b[0]:\", \"b[1]:\"))\n", | |
| " #, \"dlL/db[0]\", \"dlL/db[1]\"))\n", | |
| " \n", | |
| "for j in range(epochs):\n", | |
| " \n", | |
| " # Need to read it in again (?) to rewind the file\n", | |
| " df = pd.read_csv(\"/Users/daob/Downloads/tensorflow/dat_big.csv\", chunksize = chunk_size, iterator = True)\n", | |
| " i = 0\n", | |
| " \n", | |
| " for chunk in df:\n", | |
| " #start, end = (i * batch_size, (i + 1) * batch_size)\n", | |
| " #xi = np.asarray(df[['Z1', 'Z2',]][start:end], dtype = \"float32\")\n", | |
| " #yi = np.asarray(df[['Y.1','Y.2','Y.3',]][start:end], dtype = \"float32\")\n", | |
| " #wi = np.array([np.random.poisson(1, batch_size) for i in range(boot_size)])\n", | |
| " \n", | |
| " xi = np.asarray(chunk[['Z1', 'Z2',]], dtype = \"float32\")\n", | |
| " yi = np.asarray(chunk[['Y.1','Y.2','Y.3',]], dtype = \"float32\")\n", | |
| " \n", | |
| " feed_chunk = {x: xi, y_obs: yi}\n", | |
| " sess.run(train_step, feed_dict = feed_chunk)\n", | |
| " \n", | |
| " # TensorBoard stuff\n", | |
| " summary_str = sess.run(summary_op, feed_dict = feed_chunk)\n", | |
| " summary_writer.add_summary(summary_str, i)\n", | |
| " \n", | |
| " if (i % 500 == 0):\n", | |
| " print_it(feed_chunk)\n", | |
| " \n", | |
| " #if i >= 100:\n", | |
| " # break # DEBUG\n", | |
| " \n", | |
| " i += 1\n", | |
| "\n", | |
| "print_it(feed_chunk)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Some output" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[-0.9747526]\n", | |
| "[[-0.29132286]\n", | |
| " [ 7.95889664]]\n", | |
| "[ 0.85168254 0.84284556 0.84828174]\n", | |
| "[[-1.69697344 -1.69441473 -1.6916467 ]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print sess.run(a)\n", | |
| "print sess.run(b)\n", | |
| "print sess.run(tau)\n", | |
| "print sess.run(lam)\n", | |
| "\n", | |
| "# {R} Y~X : -0.8516 1.7104 \n", | |
| "# {R} X ~ Z1 + Z2: 1.021 0.294 -21.742 " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "\n", | |
| "These parameter estimates give the right answers. A small feature here is that TF apparently does some stabilization of logit coefficients. E.g. instead of getting 99 we get 8 for the covariate effect.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[ 0.98837703]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.98450828]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.98450828]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.98450828]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99710965]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.98450828]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.4682596 ]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.13927341]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99783838]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99711758]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99784988]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.78176117]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.98837703]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.39688635]\n", | |
| " [ 0.46762258]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.02886732]\n", | |
| " [ 0.02173034]\n", | |
| " [ 0.99990243]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.39627409]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.9999271 ]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.1389419 ]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99946916]\n", | |
| " [ 0.99960428]\n", | |
| " [ 0.8273958 ]\n", | |
| " [ 0.99712497]\n", | |
| " [ 0.10811879]\n", | |
| " [ 0.1395804 ]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.99946773]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.99947059]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.99784428]\n", | |
| " [ 0.9996022 ]\n", | |
| " [ 0.10760621]\n", | |
| " [ 0.39754918]\n", | |
| " [ 0.99960315]\n", | |
| " [ 0.10787231]\n", | |
| " [ 0.46894887]\n", | |
| " [ 0.9999271 ]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print sess.run(eta_post, feed_dict = feed_chunk)\n", | |
| "\n", | |
| "# The first observation's (1000 dataset) should be about 0.85 and the second about 0.10.\n", | |
| "# The last observation's (1000 dataset) should be about 0.1451 and the secondtolast about 0.1022 and the thirdtolast 0.22 and 5th 1.000.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### EPC's\n", | |
| "\n", | |
| "I thought it would be really easy to get derivatives and Hessians for all parameters, including new ones. It's not, mostly because TensorFlow (unlike Theano!) makes it difficult for the user to obtain these for sets of parameters together. You need to manually calculate all the off-diagonals of the Hessian, and combine all the results manually as well. A pain. But possible, so here is a proof of concept, calculating an (approcximate) EPC for the direct effects of Z1 and Z2 on the items. It is still a little strange because the hypothetical alternative model has a delta that is different for each Z but the same over Y's. \n", | |
| "\n", | |
| "Maybe the best idea would be to calculate these on the test set since they are a kind of model fit criterion.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 697, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "delta = tf.Variable(np.asarray([[0.],[0.]], dtype = np.float32)) # Item bias logistic slopes wrt Z's\n", | |
| "\n", | |
| "# New, equivalent, version of the model in which there is item bias delta but it's fixed to zero:\n", | |
| "y_pred1 = transform_pred1_to_lik(tf.nn.sigmoid(lam + tau + tf.matmul(x, delta))) \n", | |
| "y_pred2 = transform_pred1_to_lik(tf.nn.sigmoid(tau + tf.matmul(x, delta)))\n", | |
| "\n", | |
| "# P(Y | X, Z) = P(Y | X)\n", | |
| "# Takes the prediciton for each item and applies conditional independence rule to yield joint\n", | |
| "ones = np.array([[1,],[1,],[1,],], dtype = np.float32)\n", | |
| "y_pred1_joint = tf.exp(tf.matmul(tf.log(y_pred1), ones))\n", | |
| "y_pred2_joint = tf.exp(tf.matmul(tf.log(y_pred2), ones))\n", | |
| "\n", | |
| "# P(Y | Z)\n", | |
| "# Mixture model\n", | |
| "y_joint = (eta_pred * y_pred1_joint) + ((1 - eta_pred) * y_pred2_joint)\n", | |
| "\n", | |
| "# P(X | Y, Z)\n", | |
| "# Posterior for class 1\n", | |
| "eta_post = (eta_pred * y_pred1_joint) / y_joint\n", | |
| "\n", | |
| "min2ll = -2*tf.reduce_sum(tf.log(y_joint)) # independent obs's\n", | |
| "\n", | |
| "# Only initialize the new parameter\n", | |
| "init_new = tf.initialize_variables([delta,])\n", | |
| "sess.run(init_new)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 724, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[-1.0839479 0.51302063 1.17497635 1.11338735 -1.55259323 -1.98827338\n", | |
| " -1.55796206 -0.39906496 6.30396605]\n", | |
| "[-1.0839479 0.51302063 1.17497635 1.11338735 -1.55259323 -1.98827338\n", | |
| " -1.55796206 -0.39906496 6.30396605 0. 0. ]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "theta = tf.concat(0,[a, tau, tf.reshape(lam, [-1]), tf.reshape(b, [-1])])\n", | |
| "theta_aug = tf.concat(0, [theta, tf.reshape(delta, [-1])])\n", | |
| "\n", | |
| "print sess.run(theta)\n", | |
| "print sess.run(theta_aug)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 761, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "3861.33\n", | |
| "g_delta:\n", | |
| "[[ 23.57310677]\n", | |
| " [ 34.95653152]]\n", | |
| "H_delta:\n", | |
| "[[ 841.83892822]\n", | |
| " [ 932.23809814]]\n", | |
| "d delta0 / d delta:\n", | |
| "313.186\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Check that -2lL is still the same\n", | |
| "# These are not allowed, not sure how to accomplish this:\n", | |
| "#grad_theta = tf.gradients(min2ll, theta)\n", | |
| "#grad_theta_aug = tf.gradients(min2ll, theta_aug)\n", | |
| "\n", | |
| "print sess.run(min2ll, feed_dict=feed_full) \n", | |
| "\n", | |
| "# Unfortunately TensorFlow (unlike Theano!) does not support proper Jacobians/Hessians, \n", | |
| "# so this needs to be done by hand :*(\n", | |
| "g_a = tf.gradients(min2ll, a)\n", | |
| "g_b = tf.gradients(min2ll, b)\n", | |
| "g_tau = tf.gradients(min2ll, tau)\n", | |
| "g_lam = tf.gradients(min2ll, lam)\n", | |
| "g_delta = tf.gradients(min2ll, delta)\n", | |
| "\n", | |
| "g = [g_a, g_b, g_tau, g_lam, g_delta]\n", | |
| "#print [sess.run(g_vec, feed_dict=feed_full) for g_vec in g]\n", | |
| "\n", | |
| "print \"g_delta:\"\n", | |
| "g_delta_val = sess.run(g_delta, feed_dict=feed_full)[0]\n", | |
| "print g_delta_val\n", | |
| "print \"H_delta:\"\n", | |
| "H_delta_val = sess.run(tf.gradients(g_delta, delta), feed_dict=feed_full)[0]\n", | |
| "print H_delta_val\n", | |
| "print \"d delta0 / d delta:\"\n", | |
| "H_delta_val_01 = sess.run(tf.gradients(tf.slice(g_delta[0], [0,0], [1,1]), delta), feed_dict=feed_full)[0][1][0]\n", | |
| "print H_delta_val_01\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 771, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[ 0.01605898]\n", | |
| " [ 0.0321024 ]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "H = np.array([[H_delta_val[0][0], H_delta_val_01],[H_delta_val_01, H_delta_val[1][0]]])\n", | |
| "\n", | |
| "epc_approx = np.matmul(np.linalg.inv(H), g_delta_val)\n", | |
| "print epc_approx" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "These should be small since the population has been generated with them being zero. \n", | |
| "\n", | |
| "\n", | |
| "**TODO**: \n", | |
| "\n", | |
| " * Here I'm using the full dataset to calculate the gradients, which is \"cheating\". It should be possible to already include delta in the model above but tell the optimizer not to update this Variable. Then a running update can be made of the gradient and hessian of all parameters, including these fixed ones. \n", | |
| "\n", | |
| " * Online Bootstrapping with Poisson weights (do everything once for each bs weight, only changing that -2lL * wi is used)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 2", | |
| "language": "python", | |
| "name": "python2" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.11" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment