Instantly share code, notes, and snippets.
Created
July 25, 2017 00:17
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
-
Save bninopaul/d3359299a383f7ff202c0abb84214f82 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 75, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "import string\n", | |
| "import codecs\n", | |
| "pd.set_option('display.max_columns', None)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 91, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class hmm(object):\n", | |
| " def __init__(self):\n", | |
| " pass\n", | |
| "\n", | |
| " def _forward(self, obs):\n", | |
| " \"\"\"\n", | |
| " Compute the scaled forward probability matrix and scaling factors.\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " obs : ndarray of shape (T,)\n", | |
| " The observation sequence\n", | |
| " Returns\n", | |
| " -------\n", | |
| " alpha : ndarray of shape (T,N)\n", | |
| " The scaled forward probability matrix\n", | |
| " c : ndarray of shape (T,)\n", | |
| " The scaling factors c = [c_1,c_2,...,c_T]\n", | |
| " \"\"\"\n", | |
| " M,N = self.B.shape\n", | |
| " T = len(obs)\n", | |
| " c, alpha = np.zeros(T), np.zeros((T,N))\n", | |
| " c[0] = 1/np.dot(self.pi, self.B[obs[0]])\n", | |
| " alpha[0] = c[0]*(self.pi*self.B[obs[0]])\n", | |
| " for t in range(1, T):\n", | |
| " c[t] = 1/np.dot(np.dot(self.A, alpha[t-1]), self.B[obs[t]])\n", | |
| " alpha[t] = c[t]*(np.dot(self.A, alpha[t-1])*self.B[obs[t]])\n", | |
| " return alpha, c\n", | |
| " \n", | |
| " def _backward(self, obs, c):\n", | |
| " \"\"\"\n", | |
| " Compute the scaled backward probability matrix.\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " obs : ndarray of shape (T,)\n", | |
| " The observation sequence\n", | |
| " c : ndarray of shape (T,)\n", | |
| " The scaling factors from the forward pass\n", | |
| " Returns\n", | |
| " -------\n", | |
| " beta : ndarray of shape (T,N)\n", | |
| " The scaled backward probability matrix\n", | |
| " \"\"\"\n", | |
| " M,N = self.B.shape\n", | |
| " T = len(obs)\n", | |
| " beta = np.zeros((T,N))\n", | |
| " beta[T-1] = c[T-1]\n", | |
| " for t in range(T-2,-1,-1):\n", | |
| " beta[t] = np.dot(c[t]*self.A.T,(self.B[obs[t+1]]*beta[t+1]))\n", | |
| " return beta\n", | |
| " def _delta(self, obs, alpha, beta):\n", | |
| " \"\"\"\n", | |
| " Compute the delta probabilities.\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " obs : ndarray of shape (T,)\n", | |
| " The observation sequence\n", | |
| " alpha : ndarray of shape (T,N)\n", | |
| " The scaled forward probability matrix from the forward pass\n", | |
| " beta : ndarray of shape (T,N)\n", | |
| " The scaled backward probability matrix from the backward pass\n", | |
| " Returns\n", | |
| " -------\n", | |
| " delta : ndarray of shape (T-1,N,N)\n", | |
| " The delta probability array\n", | |
| " gamma : ndarray of shape (T,N)\n", | |
| " The gamma probability array\n", | |
| " \"\"\"\n", | |
| " M,N = self.B.shape\n", | |
| " T = len(obs)\n", | |
| " delta = np.zeros((T-1, N, N))\n", | |
| " gamma = np.zeros((T, N))\n", | |
| " \n", | |
| " for t in range(T-1): \n", | |
| " for i in range(N):\n", | |
| " for j in range(N):\n", | |
| " delta[t,i,j] = alpha[t,i]*self.A[j,i]*self.B[obs[t+1],j]*beta[t+1,j]\n", | |
| " gamma[t,i] = np.sum(delta[t,i])\n", | |
| " gamma[T-1] = alpha[T-1]*beta[T-1]/np.dot(alpha[T-1], beta[T-1])\n", | |
| " return delta, gamma\n", | |
| " def _estimate(self, obs, delta, gamma):\n", | |
| " \"\"\"\n", | |
| " Estimate better parameter values.\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " obs : ndarray of shape (T,)\n", | |
| " The observation sequence\n", | |
| " delta : ndarray of shape (T-1,N,N)\n", | |
| " The delta probability array\n", | |
| " gamma : ndarray of shape (T,N)\n", | |
| " The gamma probability array\n", | |
| " \"\"\"\n", | |
| " # update self.A, self.B, self.pi in place\n", | |
| " M,N = self.B.shape\n", | |
| " T = len(obs)\n", | |
| " for i in range(N):\n", | |
| " for j in range(N):\n", | |
| " self.A[i,j] = np.sum(delta[:,j,i])/np.sum(gamma[:-1,j])\n", | |
| " for i in range(M):\n", | |
| " for j in range(N):\n", | |
| " self.B[i,j] = np.sum(gamma[:,j][obs==i])/np.sum(gamma[:,j])\n", | |
| " self.pi = gamma[0,:]\n", | |
| " \n", | |
| " def fit(self, obs, N, A=None, B=None, pi=None, max_iter=100, tol=1e-3):\n", | |
| " \"\"\"\n", | |
| " Fit the model parameters to a given observation sequence.\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " obs : ndarray of shape (T,)\n", | |
| " Observation sequence on which to train the model.\n", | |
| " A : stochastic ndarray of shape (N,N)\n", | |
| " Initialization of state transition matrix\n", | |
| " B : stochastic ndarray of shape (M,N)\n", | |
| " Initialization of state observation matrix\n", | |
| " pi : stochastic ndarray of shape (N,)\n", | |
| " Initialization of initial state distribution\n", | |
| " max_iter : integer\n", | |
| " The maximum number of iterations to take\n", | |
| " tol : float\n", | |
| " The convergence threshold for change in log-probability\n", | |
| " \"\"\"\n", | |
| " # initialize self.A, self.B, self.pi\n", | |
| " # run the iteration\n", | |
| " M = len(set(obs))\n", | |
| " if A is not None:\n", | |
| " self.A = A/np.sum(A, axis=0) #make sure it is column-stochastic\n", | |
| " N,_ = A.shape\n", | |
| " else:\n", | |
| " self.A = np.random.dirichlet(np.ones(N), size=N).T\n", | |
| " \n", | |
| " if B is not None:\n", | |
| " self.B = B/np.sum(B, axis=0)\n", | |
| " else:\n", | |
| " self.B = np.random.dirichlet(np.ones(M), size=N).T\n", | |
| " \n", | |
| " if pi is not None:\n", | |
| " self.pi = pi/np.sum(pi)\n", | |
| " else:\n", | |
| " self.pi = np.random.dirichlet(np.ones(N))\n", | |
| " \n", | |
| " alpha, c = self._forward(obs)\n", | |
| " log_lkhood0 = - np.sum(np.log(c))\n", | |
| " for i in range(max_iter):\n", | |
| " if i%10==0:\n", | |
| " print(\"Iter: %d, log_lkhood: %f\"%(i, log_lkhood0))\n", | |
| " alpha, c = self._forward(obs)\n", | |
| " beta = self._backward(obs, c)\n", | |
| " delta, gamma = self._delta(obs, alpha, beta)\n", | |
| " \n", | |
| " #update params\n", | |
| " self._estimate(obs, delta, gamma)\n", | |
| " \n", | |
| " alpha, c = self._forward(obs)\n", | |
| " log_lkhood1 = - np.sum(np.log(c))\n", | |
| " if abs(log_lkhood0-log_lkhood1)< tol:\n", | |
| " break\n", | |
| " log_lkhood0=log_lkhood1\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 77, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#toy data\n", | |
| "A = np.array([[.7, .4],[.3, .6]])\n", | |
| "B = np.array([[.1,.7],[.4, .2],[.5, .1]])\n", | |
| "pi = np.array([.6, .4])\n", | |
| "obs = np.array([0, 1, 0, 2])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 78, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "-4.6429135909\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "#test for forward pass\n", | |
| "h = hmm()\n", | |
| "h.A = A\n", | |
| "h.B = B\n", | |
| "h.pi = pi\n", | |
| "alpha, c = h._forward(obs)\n", | |
| "print(-np.log(c).sum()) # the log prob of observation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 3" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 79, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 3.1361635 , 2.89939354],\n", | |
| " [ 2.86699344, 4.39229044],\n", | |
| " [ 3.898812 , 2.66760821],\n", | |
| " [ 3.56816483, 3.56816483]])" | |
| ] | |
| }, | |
| "execution_count": 79, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#test the backward pass\n", | |
| "beta = h._backward(obs, c)\n", | |
| "beta" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 4" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 80, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "delta, gamma = h._delta(obs, alpha, beta)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 81, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[[ 0.14166321, 0.0465066 ],\n", | |
| " [ 0.37776855, 0.43406164]],\n", | |
| "\n", | |
| " [[ 0.17015868, 0.34927307],\n", | |
| " [ 0.05871895, 0.4218493 ]],\n", | |
| "\n", | |
| " [[ 0.21080834, 0.01806929],\n", | |
| " [ 0.59317106, 0.17795132]]])" | |
| ] | |
| }, | |
| "execution_count": 81, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "delta" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 82, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0.18816981, 0.81183019],\n", | |
| " [ 0.51943175, 0.48056825],\n", | |
| " [ 0.22887763, 0.77112237],\n", | |
| " [ 0.8039794 , 0.1960206 ]])" | |
| ] | |
| }, | |
| "execution_count": 82, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "gamma" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "source": [ | |
| "### Problem 5" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 83, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[ 0.55807991 0.49898142]\n", | |
| " [ 0.44192009 0.50101858]]\n", | |
| "[[ 0.23961928 0.70056364]\n", | |
| " [ 0.29844534 0.21268397]\n", | |
| " [ 0.46193538 0.08675238]]\n", | |
| "[ 0.18816981 0.81183019]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "h._estimate(obs, delta, gamma)\n", | |
| "print(h.A)\n", | |
| "print(h.B)\n", | |
| "print(h.pi)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 6" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 84, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def vec_translate(a, my_dict):\n", | |
| " # translate numpy array from symbols to state numbers or vice versa\n", | |
| " return np.vectorize(my_dict.__getitem__)(a)\n", | |
| "def prep_data(filename):\n", | |
| " # Get the data as a single string\n", | |
| " with codecs.open(filename, encoding='utf-8') as f:\n", | |
| " data=f.read().lower() #and convert to all lower case\n", | |
| " #remove punctuation and newlines\n", | |
| " remove_punct_map = {ord(char): None for char in string.punctuation+\"\\n\\r\"}\n", | |
| " data = data.translate(remove_punct_map)\n", | |
| "\n", | |
| " # make a list of the symbols in the data\n", | |
| " symbols = sorted(list(set(data)))\n", | |
| " \n", | |
| " # convert the data to a NumPy array of symbols\n", | |
| " a = np.array(list(data))\n", | |
| " \n", | |
| " #make a conversion dictionary from symbols to state numbers\n", | |
| " symbols_to_obsstates = {x:i for i,x in enumerate(symbols)}\n", | |
| " \n", | |
| " #convert the symbols in a to state numbers\n", | |
| " obs_sequence = vec_translate(a,symbols_to_obsstates)\n", | |
| " return symbols, obs_sequence" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 85, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "symbols, obs = prep_data('declaration.txt')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 7" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 93, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Iter: 0, log_lkhood: -27984.694651\n", | |
| "Iter: 10, log_lkhood: -22096.451772\n", | |
| "Iter: 20, log_lkhood: -21855.613115\n", | |
| "Iter: 30, log_lkhood: -21652.715537\n", | |
| "Iter: 40, log_lkhood: -21579.418872\n", | |
| "Iter: 50, log_lkhood: -21552.504780\n", | |
| "Iter: 60, log_lkhood: -21536.276039\n", | |
| "Iter: 70, log_lkhood: -21522.869274\n", | |
| "Iter: 80, log_lkhood: -21513.910457\n", | |
| "Iter: 90, log_lkhood: -21510.120438\n", | |
| "Iter: 100, log_lkhood: -21508.393164\n", | |
| "Iter: 110, log_lkhood: -21507.256295\n", | |
| "Iter: 120, log_lkhood: -21506.391586\n", | |
| "Iter: 130, log_lkhood: -21505.756749\n", | |
| "Iter: 140, log_lkhood: -21505.322756\n", | |
| "Iter: 150, log_lkhood: -21505.042934\n", | |
| "Iter: 160, log_lkhood: -21504.869024\n", | |
| "Iter: 170, log_lkhood: -21504.762593\n", | |
| "Iter: 180, log_lkhood: -21504.697269\n", | |
| "Iter: 190, log_lkhood: -21504.656458\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "hmm_ = hmm()\n", | |
| "hmm_.fit(obs, N=2, max_iter=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 99, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th>a</th>\n", | |
| " <th>b</th>\n", | |
| " <th>c</th>\n", | |
| " <th>d</th>\n", | |
| " <th>e</th>\n", | |
| " <th>f</th>\n", | |
| " <th>g</th>\n", | |
| " <th>h</th>\n", | |
| " <th>i</th>\n", | |
| " <th>j</th>\n", | |
| " <th>k</th>\n", | |
| " <th>l</th>\n", | |
| " <th>m</th>\n", | |
| " <th>n</th>\n", | |
| " <th>o</th>\n", | |
| " <th>p</th>\n", | |
| " <th>q</th>\n", | |
| " <th>r</th>\n", | |
| " <th>s</th>\n", | |
| " <th>t</th>\n", | |
| " <th>u</th>\n", | |
| " <th>v</th>\n", | |
| " <th>w</th>\n", | |
| " <th>x</th>\n", | |
| " <th>y</th>\n", | |
| " <th>z</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0.301</td>\n", | |
| " <td>0.131</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.235</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.123</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.14</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.057</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.011</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>0.046</td>\n", | |
| " <td></td>\n", | |
| " <td>0.023</td>\n", | |
| " <td>0.044</td>\n", | |
| " <td>0.06</td>\n", | |
| " <td></td>\n", | |
| " <td>0.043</td>\n", | |
| " <td>0.031</td>\n", | |
| " <td>0.083</td>\n", | |
| " <td></td>\n", | |
| " <td>0.004</td>\n", | |
| " <td>0.003</td>\n", | |
| " <td>0.055</td>\n", | |
| " <td>0.035</td>\n", | |
| " <td>0.116</td>\n", | |
| " <td></td>\n", | |
| " <td>0.033</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.102</td>\n", | |
| " <td>0.115</td>\n", | |
| " <td>0.153</td>\n", | |
| " <td></td>\n", | |
| " <td>0.018</td>\n", | |
| " <td>0.023</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.009</td>\n", | |
| " <td>0.001</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " a b c d e f g h i j \\\n", | |
| "0 0.301 0.131 0.235 0.001 0.123 \n", | |
| "1 0.046 0.023 0.044 0.06 0.043 0.031 0.083 0.004 \n", | |
| "\n", | |
| " k l m n o p q r s t u \\\n", | |
| "0 0.14 0.057 \n", | |
| "1 0.003 0.055 0.035 0.116 0.033 0.001 0.102 0.115 0.153 \n", | |
| "\n", | |
| " v w x y z \n", | |
| "0 0.011 \n", | |
| "1 0.018 0.023 0.002 0.009 0.001 " | |
| ] | |
| }, | |
| "execution_count": 99, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#probability emission matrix for the letters and the space\n", | |
| "B = pd.DataFrame(np.around(hmm_.B, 3), index=symbols).T\n", | |
| "B.replace(B[B==0], \" \")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 8" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 101, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Iter: 0, log_lkhood: -25103.772576\n", | |
| "Iter: 10, log_lkhood: -21481.699828\n", | |
| "Iter: 20, log_lkhood: -20978.434993\n", | |
| "Iter: 30, log_lkhood: -20951.726801\n", | |
| "Iter: 40, log_lkhood: -20949.497910\n", | |
| "Iter: 50, log_lkhood: -20949.070237\n", | |
| "Iter: 60, log_lkhood: -20948.771320\n", | |
| "Iter: 70, log_lkhood: -20948.301344\n", | |
| "Iter: 80, log_lkhood: -20947.924887\n", | |
| "Iter: 90, log_lkhood: -20947.816960\n", | |
| "Iter: 100, log_lkhood: -20947.796648\n", | |
| "Iter: 110, log_lkhood: -20947.791887\n", | |
| "Iter: 120, log_lkhood: -20947.790252\n", | |
| "Iter: 130, log_lkhood: -20947.789541\n", | |
| "Iter: 140, log_lkhood: -20947.789191\n", | |
| "Iter: 150, log_lkhood: -20947.789006\n", | |
| "Iter: 160, log_lkhood: -20947.788900\n", | |
| "Iter: 170, log_lkhood: -20947.788836\n", | |
| "Iter: 180, log_lkhood: -20947.788795\n", | |
| "Iter: 190, log_lkhood: -20947.788765\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# hmm with 3 states\n", | |
| "hmm_3 = hmm()\n", | |
| "hmm_3.fit(obs, N=3, max_iter=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 103, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th>a</th>\n", | |
| " <th>b</th>\n", | |
| " <th>c</th>\n", | |
| " <th>d</th>\n", | |
| " <th>e</th>\n", | |
| " <th>f</th>\n", | |
| " <th>g</th>\n", | |
| " <th>h</th>\n", | |
| " <th>i</th>\n", | |
| " <th>j</th>\n", | |
| " <th>k</th>\n", | |
| " <th>l</th>\n", | |
| " <th>m</th>\n", | |
| " <th>n</th>\n", | |
| " <th>o</th>\n", | |
| " <th>p</th>\n", | |
| " <th>q</th>\n", | |
| " <th>r</th>\n", | |
| " <th>s</th>\n", | |
| " <th>t</th>\n", | |
| " <th>u</th>\n", | |
| " <th>v</th>\n", | |
| " <th>w</th>\n", | |
| " <th>x</th>\n", | |
| " <th>y</th>\n", | |
| " <th>z</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0.359</td>\n", | |
| " <td>0.068</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.213</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.096</td>\n", | |
| " <td>0.113</td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.101</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.003</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.038</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.006</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.008</td>\n", | |
| " <td>0.043</td>\n", | |
| " <td>0.086</td>\n", | |
| " <td>0.043</td>\n", | |
| " <td>0.056</td>\n", | |
| " <td>0.043</td>\n", | |
| " <td>0.012</td>\n", | |
| " <td></td>\n", | |
| " <td>0.005</td>\n", | |
| " <td>0.004</td>\n", | |
| " <td>0.055</td>\n", | |
| " <td>0.039</td>\n", | |
| " <td>0.049</td>\n", | |
| " <td></td>\n", | |
| " <td>0.024</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.106</td>\n", | |
| " <td>0.129</td>\n", | |
| " <td>0.218</td>\n", | |
| " <td></td>\n", | |
| " <td>0.025</td>\n", | |
| " <td>0.028</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.021</td>\n", | |
| " <td>0.001</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>0.072</td>\n", | |
| " <td>0.156</td>\n", | |
| " <td>0.044</td>\n", | |
| " <td>0.035</td>\n", | |
| " <td></td>\n", | |
| " <td>0.023</td>\n", | |
| " <td>0.01</td>\n", | |
| " <td>0.004</td>\n", | |
| " <td></td>\n", | |
| " <td>0.049</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.038</td>\n", | |
| " <td>0.018</td>\n", | |
| " <td>0.209</td>\n", | |
| " <td>0.111</td>\n", | |
| " <td>0.041</td>\n", | |
| " <td></td>\n", | |
| " <td>0.064</td>\n", | |
| " <td>0.060</td>\n", | |
| " <td></td>\n", | |
| " <td>0.052</td>\n", | |
| " <td></td>\n", | |
| " <td>0.01</td>\n", | |
| " <td>0.003</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " a b c d e f g h i \\\n", | |
| "0 0.359 0.068 0.213 0.096 0.113 \n", | |
| "1 0.008 0.043 0.086 0.043 0.056 0.043 0.012 \n", | |
| "2 0.072 0.156 0.044 0.035 0.023 0.01 0.004 0.049 \n", | |
| "\n", | |
| " j k l m n o p q r s \\\n", | |
| "0 0.001 0.001 0.101 0.003 0.001 \n", | |
| "1 0.005 0.004 0.055 0.039 0.049 0.024 0.002 0.106 0.129 \n", | |
| "2 0.038 0.018 0.209 0.111 0.041 0.064 0.060 \n", | |
| "\n", | |
| " t u v w x y z \n", | |
| "0 0.001 0.038 0.006 \n", | |
| "1 0.218 0.025 0.028 0.002 0.021 0.001 \n", | |
| "2 0.052 0.01 0.003 " | |
| ] | |
| }, | |
| "execution_count": 103, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#probability emission matrix for the letters and the space\n", | |
| "B = pd.DataFrame(np.around(hmm_3.B, 3), index=symbols).T\n", | |
| "B.replace(B[B==0], \" \")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 102, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Iter: 0, log_lkhood: -26491.366237\n", | |
| "Iter: 10, log_lkhood: -21084.722971\n", | |
| "Iter: 20, log_lkhood: -20639.194883\n", | |
| "Iter: 30, log_lkhood: -20594.095902\n", | |
| "Iter: 40, log_lkhood: -20584.392092\n", | |
| "Iter: 50, log_lkhood: -20580.165584\n", | |
| "Iter: 60, log_lkhood: -20577.197153\n", | |
| "Iter: 70, log_lkhood: -20573.903547\n", | |
| "Iter: 80, log_lkhood: -20567.329903\n", | |
| "Iter: 90, log_lkhood: -20549.905328\n", | |
| "Iter: 100, log_lkhood: -20539.691760\n", | |
| "Iter: 110, log_lkhood: -20537.063214\n", | |
| "Iter: 120, log_lkhood: -20533.772442\n", | |
| "Iter: 130, log_lkhood: -20526.231557\n", | |
| "Iter: 140, log_lkhood: -20520.965193\n", | |
| "Iter: 150, log_lkhood: -20516.225241\n", | |
| "Iter: 160, log_lkhood: -20510.880567\n", | |
| "Iter: 170, log_lkhood: -20506.710549\n", | |
| "Iter: 180, log_lkhood: -20504.449060\n", | |
| "Iter: 190, log_lkhood: -20503.471245\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# hmm with 4 states\n", | |
| "hmm_4 = hmm()\n", | |
| "hmm_4.fit(obs, N=4, max_iter=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 104, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th>a</th>\n", | |
| " <th>b</th>\n", | |
| " <th>c</th>\n", | |
| " <th>d</th>\n", | |
| " <th>e</th>\n", | |
| " <th>f</th>\n", | |
| " <th>g</th>\n", | |
| " <th>h</th>\n", | |
| " <th>i</th>\n", | |
| " <th>j</th>\n", | |
| " <th>k</th>\n", | |
| " <th>l</th>\n", | |
| " <th>m</th>\n", | |
| " <th>n</th>\n", | |
| " <th>o</th>\n", | |
| " <th>p</th>\n", | |
| " <th>q</th>\n", | |
| " <th>r</th>\n", | |
| " <th>s</th>\n", | |
| " <th>t</th>\n", | |
| " <th>u</th>\n", | |
| " <th>v</th>\n", | |
| " <th>w</th>\n", | |
| " <th>x</th>\n", | |
| " <th>y</th>\n", | |
| " <th>z</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td></td>\n", | |
| " <td>0.007</td>\n", | |
| " <td>0.004</td>\n", | |
| " <td></td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.663</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.192</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.029</td>\n", | |
| " <td>0.044</td>\n", | |
| " <td>0.02</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.037</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td></td>\n", | |
| " <td>0.124</td>\n", | |
| " <td>0.044</td>\n", | |
| " <td>0.05</td>\n", | |
| " <td></td>\n", | |
| " <td>0.046</td>\n", | |
| " <td>0.015</td>\n", | |
| " <td>0.017</td>\n", | |
| " <td></td>\n", | |
| " <td>0.056</td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.028</td>\n", | |
| " <td></td>\n", | |
| " <td>0.165</td>\n", | |
| " <td>0.138</td>\n", | |
| " <td>0.037</td>\n", | |
| " <td></td>\n", | |
| " <td>0.021</td>\n", | |
| " <td>0.053</td>\n", | |
| " <td>0.14</td>\n", | |
| " <td>0.043</td>\n", | |
| " <td></td>\n", | |
| " <td>0.023</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>0.635</td>\n", | |
| " <td>0.107</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.165</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.01</td>\n", | |
| " <td></td>\n", | |
| " <td>0.005</td>\n", | |
| " <td>0.01</td>\n", | |
| " <td></td>\n", | |
| " <td>0.018</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.049</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.031</td>\n", | |
| " <td>0.096</td>\n", | |
| " <td></td>\n", | |
| " <td>0.058</td>\n", | |
| " <td>0.037</td>\n", | |
| " <td>0.134</td>\n", | |
| " <td></td>\n", | |
| " <td>0.006</td>\n", | |
| " <td>0.004</td>\n", | |
| " <td>0.064</td>\n", | |
| " <td>0.047</td>\n", | |
| " <td>0.056</td>\n", | |
| " <td></td>\n", | |
| " <td>0.016</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.133</td>\n", | |
| " <td>0.129</td>\n", | |
| " <td>0.117</td>\n", | |
| " <td></td>\n", | |
| " <td>0.028</td>\n", | |
| " <td>0.019</td>\n", | |
| " <td>0.003</td>\n", | |
| " <td>0.015</td>\n", | |
| " <td>0.002</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " a b c d e f g h i \\\n", | |
| "0 0.007 0.004 0.002 0.663 \n", | |
| "1 0.124 0.044 0.05 0.046 0.015 0.017 0.056 \n", | |
| "2 0.635 0.107 0.001 0.165 \n", | |
| "3 0.031 0.096 0.058 0.037 0.134 \n", | |
| "\n", | |
| " j k l m n o p q r s \\\n", | |
| "0 0.001 0.002 0.192 0.029 \n", | |
| "1 0.001 0.028 0.165 0.138 0.037 0.021 0.053 \n", | |
| "2 0.01 0.005 0.01 0.018 \n", | |
| "3 0.006 0.004 0.064 0.047 0.056 0.016 0.002 0.133 0.129 \n", | |
| "\n", | |
| " t u v w x y z \n", | |
| "0 0.044 0.02 0.037 \n", | |
| "1 0.14 0.043 0.023 \n", | |
| "2 0.049 \n", | |
| "3 0.117 0.028 0.019 0.003 0.015 0.002 " | |
| ] | |
| }, | |
| "execution_count": 104, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#probability emission matrix for the letters and the space\n", | |
| "B = pd.DataFrame(np.around(hmm_4.B, 3), index=symbols).T\n", | |
| "B.replace(B[B==0], \" \")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Problem 9\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 105, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "symbols, obs = prep_data('WarAndPeace.txt')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 108, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Iter: 0, log_lkhood: -207306.469625\n", | |
| "Iter: 10, log_lkhood: -167057.782077\n", | |
| "Iter: 20, log_lkhood: -159928.818875\n", | |
| "Iter: 30, log_lkhood: -158425.865835\n", | |
| "Iter: 40, log_lkhood: -158388.675449\n", | |
| "Iter: 50, log_lkhood: -158386.868491\n", | |
| "Iter: 60, log_lkhood: -158385.870035\n", | |
| "Iter: 70, log_lkhood: -158385.682049\n", | |
| "Iter: 80, log_lkhood: -158385.635383\n", | |
| "Iter: 90, log_lkhood: -158385.618570\n", | |
| "Iter: 100, log_lkhood: -158385.612079\n", | |
| "Iter: 110, log_lkhood: -158385.609413\n", | |
| "Iter: 120, log_lkhood: -158385.608213\n", | |
| "Iter: 130, log_lkhood: -158385.607596\n", | |
| "Iter: 140, log_lkhood: -158385.607222\n", | |
| "Iter: 150, log_lkhood: -158385.606958\n", | |
| "Iter: 160, log_lkhood: -158385.606750\n", | |
| "Iter: 170, log_lkhood: -158385.606576\n", | |
| "Iter: 180, log_lkhood: -158385.606425\n", | |
| "Iter: 190, log_lkhood: -158385.606292\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "hmm_2 = hmm()\n", | |
| "hmm_2.fit(obs, N=2, max_iter=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 109, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th>́</th>\n", | |
| " <th>а</th>\n", | |
| " <th>б</th>\n", | |
| " <th>в</th>\n", | |
| " <th>г</th>\n", | |
| " <th>д</th>\n", | |
| " <th>е</th>\n", | |
| " <th>ж</th>\n", | |
| " <th>з</th>\n", | |
| " <th>и</th>\n", | |
| " <th>й</th>\n", | |
| " <th>к</th>\n", | |
| " <th>л</th>\n", | |
| " <th>м</th>\n", | |
| " <th>н</th>\n", | |
| " <th>о</th>\n", | |
| " <th>п</th>\n", | |
| " <th>р</th>\n", | |
| " <th>с</th>\n", | |
| " <th>т</th>\n", | |
| " <th>у</th>\n", | |
| " <th>ф</th>\n", | |
| " <th>х</th>\n", | |
| " <th>ц</th>\n", | |
| " <th>ч</th>\n", | |
| " <th>ш</th>\n", | |
| " <th>щ</th>\n", | |
| " <th>ъ</th>\n", | |
| " <th>ы</th>\n", | |
| " <th>ь</th>\n", | |
| " <th>э</th>\n", | |
| " <th>ю</th>\n", | |
| " <th>я</th>\n", | |
| " <th>ё</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0.215</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.025</td>\n", | |
| " <td>0.066</td>\n", | |
| " <td>0.03</td>\n", | |
| " <td>0.039</td>\n", | |
| " <td>0.018</td>\n", | |
| " <td>0.014</td>\n", | |
| " <td>0.025</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.015</td>\n", | |
| " <td>0.050</td>\n", | |
| " <td>0.072</td>\n", | |
| " <td>0.038</td>\n", | |
| " <td>0.097</td>\n", | |
| " <td></td>\n", | |
| " <td>0.035</td>\n", | |
| " <td>0.06</td>\n", | |
| " <td>0.051</td>\n", | |
| " <td>0.078</td>\n", | |
| " <td></td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.011</td>\n", | |
| " <td>0.005</td>\n", | |
| " <td>0.017</td>\n", | |
| " <td>0.011</td>\n", | |
| " <td>0.005</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " <td>0.008</td>\n", | |
| " <td>0.013</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>0.088</td>\n", | |
| " <td></td>\n", | |
| " <td>0.176</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.143</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.131</td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.241</td>\n", | |
| " <td>0.006</td>\n", | |
| " <td></td>\n", | |
| " <td>0.028</td>\n", | |
| " <td></td>\n", | |
| " <td>0.059</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.004</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.038</td>\n", | |
| " <td>0.043</td>\n", | |
| " <td>0.007</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.033</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " ́ а б в г д е ж з и \\\n", | |
| "0 0.215 0.025 0.066 0.03 0.039 0.018 0.014 0.025 0.002 \n", | |
| "1 0.088 0.176 0.143 0.131 \n", | |
| "\n", | |
| " й к л м н о п р с т у \\\n", | |
| "0 0.015 0.050 0.072 0.038 0.097 0.035 0.06 0.051 0.078 \n", | |
| "1 0.001 0.241 0.006 0.028 0.059 \n", | |
| "\n", | |
| " ф х ц ч ш щ ъ ы ь э ю \\\n", | |
| "0 0.002 0.011 0.005 0.017 0.011 0.005 0.001 0.008 \n", | |
| "1 0.004 0.038 0.043 0.007 0.002 \n", | |
| "\n", | |
| " я ё \n", | |
| "0 0.013 \n", | |
| "1 0.033 " | |
| ] | |
| }, | |
| "execution_count": 109, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#probability emission matrix for the letters and the space\n", | |
| "B = pd.DataFrame(np.around(hmm_2.B, 3), index=symbols).T\n", | |
| "B.replace(B[B==0], \" \")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 110, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Iter: 0, log_lkhood: -206722.641505\n", | |
| "Iter: 10, log_lkhood: -167016.909821\n", | |
| "Iter: 20, log_lkhood: -166640.153796\n", | |
| "Iter: 30, log_lkhood: -166376.748867\n", | |
| "Iter: 40, log_lkhood: -166239.097332\n", | |
| "Iter: 50, log_lkhood: -166188.786042\n", | |
| "Iter: 60, log_lkhood: -166161.622674\n", | |
| "Iter: 70, log_lkhood: -166141.460107\n", | |
| "Iter: 80, log_lkhood: -166122.366250\n", | |
| "Iter: 90, log_lkhood: -166105.709010\n", | |
| "Iter: 100, log_lkhood: -166087.292232\n", | |
| "Iter: 110, log_lkhood: -166060.262078\n", | |
| "Iter: 120, log_lkhood: -166027.852430\n", | |
| "Iter: 130, log_lkhood: -165995.554904\n", | |
| "Iter: 140, log_lkhood: -165920.602652\n", | |
| "Iter: 150, log_lkhood: -165879.112254\n", | |
| "Iter: 160, log_lkhood: -165862.662407\n", | |
| "Iter: 170, log_lkhood: -165834.125206\n", | |
| "Iter: 180, log_lkhood: -165786.659524\n", | |
| "Iter: 190, log_lkhood: -165752.703862\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "hmm_3 = hmm()\n", | |
| "hmm_3.fit(obs, N=3, max_iter=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 111, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th>́</th>\n", | |
| " <th>а</th>\n", | |
| " <th>б</th>\n", | |
| " <th>в</th>\n", | |
| " <th>г</th>\n", | |
| " <th>д</th>\n", | |
| " <th>е</th>\n", | |
| " <th>ж</th>\n", | |
| " <th>з</th>\n", | |
| " <th>и</th>\n", | |
| " <th>й</th>\n", | |
| " <th>к</th>\n", | |
| " <th>л</th>\n", | |
| " <th>м</th>\n", | |
| " <th>н</th>\n", | |
| " <th>о</th>\n", | |
| " <th>п</th>\n", | |
| " <th>р</th>\n", | |
| " <th>с</th>\n", | |
| " <th>т</th>\n", | |
| " <th>у</th>\n", | |
| " <th>ф</th>\n", | |
| " <th>х</th>\n", | |
| " <th>ц</th>\n", | |
| " <th>ч</th>\n", | |
| " <th>ш</th>\n", | |
| " <th>щ</th>\n", | |
| " <th>ъ</th>\n", | |
| " <th>ы</th>\n", | |
| " <th>ь</th>\n", | |
| " <th>э</th>\n", | |
| " <th>ю</th>\n", | |
| " <th>я</th>\n", | |
| " <th>ё</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0.171</td>\n", | |
| " <td></td>\n", | |
| " <td>0.006</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.055</td>\n", | |
| " <td>0.04</td>\n", | |
| " <td>0.020</td>\n", | |
| " <td>0.064</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.030</td>\n", | |
| " <td>0.008</td>\n", | |
| " <td>0.015</td>\n", | |
| " <td>0.023</td>\n", | |
| " <td>0.017</td>\n", | |
| " <td>0.020</td>\n", | |
| " <td>0.212</td>\n", | |
| " <td>0.053</td>\n", | |
| " <td>0.045</td>\n", | |
| " <td>0.052</td>\n", | |
| " <td>0.092</td>\n", | |
| " <td>0.013</td>\n", | |
| " <td></td>\n", | |
| " <td>0.012</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.021</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.02</td>\n", | |
| " <td>0.006</td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>0.124</td>\n", | |
| " <td></td>\n", | |
| " <td>0.048</td>\n", | |
| " <td>0.048</td>\n", | |
| " <td>0.033</td>\n", | |
| " <td></td>\n", | |
| " <td>0.048</td>\n", | |
| " <td>0.142</td>\n", | |
| " <td>0.028</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.069</td>\n", | |
| " <td>0.019</td>\n", | |
| " <td></td>\n", | |
| " <td>0.046</td>\n", | |
| " <td>0.047</td>\n", | |
| " <td>0.098</td>\n", | |
| " <td>0.036</td>\n", | |
| " <td></td>\n", | |
| " <td>0.052</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.051</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.005</td>\n", | |
| " <td>0.007</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.021</td>\n", | |
| " <td>0.01</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.05</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.014</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>0.195</td>\n", | |
| " <td></td>\n", | |
| " <td>0.180</td>\n", | |
| " <td></td>\n", | |
| " <td>0.024</td>\n", | |
| " <td>0.005</td>\n", | |
| " <td>0.003</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.05</td>\n", | |
| " <td>0.070</td>\n", | |
| " <td></td>\n", | |
| " <td>0.081</td>\n", | |
| " <td>0.068</td>\n", | |
| " <td>0.006</td>\n", | |
| " <td>0.072</td>\n", | |
| " <td></td>\n", | |
| " <td>0.006</td>\n", | |
| " <td>0.008</td>\n", | |
| " <td>0.071</td>\n", | |
| " <td>0.031</td>\n", | |
| " <td>0.010</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td></td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.009</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.033</td>\n", | |
| " <td></td>\n", | |
| " <td>0.005</td>\n", | |
| " <td>0.068</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " ́ а б в г д е ж з и \\\n", | |
| "0 0.171 0.006 0.002 0.055 0.04 0.020 0.064 0.030 \n", | |
| "1 0.124 0.048 0.048 0.033 0.048 0.142 0.028 0.001 0.069 \n", | |
| "2 0.195 0.180 0.024 0.005 0.003 0.05 0.070 \n", | |
| "\n", | |
| " й к л м н о п р с т \\\n", | |
| "0 0.008 0.015 0.023 0.017 0.020 0.212 0.053 0.045 0.052 0.092 \n", | |
| "1 0.019 0.046 0.047 0.098 0.036 0.052 \n", | |
| "2 0.081 0.068 0.006 0.072 0.006 0.008 0.071 0.031 \n", | |
| "\n", | |
| " у ф х ц ч ш щ ъ ы ь э \\\n", | |
| "0 0.013 0.012 0.001 0.021 0.001 0.02 0.006 \n", | |
| "1 0.051 0.002 0.005 0.007 0.001 0.021 0.01 0.001 0.05 \n", | |
| "2 0.010 0.002 0.002 0.009 0.033 \n", | |
| "\n", | |
| " ю я ё \n", | |
| "0 0.001 \n", | |
| "1 0.014 \n", | |
| "2 0.005 0.068 " | |
| ] | |
| }, | |
| "execution_count": 111, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#probability emission matrix for the letters and the space\n", | |
| "B = pd.DataFrame(np.around(hmm_3.B, 3), index=symbols).T\n", | |
| "B.replace(B[B==0], \" \")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 112, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Iter: 0, log_lkhood: -204826.378984\n", | |
| "Iter: 10, log_lkhood: -158387.129848\n", | |
| "Iter: 20, log_lkhood: -153366.552658\n", | |
| "Iter: 30, log_lkhood: -152717.446563\n", | |
| "Iter: 40, log_lkhood: -152568.968634\n", | |
| "Iter: 50, log_lkhood: -152506.637600\n", | |
| "Iter: 60, log_lkhood: -152475.317134\n", | |
| "Iter: 70, log_lkhood: -152454.696123\n", | |
| "Iter: 80, log_lkhood: -152438.634036\n", | |
| "Iter: 90, log_lkhood: -152426.087514\n", | |
| "Iter: 100, log_lkhood: -152417.145420\n", | |
| "Iter: 110, log_lkhood: -152411.385141\n", | |
| "Iter: 120, log_lkhood: -152407.907331\n", | |
| "Iter: 130, log_lkhood: -152405.862370\n", | |
| "Iter: 140, log_lkhood: -152404.664465\n", | |
| "Iter: 150, log_lkhood: -152403.957874\n", | |
| "Iter: 160, log_lkhood: -152403.536247\n", | |
| "Iter: 170, log_lkhood: -152403.281592\n", | |
| "Iter: 180, log_lkhood: -152403.126194\n", | |
| "Iter: 190, log_lkhood: -152403.030619\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "hmm_4 = hmm()\n", | |
| "hmm_4.fit(obs, N=4, max_iter=200)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 113, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th>́</th>\n", | |
| " <th>а</th>\n", | |
| " <th>б</th>\n", | |
| " <th>в</th>\n", | |
| " <th>г</th>\n", | |
| " <th>д</th>\n", | |
| " <th>е</th>\n", | |
| " <th>ж</th>\n", | |
| " <th>з</th>\n", | |
| " <th>и</th>\n", | |
| " <th>й</th>\n", | |
| " <th>к</th>\n", | |
| " <th>л</th>\n", | |
| " <th>м</th>\n", | |
| " <th>н</th>\n", | |
| " <th>о</th>\n", | |
| " <th>п</th>\n", | |
| " <th>р</th>\n", | |
| " <th>с</th>\n", | |
| " <th>т</th>\n", | |
| " <th>у</th>\n", | |
| " <th>ф</th>\n", | |
| " <th>х</th>\n", | |
| " <th>ц</th>\n", | |
| " <th>ч</th>\n", | |
| " <th>ш</th>\n", | |
| " <th>щ</th>\n", | |
| " <th>ъ</th>\n", | |
| " <th>ы</th>\n", | |
| " <th>ь</th>\n", | |
| " <th>э</th>\n", | |
| " <th>ю</th>\n", | |
| " <th>я</th>\n", | |
| " <th>ё</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>0.504</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.004</td>\n", | |
| " <td>0.052</td>\n", | |
| " <td>0.009</td>\n", | |
| " <td>0.019</td>\n", | |
| " <td>0.056</td>\n", | |
| " <td></td>\n", | |
| " <td>0.01</td>\n", | |
| " <td>0.008</td>\n", | |
| " <td>0.04</td>\n", | |
| " <td>0.034</td>\n", | |
| " <td></td>\n", | |
| " <td>0.008</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.025</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.112</td>\n", | |
| " <td>0.022</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.003</td>\n", | |
| " <td></td>\n", | |
| " <td>0.016</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.012</td>\n", | |
| " <td>0.021</td>\n", | |
| " <td>0.039</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>0.009</td>\n", | |
| " <td></td>\n", | |
| " <td>0.203</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " <td>0.159</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.15</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.278</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.067</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.043</td>\n", | |
| " <td>0.052</td>\n", | |
| " <td></td>\n", | |
| " <td>0.003</td>\n", | |
| " <td>0.034</td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>0.155</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.031</td>\n", | |
| " <td>0.054</td>\n", | |
| " <td>0.046</td>\n", | |
| " <td>0.046</td>\n", | |
| " <td></td>\n", | |
| " <td>0.024</td>\n", | |
| " <td>0.047</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.03</td>\n", | |
| " <td>0.156</td>\n", | |
| " <td>0.07</td>\n", | |
| " <td>0.113</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.056</td>\n", | |
| " <td>0.03</td>\n", | |
| " <td>0.068</td>\n", | |
| " <td></td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.019</td>\n", | |
| " <td>0.008</td>\n", | |
| " <td>0.019</td>\n", | |
| " <td>0.02</td>\n", | |
| " <td>0.005</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>0.076</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.035</td>\n", | |
| " <td>0.073</td>\n", | |
| " <td>0.028</td>\n", | |
| " <td>0.042</td>\n", | |
| " <td></td>\n", | |
| " <td>0.016</td>\n", | |
| " <td>0.016</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td>0.071</td>\n", | |
| " <td>0.056</td>\n", | |
| " <td>0.032</td>\n", | |
| " <td>0.154</td>\n", | |
| " <td></td>\n", | |
| " <td>0.074</td>\n", | |
| " <td>0.104</td>\n", | |
| " <td>0.047</td>\n", | |
| " <td>0.121</td>\n", | |
| " <td></td>\n", | |
| " <td>0.002</td>\n", | |
| " <td>0.01</td>\n", | |
| " <td>0.006</td>\n", | |
| " <td>0.018</td>\n", | |
| " <td>0.011</td>\n", | |
| " <td>0.008</td>\n", | |
| " <td>0.001</td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " <td></td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " ́ а б в г д е ж з и \\\n", | |
| "0 0.504 0.004 0.052 0.009 0.019 0.056 0.01 0.008 \n", | |
| "1 0.009 0.203 0.001 0.159 0.15 \n", | |
| "2 0.155 0.031 0.054 0.046 0.046 0.024 0.047 \n", | |
| "3 0.076 0.035 0.073 0.028 0.042 0.016 0.016 \n", | |
| "\n", | |
| " й к л м н о п р с т у \\\n", | |
| "0 0.04 0.034 0.008 0.025 0.002 0.112 0.022 0.001 \n", | |
| "1 0.278 0.067 \n", | |
| "2 0.03 0.156 0.07 0.113 0.056 0.03 0.068 \n", | |
| "3 0.071 0.056 0.032 0.154 0.074 0.104 0.047 0.121 \n", | |
| "\n", | |
| " ф х ц ч ш щ ъ ы ь э \\\n", | |
| "0 0.002 0.003 0.016 0.012 \n", | |
| "1 0.043 0.052 \n", | |
| "2 0.002 0.019 0.008 0.019 0.02 0.005 \n", | |
| "3 0.002 0.01 0.006 0.018 0.011 0.008 0.001 \n", | |
| "\n", | |
| " ю я ё \n", | |
| "0 0.021 0.039 \n", | |
| "1 0.003 0.034 \n", | |
| "2 \n", | |
| "3 " | |
| ] | |
| }, | |
| "execution_count": 113, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#probability emission matrix for the letters and the space\n", | |
| "B = pd.DataFrame(np.around(hmm_4.B, 3), index=symbols).T\n", | |
| "B.replace(B[B==0], \" \")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "anaconda-cloud": {}, | |
| "kernelspec": { | |
| "display_name": "Python [conda env:impact]", | |
| "language": "python", | |
| "name": "conda-env-impact-py" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.0" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 1 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment