Created
January 12, 2016 17:55
-
-
Save kozo2/f2a391a6d65ed76ab5a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Logistic Regression with daru and statsample-glm\n", | |
| "\n", | |
| "In this notebook we'll see with some examples how the probability of a given outcome can be predicted with logistic regression using daru and statsample-glm." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/javascript": [ | |
| "if(window['d3'] === undefined ||\n", | |
| " window['Nyaplot'] === undefined){\n", | |
| " var path = {\"d3\":\"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min\",\"downloadable\":\"http://cdn.rawgit.com/domitry/d3-downloadable/master/d3-downloadable\"};\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| " var shim = {\"d3\":{\"exports\":\"d3\"},\"downloadable\":{\"exports\":\"downloadable\"}};\n", | |
| "\n", | |
| " require.config({paths: path, shim:shim});\n", | |
| "\n", | |
| "\n", | |
| "require(['d3'], function(d3){window['d3']=d3;console.log('finished loading d3');require(['downloadable'], function(downloadable){window['downloadable']=downloadable;console.log('finished loading downloadable');\n", | |
| "\n", | |
| "\tvar script = d3.select(\"head\")\n", | |
| "\t .append(\"script\")\n", | |
| "\t .attr(\"src\", \"http://cdn.rawgit.com/domitry/Nyaplotjs/master/release/nyaplot.js\")\n", | |
| "\t .attr(\"async\", true);\n", | |
| "\n", | |
| "\tscript[0][0].onload = script[0][0].onreadystatechange = function(){\n", | |
| "\n", | |
| "\n", | |
| "\t var event = document.createEvent(\"HTMLEvents\");\n", | |
| "\t event.initEvent(\"load_nyaplot\",false,false);\n", | |
| "\t window.dispatchEvent(event);\n", | |
| "\t console.log('Finished loading Nyaplotjs');\n", | |
| "\n", | |
| "\t};\n", | |
| "\n", | |
| "\n", | |
| "});});\n", | |
| "}\n" | |
| ], | |
| "text/plain": [ | |
| "\"if(window['d3'] === undefined ||\\n window['Nyaplot'] === undefined){\\n var path = {\\\"d3\\\":\\\"https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.5/d3.min\\\",\\\"downloadable\\\":\\\"http://cdn.rawgit.com/domitry/d3-downloadable/master/d3-downloadable\\\"};\\n\\n\\n\\n var shim = {\\\"d3\\\":{\\\"exports\\\":\\\"d3\\\"},\\\"downloadable\\\":{\\\"exports\\\":\\\"downloadable\\\"}};\\n\\n require.config({paths: path, shim:shim});\\n\\n\\nrequire(['d3'], function(d3){window['d3']=d3;console.log('finished loading d3');require(['downloadable'], function(downloadable){window['downloadable']=downloadable;console.log('finished loading downloadable');\\n\\n\\tvar script = d3.select(\\\"head\\\")\\n\\t .append(\\\"script\\\")\\n\\t .attr(\\\"src\\\", \\\"http://cdn.rawgit.com/domitry/Nyaplotjs/master/release/nyaplot.js\\\")\\n\\t .attr(\\\"async\\\", true);\\n\\n\\tscript[0][0].onload = script[0][0].onreadystatechange = function(){\\n\\n\\n\\t var event = document.createEvent(\\\"HTMLEvents\\\");\\n\\t event.initEvent(\\\"load_nyaplot\\\",false,false);\\n\\t window.dispatchEvent(event);\\n\\t console.log('Finished loading Nyaplotjs');\\n\\n\\t};\\n\\n\\n});});\\n}\\n\"" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "true" | |
| ] | |
| }, | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "require 'daru'\n", | |
| "require 'statsample-glm'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "For this notebook, we will utilize [this dataset](http://www.ats.ucla.edu/stat/data/binary.csv) denoting whether students got admission for a graduate degree program depending on their GRE scores, GPA and rank of the institute they did an undergraduate degree in (ranked from 1 to 4).\n", | |
| "\n", | |
| "It should be noted that statsample-glm does not yet support categorical data so the ranks will be treated as continuos." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "5489" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "require 'open-uri'\n", | |
| "content = open('http://www.ats.ucla.edu/stat/data/binary.csv')\n", | |
| "File.write(\"binary.csv\", content.read)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<table><tr><th colspan=\"5\">Daru::DataFrame:22838900 rows: 400 cols: 4<tr><th></th><th>admit</th><th>gpa</th><th>gre</th><th>rank</th></tr><tr><td>0</td><td>0</td><td>3.61</td><td>380</td><td>3</td></tr><tr><td>1</td><td>1</td><td>3.67</td><td>660</td><td>3</td></tr><tr><td>2</td><td>1</td><td>4</td><td>800</td><td>1</td></tr><tr><td>3</td><td>1</td><td>3.19</td><td>640</td><td>4</td></tr><tr><td>4</td><td>0</td><td>2.93</td><td>520</td><td>4</td></tr><tr><td>5</td><td>1</td><td>3</td><td>760</td><td>2</td></tr><tr><td>6</td><td>1</td><td>2.98</td><td>560</td><td>1</td></tr><tr><td>7</td><td>0</td><td>3.08</td><td>400</td><td>2</td></tr><tr><td>8</td><td>1</td><td>3.39</td><td>540</td><td>3</td></tr><tr><td>9</td><td>0</td><td>3.92</td><td>700</td><td>2</td></tr><tr><td>10</td><td>0</td><td>4</td><td>800</td><td>4</td></tr><tr><td>11</td><td>0</td><td>3.22</td><td>440</td><td>1</td></tr><tr><td>12</td><td>1</td><td>4</td><td>760</td><td>1</td></tr><tr><td>13</td><td>0</td><td>3.08</td><td>700</td><td>2</td></tr><tr><td>14</td><td>1</td><td>4</td><td>700</td><td>1</td></tr><tr><td>15</td><td>0</td><td>3.44</td><td>480</td><td>3</td></tr><tr><td>16</td><td>0</td><td>3.87</td><td>780</td><td>4</td></tr><tr><td>17</td><td>0</td><td>2.56</td><td>360</td><td>3</td></tr><tr><td>18</td><td>0</td><td>3.75</td><td>800</td><td>2</td></tr><tr><td>19</td><td>1</td><td>3.81</td><td>540</td><td>1</td></tr><tr><td>20</td><td>0</td><td>3.17</td><td>500</td><td>3</td></tr><tr><td>21</td><td>1</td><td>3.63</td><td>660</td><td>2</td></tr><tr><td>22</td><td>0</td><td>2.82</td><td>600</td><td>4</td></tr><tr><td>23</td><td>0</td><td>3.19</td><td>680</td><td>4</td></tr><tr><td>24</td><td>1</td><td>3.35</td><td>760</td><td>2</td></tr><tr><td>25</td><td>1</td><td>3.66</td><td>800</td><td>1</td></tr><tr><td>26</td><td>1</td><td>3.61</td><td>620</td><td>1</td></tr><tr><td>27</td><td>1</td><td>3.74</td><td>520</td><td>4</td></tr><tr><td>28</td><td>1</td><td>3.22</td><td>780</td><td>2</td></tr><tr><td>29</td><td>0</td><td>3.29</td><td>520</td><td>1</td></tr><tr><td>30</td><td>0</td><td>3.78</td><td>540</td><td>4</td></tr><tr><td>31</td><td>0</td><td>3.35</td><td>760</td><td>3</td></tr><tr><td>...</td><td>...</td><td>...</td><td>...</td><td>...</td></tr><tr><td>399</td><td>0</td><td>3.89</td><td>600</td><td>3</td></tr></table>" | |
| ], | |
| "text/plain": [ | |
| "\n", | |
| "#<Daru::DataFrame:22838900 @name = 722f9759-a5c9-4724-9181-f2b71cbbcedc @size = 400>\n", | |
| " admit gpa gre rank \n", | |
| " 0 0 3.61 380 3 \n", | |
| " 1 1 3.67 660 3 \n", | |
| " 2 1 4 800 1 \n", | |
| " 3 1 3.19 640 4 \n", | |
| " 4 0 2.93 520 4 \n", | |
| " 5 1 3 760 2 \n", | |
| " 6 1 2.98 560 1 \n", | |
| " 7 0 3.08 400 2 \n", | |
| " 8 1 3.39 540 3 \n", | |
| " 9 0 3.92 700 2 \n", | |
| " 10 0 4 800 4 \n", | |
| " 11 0 3.22 440 1 \n", | |
| " 12 1 4 760 1 \n", | |
| " 13 0 3.08 700 2 \n", | |
| " 14 1 4 700 1 \n", | |
| " ... ... ... ... ... \n" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "df = Daru::DataFrame.from_csv \"binary.csv\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Use the `Statsampel::GLM.compute` method for logisitic regression analysis.\n", | |
| "\n", | |
| "The first method in the `compute` function is the DataFrame object, followed by the Vector that is to be the dependent variable, and then the method to be used for the link function. Can be :logit, :probit, :poisson or :normal.\n", | |
| "\n", | |
| "The `coefficients` method calculates the coefficients of the GLM and returns them as a Hash." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "ename": "NoMethodError", | |
| "evalue": "undefined method `each' for :admit:Symbol", | |
| "output_type": "error", | |
| "traceback": [ | |
| "\u001b[31mNoMethodError\u001b[0m: undefined method `each' for :admit:Symbol", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/daru-0.1.1/lib/daru/dataframe.rb:2118:in `access_vector'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/daru-0.1.1/lib/daru/dataframe.rb:346:in `[]'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/statsample-glm-0.2.1/lib/statsample-glm/glm/base.rb:17:in `initialize'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/statsample-glm-0.2.1/lib/statsample-glm/glm/logistic.rb:8:in `initialize'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/statsample-glm-0.2.1/lib/statsample-glm/glm.rb:32:in `new'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/statsample-glm-0.2.1/lib/statsample-glm/glm.rb:32:in `compute'\u001b[0m", | |
| "\u001b[37m(pry):7:in `<main>'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:355:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:355:in `evaluate_ruby'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:323:in `handle_line'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:243:in `block (2 levels) in eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:242:in `catch'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:242:in `block in eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:241:in `catch'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:241:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/backend.rb:65:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/backend.rb:12:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/kernel.rb:87:in `execute_request'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/kernel.rb:47:in `dispatch'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/kernel.rb:37:in `run'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/command.rb:70:in `run_kernel'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/command.rb:34:in `run'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/bin/iruby:5:in `<top (required)>'\u001b[0m", | |
| "\u001b[37m/usr/local/bin/iruby:23:in `load'\u001b[0m", | |
| "\u001b[37m/usr/local/bin/iruby:23:in `<main>'\u001b[0m" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "glm = Statsample::GLM::compute df, :admit, :logistic, constant: 1\n", | |
| "c = glm.coefficients :hash" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "The logistic regression coefficients give the change in the log odds of the outcome for a one unit increase in the predictor variable.\n", | |
| "\n", | |
| "Therefore, to interpret each of the above co-efficients:\n", | |
| "* For every one unit change in gre, the log odds of admission (versus non-admission) increases by **0.002**.\n", | |
| "* For a one unit increase in gpa, the log odds of being admitted to graduate school increases by **0.777**.\n", | |
| "* For every increase in the rank number of the institute (aka decrease in quality of the institute), the log odds of being admitted to graduate school increase by **-0.56**.\n", | |
| "\n", | |
| "Log odds become a little difficult to interpret, so we'll exponentiate each of the co-efficients so that each co-efficient can be interpreted as an odds-ratio." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<table><tr><th colspan=\"2\">Daru::Vector:23665000 size: 0</th></tr><tr><th> </th><th>nil</th></tr></table>" | |
| ], | |
| "text/plain": [ | |
| "\n", | |
| "#<Daru::Vector:23665000 @name = nil @size = 0 >\n", | |
| " nil\n" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "Daru::Vector.new(c).exp # Calling `#exp` on Daru::Vector exponentiates each element of the Vector." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We can now compute the probability of gaining admission into a graduate college based on the rank of the undergraduate college, by keeping the GRE score and GPA constant.\n", | |
| "\n", | |
| "As you can see in the result below, the `rankp` Vector shows the probability of admission based on the rank. The person from the most highly rated undergrad school (rank 1) has a probability of **0.49** of getting admitted into graduate school." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "ename": "NoMethodError", | |
| "evalue": "undefined method `each' for :gre:Symbol", | |
| "output_type": "error", | |
| "traceback": [ | |
| "\u001b[31mNoMethodError\u001b[0m: undefined method `each' for :gre:Symbol", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/daru-0.1.1/lib/daru/dataframe.rb:2118:in `access_vector'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/daru-0.1.1/lib/daru/dataframe.rb:346:in `[]'\u001b[0m", | |
| "\u001b[37m(pry):12:in `<main>'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:355:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:355:in `evaluate_ruby'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:323:in `handle_line'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:243:in `block (2 levels) in eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:242:in `catch'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:242:in `block in eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:241:in `catch'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/pry-0.10.3/lib/pry/pry_instance.rb:241:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/backend.rb:65:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/backend.rb:12:in `eval'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/kernel.rb:87:in `execute_request'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/kernel.rb:47:in `dispatch'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/kernel.rb:37:in `run'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/command.rb:70:in `run_kernel'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/lib/iruby/command.rb:34:in `run'\u001b[0m", | |
| "\u001b[37m/var/lib/gems/2.2.0/gems/iruby-0.2.8/bin/iruby:5:in `<top (required)>'\u001b[0m", | |
| "\u001b[37m/usr/local/bin/iruby:23:in `load'\u001b[0m", | |
| "\u001b[37m/usr/local/bin/iruby:23:in `<main>'\u001b[0m" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "e = Math::E\n", | |
| "new_data = Daru::DataFrame.new({\n", | |
| " gre: [df[:gre].mean]*4,\n", | |
| " gpa: [df[:gpa].mean]*4,\n", | |
| " rank: df[:rank].factors\n", | |
| " })\n", | |
| "\n", | |
| "new_data[:rankp] = new_data.collect(:row) do |x|\n", | |
| " 1 / (1 + e ** -(c[:constant] + x[:gre] * c[:gre] + x[:gpa] * c[:gpa] + x[:rank] * c[:rank]))\n", | |
| "end\n", | |
| "\n", | |
| "new_data.sort! [:rank]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "To demonstrate with another example, lets create a hypothetical dataset consisting of the body weight of 20 people and whether they survived or not.\n", | |
| "\n", | |
| "For this example we will just assume that people with less body weight have lesser chances of survival." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<table><tr><th colspan=\"3\">Daru::DataFrame:23916960 rows: 20 cols: 2<tr><th></th><th>body_weight</th><th>survive</th></tr><tr><td>0</td><td>25.685645200554255</td><td>0</td></tr><tr><td>1</td><td>26.78483156683682</td><td>0</td></tr><tr><td>2</td><td>26.87777624900039</td><td>0</td></tr><tr><td>3</td><td>26.90663536835868</td><td>0</td></tr><tr><td>4</td><td>27.222761100195754</td><td>0</td></tr><tr><td>5</td><td>28.367501018305376</td><td>1</td></tr><tr><td>6</td><td>28.76992478247493</td><td>0</td></tr><tr><td>7</td><td>30.100726926660045</td><td>1</td></tr><tr><td>8</td><td>30.225681682863485</td><td>0</td></tr><tr><td>9</td><td>30.27850709664521</td><td>0</td></tr><tr><td>10</td><td>30.696817029906686</td><td>1</td></tr><tr><td>11</td><td>31.229602685861874</td><td>1</td></tr><tr><td>12</td><td>31.459910308413384</td><td>0</td></tr><tr><td>13</td><td>31.637355412858362</td><td>1</td></tr><tr><td>14</td><td>31.912897748166323</td><td>1</td></tr><tr><td>15</td><td>32.20231433535488</td><td>1</td></tr><tr><td>16</td><td>32.47059457772508</td><td>0</td></tr><tr><td>17</td><td>33.24768811130859</td><td>1</td></tr><tr><td>18</td><td>34.41320950803006</td><td>1</td></tr><tr><td>19</td><td>35.720554742060145</td><td>1</td></tr></table>" | |
| ], | |
| "text/plain": [ | |
| "\n", | |
| "#<Daru::DataFrame:23916960 @name = 09f0729e-acca-426d-b194-56933b7d9048 @size = 20>\n", | |
| " body_weigh survive \n", | |
| " 0 25.6856452 0 \n", | |
| " 1 26.7848315 0 \n", | |
| " 2 26.8777762 0 \n", | |
| " 3 26.9066353 0 \n", | |
| " 4 27.2227611 0 \n", | |
| " 5 28.3675010 1 \n", | |
| " 6 28.7699247 0 \n", | |
| " 7 30.1007269 1 \n", | |
| " 8 30.2256816 0 \n", | |
| " 9 30.2785070 0 \n", | |
| " 10 30.6968170 1 \n", | |
| " 11 31.2296026 1 \n", | |
| " 12 31.4599103 0 \n", | |
| " 13 31.6373554 1 \n", | |
| " 14 31.9128977 1 \n", | |
| " ... ... ... \n" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "require 'distribution'\n", | |
| "\n", | |
| "# Create a normally distributed Vector with mean 30 and standard deviation 2\n", | |
| "rng = Distribution::Normal.rng(30,2)\n", | |
| "body_weight = Daru::Vector.new(20.times.map { rng.call }.sort)\n", | |
| "\n", | |
| "# Populate chances of survival, assume that people with less body weight on average\n", | |
| "# are less likely to survive.\n", | |
| "survive = Daru::Vector.new [0,0,0,0,0,1,0,1,0,0,1,1,0,1,1,1,0,1,1,1]\n", | |
| "\n", | |
| "df = Daru::DataFrame.new({\n", | |
| " body_weight: body_weight,\n", | |
| " survive: survive\n", | |
| "})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Compute the logistic regression co-efficients." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{:body_weight=>0.6908834795544777, :constant=>-20.97009374018127}" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "glm = Statsample::GLM.compute df, :survive, :logistic, constant: 1\n", | |
| "coeffs = glm.coefficients :hash" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Based on the coefficients, we compute the predicted probabilities for each number in the Vector :body_weight and store them in another Vector called `:survive_pred`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<table><tr><th colspan=\"4\">Daru::DataFrame:23916960 rows: 20 cols: 3<tr><th></th><th>body_weight</th><th>survive</th><th>survive_pred</th></tr><tr><td>0</td><td>25.685645200554255</td><td>0</td><td>0.038261227923132045</td></tr><tr><td>1</td><td>26.78483156683682</td><td>0</td><td>0.07835602733619414</td></tr><tr><td>2</td><td>26.87777624900039</td><td>0</td><td>0.08312069297554081</td></tr><tr><td>3</td><td>26.90663536835868</td><td>0</td><td>0.08465290752507156</td></tr><tr><td>4</td><td>27.222761100195754</td><td>0</td><td>0.10318391508304724</td></tr><tr><td>5</td><td>28.367501018305376</td><td>1</td><td>0.20238472081599393</td></tr><tr><td>6</td><td>28.76992478247493</td><td>0</td><td>0.25097331344020524</td></tr><tr><td>7</td><td>30.100726926660045</td><td>1</td><td>0.45660972095200497</td></tr><tr><td>8</td><td>30.225681682863485</td><td>0</td><td>0.47809662543704295</td></tr><tr><td>9</td><td>30.27850709664521</td><td>0</td><td>0.487209440732106</td></tr><tr><td>10</td><td>30.696817029906686</td><td>1</td><td>0.5591788228860437</td></tr><tr><td>11</td><td>31.229602685861874</td><td>1</td><td>0.6470101841079281</td></tr><tr><td>12</td><td>31.459910308413384</td><td>0</td><td>0.6824466587497141</td></tr><tr><td>13</td><td>31.637355412858362</td><td>1</td><td>0.7084013521153286</td></tr><tr><td>14</td><td>31.912897748166323</td><td>1</td><td>0.7461153340564617</td></tr><tr><td>15</td><td>32.20231433535488</td><td>1</td><td>0.782101170759652</td></tr><tr><td>16</td><td>32.47059457772508</td><td>0</td><td>0.8120374356450982</td></tr><tr><td>17</td><td>33.24768811130859</td><td>1</td><td>0.8808164699056995</td></tr><tr><td>18</td><td>34.41320950803006</td><td>1</td><td>0.9429682341211665</td></tr><tr><td>19</td><td>35.720554742060145</td><td>1</td><td>0.9760757454325684</td></tr></table>" | |
| ], | |
| "text/plain": [ | |
| "\n", | |
| "#<Daru::DataFrame:23916960 @name = 09f0729e-acca-426d-b194-56933b7d9048 @size = 20>\n", | |
| " body_weigh survive survive_pr \n", | |
| " 0 25.6856452 0 0.03826122 \n", | |
| " 1 26.7848315 0 0.07835602 \n", | |
| " 2 26.8777762 0 0.08312069 \n", | |
| " 3 26.9066353 0 0.08465290 \n", | |
| " 4 27.2227611 0 0.10318391 \n", | |
| " 5 28.3675010 1 0.20238472 \n", | |
| " 6 28.7699247 0 0.25097331 \n", | |
| " 7 30.1007269 1 0.45660972 \n", | |
| " 8 30.2256816 0 0.47809662 \n", | |
| " 9 30.2785070 0 0.48720944 \n", | |
| " 10 30.6968170 1 0.55917882 \n", | |
| " 11 31.2296026 1 0.64701018 \n", | |
| " 12 31.4599103 0 0.68244665 \n", | |
| " 13 31.6373554 1 0.70840135 \n", | |
| " 14 31.9128977 1 0.74611533 \n", | |
| " ... ... ... ... \n" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "e = Math::E\n", | |
| "df[:survive_pred] = df[:body_weight].map { |x| 1 / (1 + e ** -(coeffs[:constant] + x*coeffs[:body_weight])) }\n", | |
| "df" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "The above results can then be plotted using the `plot` function.\n", | |
| "\n", | |
| "The curve looks is an ideal logit regression curve." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div id='vis-2a3f16da-dab6-4dae-897c-75d065627124'></div>\n", | |
| "<script>\n", | |
| "(function(){\n", | |
| " var render = function(){\n", | |
| " var model = {\"panes\":[{\"diagrams\":[{\"type\":\"scatter\",\"options\":{\"x\":\"body_weight\",\"y\":\"survive_pred\"},\"data\":\"83108d6e-9bba-4ba7-b155-31f1f379b89a\"},{\"type\":\"line\",\"options\":{\"x\":\"body_weight\",\"y\":\"survive_pred\"},\"data\":\"83108d6e-9bba-4ba7-b155-31f1f379b89a\"}],\"options\":{\"x_label\":\"Body Weight\",\"y_label\":\"Probability of Survival\",\"zoom\":true,\"width\":700,\"xrange\":[25.685645200554255,35.720554742060145],\"yrange\":[0.038261227923132045,0.9760757454325684]}}],\"data\":{\"83108d6e-9bba-4ba7-b155-31f1f379b89a\":[{\"body_weight\":25.685645200554255,\"survive\":0,\"survive_pred\":0.038261227923132045},{\"body_weight\":26.78483156683682,\"survive\":0,\"survive_pred\":0.07835602733619414},{\"body_weight\":26.87777624900039,\"survive\":0,\"survive_pred\":0.08312069297554081},{\"body_weight\":26.90663536835868,\"survive\":0,\"survive_pred\":0.08465290752507156},{\"body_weight\":27.222761100195754,\"survive\":0,\"survive_pred\":0.10318391508304724},{\"body_weight\":28.367501018305376,\"survive\":1,\"survive_pred\":0.20238472081599393},{\"body_weight\":28.76992478247493,\"survive\":0,\"survive_pred\":0.25097331344020524},{\"body_weight\":30.100726926660045,\"survive\":1,\"survive_pred\":0.45660972095200497},{\"body_weight\":30.225681682863485,\"survive\":0,\"survive_pred\":0.47809662543704295},{\"body_weight\":30.27850709664521,\"survive\":0,\"survive_pred\":0.487209440732106},{\"body_weight\":30.696817029906686,\"survive\":1,\"survive_pred\":0.5591788228860437},{\"body_weight\":31.229602685861874,\"survive\":1,\"survive_pred\":0.6470101841079281},{\"body_weight\":31.459910308413384,\"survive\":0,\"survive_pred\":0.6824466587497141},{\"body_weight\":31.637355412858362,\"survive\":1,\"survive_pred\":0.7084013521153286},{\"body_weight\":31.912897748166323,\"survive\":1,\"survive_pred\":0.7461153340564617},{\"body_weight\":32.20231433535488,\"survive\":1,\"survive_pred\":0.782101170759652},{\"body_weight\":32.47059457772508,\"survive\":0,\"survive_pred\":0.8120374356450982},{\"body_weight\":33.24768811130859,\"survive\":1,\"survive_pred\":0.8808164699056995},{\"body_weight\":34.41320950803006,\"survive\":1,\"survive_pred\":0.9429682341211665},{\"body_weight\":35.720554742060145,\"survive\":1,\"survive_pred\":0.9760757454325684}]},\"extension\":[]}\n", | |
| " var id_name = '#vis-2a3f16da-dab6-4dae-897c-75d065627124';\n", | |
| " Nyaplot.core.parse(model, id_name);\n", | |
| "\n", | |
| " require(['downloadable'], function(downloadable){\n", | |
| " var svg = d3.select(id_name).select(\"svg\");\n", | |
| "\t if(!svg.empty())\n", | |
| "\t svg.call(downloadable().filename('fig'));\n", | |
| "\t});\n", | |
| " };\n", | |
| " if(window['Nyaplot']==undefined){\n", | |
| " window.addEventListener('load_nyaplot', render, false);\n", | |
| "\treturn;\n", | |
| " } else {\n", | |
| " render();\n", | |
| " }\n", | |
| "})();\n", | |
| "</script>\n" | |
| ], | |
| "text/plain": [ | |
| "#<Nyaplot::Frame:0x00000003004800 @properties={:panes=>[#<Nyaplot::Plot:0x00000002ff1750 @properties={:diagrams=>[#<Nyaplot::Diagram:0x000000030053b8 @properties={:type=>:scatter, :options=>{:x=>:body_weight, :y=>:survive_pred}, :data=>\"83108d6e-9bba-4ba7-b155-31f1f379b89a\"}, @xrange=[25.685645200554255, 35.720554742060145], @yrange=[0.038261227923132045, 0.9760757454325684]>, #<Nyaplot::Diagram:0x00000003004dc8 @properties={:type=>:line, :options=>{:x=>:body_weight, :y=>:survive_pred}, :data=>\"83108d6e-9bba-4ba7-b155-31f1f379b89a\"}, @xrange=[25.685645200554255, 35.720554742060145], @yrange=[0.038261227923132045, 0.9760757454325684]>], :options=>{:x_label=>\"Body Weight\", :y_label=>\"Probability of Survival\", :zoom=>true, :width=>700, :xrange=>[25.685645200554255, 35.720554742060145], :yrange=>[0.038261227923132045, 0.9760757454325684]}}>], :data=>{\"83108d6e-9bba-4ba7-b155-31f1f379b89a\"=>#<Nyaplot::DataFrame:0x00000003005fe8 @name=\"83108d6e-9bba-4ba7-b155-31f1f379b89a\", @rows=[{:body_weight=>25.685645200554255, :survive=>0, :survive_pred=>0.038261227923132045}, {:body_weight=>26.78483156683682, :survive=>0, :survive_pred=>0.07835602733619414}, {:body_weight=>26.87777624900039, :survive=>0, :survive_pred=>0.08312069297554081}, {:body_weight=>26.90663536835868, :survive=>0, :survive_pred=>0.08465290752507156}, {:body_weight=>27.222761100195754, :survive=>0, :survive_pred=>0.10318391508304724}, {:body_weight=>28.367501018305376, :survive=>1, :survive_pred=>0.20238472081599393}, {:body_weight=>28.76992478247493, :survive=>0, :survive_pred=>0.25097331344020524}, {:body_weight=>30.100726926660045, :survive=>1, :survive_pred=>0.45660972095200497}, {:body_weight=>30.225681682863485, :survive=>0, :survive_pred=>0.47809662543704295}, {:body_weight=>30.27850709664521, :survive=>0, :survive_pred=>0.487209440732106}, {:body_weight=>30.696817029906686, :survive=>1, :survive_pred=>0.5591788228860437}, {:body_weight=>31.229602685861874, :survive=>1, :survive_pred=>0.6470101841079281}, {:body_weight=>31.459910308413384, :survive=>0, :survive_pred=>0.6824466587497141}, {:body_weight=>31.637355412858362, :survive=>1, :survive_pred=>0.7084013521153286}, {:body_weight=>31.912897748166323, :survive=>1, :survive_pred=>0.7461153340564617}, {:body_weight=>32.20231433535488, :survive=>1, :survive_pred=>0.782101170759652}, {:body_weight=>32.47059457772508, :survive=>0, :survive_pred=>0.8120374356450982}, {:body_weight=>33.24768811130859, :survive=>1, :survive_pred=>0.8808164699056995}, {:body_weight=>34.41320950803006, :survive=>1, :survive_pred=>0.9429682341211665}, {:body_weight=>35.720554742060145, :survive=>1, :survive_pred=>0.9760757454325684}]>}, :extension=>[]}>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "df.plot type: [:scatter,:line], x: [:body_weight]*2, y: [:survive_pred]*2 do |plot, diagram|\n", | |
| " plot.x_label \"Body Weight\"\n", | |
| " plot.y_label \"Probability of Survival\"\n", | |
| "end" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Ruby 2.2.3", | |
| "language": "ruby", | |
| "name": "ruby" | |
| }, | |
| "language_info": { | |
| "file_extension": ".rb", | |
| "mimetype": "application/x-ruby", | |
| "name": "ruby", | |
| "version": "2.2.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment