Last active
February 27, 2016 16:44
-
-
Save jgc128/bd426bb669cb9dc7df6a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Using Theano backend.\n", | |
| "/data1/aromanov/virt_env2/keras_latest/lib/python3.4/site-packages/theano/tensor/signal/downsample.py:5: UserWarning: downsample module has been moved to the pool module.\n", | |
| " warnings.warn(\"downsample module has been moved to the pool module.\")\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import os\n", | |
| "import pickle\n", | |
| "from itertools import chain\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "\n", | |
| "from sklearn.cross_validation import train_test_split\n", | |
| "from sklearn.metrics import f1_score, accuracy_score\n", | |
| "\n", | |
| "from keras.utils.np_utils import to_categorical\n", | |
| "from keras.models import Sequential\n", | |
| "from keras.layers.core import Dense, RepeatVector, TimeDistributedDense, Masking, Activation, RepeatVector\n", | |
| "from keras.layers.embeddings import Embedding\n", | |
| "from keras.layers.recurrent import LSTM" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from nltk.corpus import brown" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Create training data from the brown corpus" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "nb_brown_samples = 5000" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "brown_corpus = brown.tagged_sents(tagset='universal')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "brown_corpus_selected = list(brown_corpus[0:nb_brown_samples])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Sentences: 5000\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('Sentences:', len(brown_corpus_selected))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[('The', 'DET'),\n", | |
| " ('Fulton', 'NOUN'),\n", | |
| " ('County', 'NOUN'),\n", | |
| " ('Grand', 'ADJ'),\n", | |
| " ('Jury', 'NOUN'),\n", | |
| " ('said', 'VERB'),\n", | |
| " ('Friday', 'NOUN'),\n", | |
| " ('an', 'DET'),\n", | |
| " ('investigation', 'NOUN'),\n", | |
| " ('of', 'ADP'),\n", | |
| " (\"Atlanta's\", 'NOUN'),\n", | |
| " ('recent', 'ADJ'),\n", | |
| " ('primary', 'NOUN'),\n", | |
| " ('election', 'NOUN'),\n", | |
| " ('produced', 'VERB'),\n", | |
| " ('``', '.'),\n", | |
| " ('no', 'DET'),\n", | |
| " ('evidence', 'NOUN'),\n", | |
| " (\"''\", '.'),\n", | |
| " ('that', 'ADP'),\n", | |
| " ('any', 'DET'),\n", | |
| " ('irregularities', 'NOUN'),\n", | |
| " ('took', 'VERB'),\n", | |
| " ('place', 'NOUN'),\n", | |
| " ('.', '.')]" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "brown_corpus_selected[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[('His', 'DET'),\n", | |
| " ('petition', 'NOUN'),\n", | |
| " ('charged', 'VERB'),\n", | |
| " ('mental', 'ADJ'),\n", | |
| " ('cruelty', 'NOUN'),\n", | |
| " ('.', '.')]" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "brown_corpus_selected[31]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "brown_train, brown_test = train_test_split(brown_corpus_selected, test_size=0.25)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Train: 3750\n", | |
| "Test: 1250\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('Train:', len(brown_train))\n", | |
| "print('Test:', len(brown_test))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Construct *tok2id* and *label2id*" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| " # i+1 because 0 will be the maskng character\n", | |
| "tok2id = {t:i+1 for i,t in enumerate(set(t[0].lower() for seq in brown_corpus_selected for t in seq))}\n", | |
| "\n", | |
| "label2id = {t:i for i,t in enumerate(set(t[1] for seq in brown_corpus_selected for t in seq))}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Tokens: 13735\n", | |
| "Labels: 12\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('Tokens:', len(tok2id))\n", | |
| "print('Labels:', len(label2id))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Convert brown corpus to input matricies" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def create_XY_seq(data, tok2id, label2id):\n", | |
| " nb_classes = len(label2id)\n", | |
| " \n", | |
| " # data[0] - a sentence\n", | |
| " # data[0][0] - a pair (token, tag)\n", | |
| " \n", | |
| " # X_seq - a list of sentences, each sentences - a list of tokens\n", | |
| " X_seq = [[t[0].lower() for t in seq] for seq in data]\n", | |
| " \n", | |
| " # convert tokens to id\n", | |
| " X_seq_ids = [[tok2id[l] for l in seq] for seq in X_seq]\n", | |
| " \n", | |
| " \n", | |
| " # Y_seq - a list of sequences, each sequence - a list of labels\n", | |
| " Y_seq = [[t[1] for t in seq] for seq in data]\n", | |
| " \n", | |
| " # convert tokens to id\n", | |
| " Y_seq_ids = [[label2id[l] for l in seq] for seq in Y_seq]\n", | |
| " \n", | |
| " # create one-hot representation for labels\n", | |
| " Y_seq_cat = [to_categorical(seq, nb_classes=nb_classes) for seq in Y_seq_ids]\n", | |
| " \n", | |
| " return X_seq_ids, Y_seq_cat" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "X_seq_train, Y_seq_train = create_XY_seq(brown_train, tok2id, label2id)\n", | |
| "X_seq_test, Y_seq_test = create_XY_seq(brown_test, tok2id, label2id)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Train: 3750 3750\n", | |
| "Test: 1250 1250\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('Train:', len(X_seq_train), len(Y_seq_train))\n", | |
| "print('Test:', len(X_seq_test), len(Y_seq_test))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[253,\n", | |
| " 11462,\n", | |
| " 7707,\n", | |
| " 4952,\n", | |
| " 11759,\n", | |
| " 3983,\n", | |
| " 12227,\n", | |
| " 2286,\n", | |
| " 3993,\n", | |
| " 8386,\n", | |
| " 12227,\n", | |
| " 9631,\n", | |
| " 1959,\n", | |
| " 2286,\n", | |
| " 6550,\n", | |
| " 12230,\n", | |
| " 13251,\n", | |
| " 730,\n", | |
| " 10213,\n", | |
| " 10598,\n", | |
| " 3185]" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_seq_train[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "Y_seq_train[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Create data matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Input to the network - 2d tensor with shape `(nb_samples, sequence_length)`\n", | |
| "\n", | |
| "Output of the network - 3d tensor with shape `(nb_samples, timesteps, nb_classes)`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "nb_samples_train = len(X_seq_train)\n", | |
| "nb_samples_test = len(X_seq_test)\n", | |
| "\n", | |
| "nb_classes = len(label2id)\n", | |
| "maxlen = max(chain((len(seq) for seq in X_seq_train), (len(seq) for seq in X_seq_test)))\n", | |
| "input_dim = max(tok2id.values()) + 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "nb_samples_train: 3750\n", | |
| "nb_samples_test: 1250\n", | |
| "nb_classes: 12\n", | |
| "maxlen: 102\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('nb_samples_train:', nb_samples_train)\n", | |
| "print('nb_samples_test:', nb_samples_test)\n", | |
| "print('nb_classes:', nb_classes)\n", | |
| "print('maxlen:', maxlen)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def create_data_matrix(X_seq, Y_seq, nb_samples, maxlen, nb_classes):\n", | |
| " X = np.zeros((nb_samples, maxlen))\n", | |
| " Y = np.zeros((nb_samples, maxlen, nb_classes))\n", | |
| " \n", | |
| " for i in range(nb_samples):\n", | |
| " cur_len = len(X_seq[i])\n", | |
| " \n", | |
| " # We pad on the left with zeros, so for short sentences the first elemnts in the matrix are zeros \n", | |
| " X[i, maxlen - cur_len:] = X_seq[i]\n", | |
| " Y[i, maxlen - cur_len:, :] = Y_seq[i]\n", | |
| " \n", | |
| " return X, Y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "X_train, Y_train = create_data_matrix(X_seq_train, Y_seq_train, nb_samples_train, maxlen, nb_classes)\n", | |
| "X_test, Y_test = create_data_matrix(X_seq_test, Y_seq_test, nb_samples_test, maxlen, nb_classes)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Train: (3750, 102) (3750, 102, 12)\n", | |
| "Test: (1250, 102) (1250, 102, 12)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('Train:', X_train.shape, Y_train.shape)\n", | |
| "print('Test:', X_test.shape, Y_test.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0., 0., 0., ..., 10213., 10598., 3185.],\n", | |
| " [ 0., 0., 0., ..., 11759., 12732., 3185.],\n", | |
| " [ 0., 0., 0., ..., 5503., 8139., 3185.],\n", | |
| " ..., \n", | |
| " [ 0., 0., 0., ..., 11759., 5088., 3185.],\n", | |
| " [ 0., 0., 0., ..., 9025., 7550., 3185.],\n", | |
| " [ 0., 0., 0., ..., 6795., 2419., 3185.]])" | |
| ] | |
| }, | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[[ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " ..., \n", | |
| " [ 0., 0., 1., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " [[ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " ..., \n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " [[ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " ..., \n", | |
| " [ 0., 0., 0., ..., 1., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " ..., \n", | |
| " [[ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " ..., \n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " [[ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " ..., \n", | |
| " [ 0., 0., 1., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.]],\n", | |
| "\n", | |
| " [[ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " ..., \n", | |
| " [ 0., 0., 0., ..., 1., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [ 0., 0., 0., ..., 0., 0., 0.]]])" | |
| ] | |
| }, | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "Y_train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Build a Keras model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "batch_size = 128\n", | |
| "nb_epoch = 15" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = Sequential()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Embedding layer converts the input sequences into a sequence of vectors with dim 300\n", | |
| "model.add(Embedding(output_dim=200, input_dim=input_dim, input_length=maxlen, mask_zero=True))\n", | |
| "\n", | |
| "# LSTM input: (nb_samples, timesteps, input_dim)\n", | |
| "model.add(LSTM(output_dim=128, return_sequences=True))\n", | |
| "\n", | |
| "# TimeDistributedDense is a fully-connected layer that apply the same weight matrix at each timestep\n", | |
| "model.add(TimeDistributedDense(output_dim=50, activation='sigmoid'))\n", | |
| "\n", | |
| "model.add(TimeDistributedDense(output_dim=nb_classes, activation='softmax'))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model.compile(loss='categorical_crossentropy', optimizer='adam')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch 1/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 2.2656 \n", | |
| "Epoch 2/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 2.0316 \n", | |
| "Epoch 3/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 1.7507 \n", | |
| "Epoch 4/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 1.1898 \n", | |
| "Epoch 5/15\n", | |
| "3750/3750 [==============================] - 21s - loss: 0.7586 \n", | |
| "Epoch 6/15\n", | |
| "3750/3750 [==============================] - 23s - loss: 0.5090 \n", | |
| "Epoch 7/15\n", | |
| "3750/3750 [==============================] - 22s - loss: 0.3740 \n", | |
| "Epoch 8/15\n", | |
| "3750/3750 [==============================] - 23s - loss: 0.2862 \n", | |
| "Epoch 9/15\n", | |
| "3750/3750 [==============================] - 23s - loss: 0.2262 \n", | |
| "Epoch 10/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 0.1852 \n", | |
| "Epoch 11/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 0.1559 \n", | |
| "Epoch 12/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 0.1339 \n", | |
| "Epoch 13/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 0.1168 \n", | |
| "Epoch 14/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 0.1029 \n", | |
| "Epoch 15/15\n", | |
| "3750/3750 [==============================] - 20s - loss: 0.0921 \n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<keras.callbacks.History at 0x7f7a336b1278>" | |
| ] | |
| }, | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1250/1250 [==============================] - 1s \n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "pred_classes = model.predict_classes(X_test, batch_size=batch_size)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "pred_classes: (1250, 102)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('pred_classes:', pred_classes.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[4, 4, 4, ..., 9, 4, 7],\n", | |
| " [4, 4, 4, ..., 3, 5, 7],\n", | |
| " [4, 4, 4, ..., 3, 4, 7],\n", | |
| " ..., \n", | |
| " [4, 4, 4, ..., 4, 4, 7],\n", | |
| " [4, 4, 4, ..., 2, 4, 7],\n", | |
| " [4, 4, 4, ..., 4, 4, 7]])" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "pred_classes" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Calc accuracy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "X_seq_test_len = [len(seq) for seq in X_seq_test]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# concatenate all sequence together to get a long list of labels\n", | |
| "# we do not want to consider the starting padding, so we need to get only the actual sentences\n", | |
| "y_true = [Y_seq_test[i][j].argmax() for i in range(nb_samples_test) for j in range(X_seq_test_len[i])]\n", | |
| "y_pred = [pred_classes[i,maxlen - X_seq_test_len[i] + j] for i in range(nb_samples_test) for j in range(X_seq_test_len[i])]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "y true: 26614\n", | |
| "y pred: 26614\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('y true:', len(y_true))\n", | |
| "print('y pred:', len(y_pred))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "acc = accuracy_score(y_true, y_pred)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Accuracy: 0.941121214398\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('Accuracy:', acc)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.4.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment