Last active
April 5, 2020 00:42
-
-
Save yufengg/a6dff912ab48f7a273f5704ad9ab1311 to your computer and use it in GitHub Desktop.
Jupyter notebook for AI Adventures episode 3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Copyright 2017 Google Inc.\n", | |
| "#\n", | |
| "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", | |
| "# you may not use this file except in compliance with the License.\n", | |
| "# You may obtain a copy of the License at\n", | |
| "#\n", | |
| "# http://www.apache.org/licenses/LICENSE-2.0\n", | |
| "#\n", | |
| "# Unless required by applicable law or agreed to in writing, software\n", | |
| "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", | |
| "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | |
| "# See the License for the specific language governing permissions and\n", | |
| "# limitations under the License." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Overview\n", | |
| "All the code we'll look at is in the next cell. We will step through each step after." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "print(tf.__version__)\n", | |
| "\n", | |
| "from tensorflow.contrib.learn.python.learn.datasets import base\n", | |
| "\n", | |
| "# Data files\n", | |
| "IRIS_TRAINING = \"iris_training.csv\"\n", | |
| "IRIS_TEST = \"iris_test.csv\"\n", | |
| "\n", | |
| "# Load datasets.\n", | |
| "training_set = base.load_csv_with_header(filename=IRIS_TRAINING,\n", | |
| " features_dtype=np.float32,\n", | |
| " target_dtype=np.int)\n", | |
| "test_set = base.load_csv_with_header(filename=IRIS_TEST,\n", | |
| " features_dtype=np.float32,\n", | |
| " target_dtype=np.int)\n", | |
| "\n", | |
| "# Specify that all features have real-value data\n", | |
| "feature_name = \"flower_features\"\n", | |
| "feature_columns = [tf.feature_column.numeric_column(feature_name, \n", | |
| " shape=[4])]\n", | |
| "classifier = tf.estimator.LinearClassifier(\n", | |
| " feature_columns=feature_columns,\n", | |
| " n_classes=3,\n", | |
| " model_dir=\"/tmp/iris_model\")\n", | |
| "\n", | |
| "def input_fn(dataset):\n", | |
| " def _fn():\n", | |
| " features = {feature_name: tf.constant(dataset.data)}\n", | |
| " label = tf.constant(dataset.target)\n", | |
| " return features, label\n", | |
| " return _fn\n", | |
| "\n", | |
| "# Fit model.\n", | |
| "classifier.train(input_fn=input_fn(training_set),\n", | |
| " steps=1000)\n", | |
| "print('fit done')\n", | |
| "\n", | |
| "# Evaluate accuracy.\n", | |
| "accuracy_score = classifier.evaluate(input_fn=input_fn(test_set), \n", | |
| " steps=100)[\"accuracy\"]\n", | |
| "print('\\nAccuracy: {0:f}'.format(accuracy_score))\n", | |
| "\n", | |
| "# Export the model for serving\n", | |
| "feature_spec = {'flower_features': tf.FixedLenFeature(shape=[4], dtype=np.float32)}\n", | |
| "\n", | |
| "serving_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)\n", | |
| "\n", | |
| "classifier.export_savedmodel(export_dir_base='/tmp/iris_model' + '/export', \n", | |
| " serving_input_receiver_fn=serving_fn)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Imports" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1.3.0\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "print(tf.__version__)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Data set\n", | |
| "From https://en.wikipedia.org/wiki/Iris_flower_data_set\n", | |
| "\n", | |
| "3 types of Iris Flowers: \n", | |
| "\n", | |
| "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg/450px-Kosaciec_szczecinkowaty_Iris_setosa.jpg\" style=\"width: 100px; display:inline\"/>\n", | |
| "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/4/41/Iris_versicolor_3.jpg/800px-Iris_versicolor_3.jpg\" style=\"width: 150px;display:inline\"/>\n", | |
| "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/9/9f/Iris_virginica.jpg/736px-Iris_virginica.jpg\" style=\"width: 150px;display:inline\"/>\n", | |
| "* Iris Setosa\n", | |
| "* Iris Versicolour\n", | |
| "* Iris Virginica\n", | |
| "\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Data Columns:\n", | |
| " 1. sepal length in cm \n", | |
| " 2. sepal width in cm \n", | |
| " 3. petal length in cm \n", | |
| " 4. petal width in cm\n", | |
| "\n", | |
| "<img src=\"petal_sepal.png\" style=\"width: 200px;\"/>\n", | |
| "<img src=\"https://storage.googleapis.com/image-uploader/AIA_images/data_table.png\" style=\"width: 450px\"/>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Load data in" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[ 6.4000001 2.79999995 5.5999999 2.20000005]\n", | |
| " [ 5. 2.29999995 3.29999995 1. ]\n", | |
| " [ 4.9000001 2.5 4.5 1.70000005]\n", | |
| " [ 4.9000001 3.0999999 1.5 0.1 ]\n", | |
| " [ 5.69999981 3.79999995 1.70000005 0.30000001]\n", | |
| " [ 4.4000001 3.20000005 1.29999995 0.2 ]\n", | |
| " [ 5.4000001 3.4000001 1.5 0.40000001]\n", | |
| " [ 6.9000001 3.0999999 5.0999999 2.29999995]\n", | |
| " [ 6.69999981 3.0999999 4.4000001 1.39999998]\n", | |
| " [ 5.0999999 3.70000005 1.5 0.40000001]\n", | |
| " [ 5.19999981 2.70000005 3.9000001 1.39999998]\n", | |
| " [ 6.9000001 3.0999999 4.9000001 1.5 ]\n", | |
| " [ 5.80000019 4. 1.20000005 0.2 ]\n", | |
| " [ 5.4000001 3.9000001 1.70000005 0.40000001]\n", | |
| " [ 7.69999981 3.79999995 6.69999981 2.20000005]\n", | |
| " [ 6.30000019 3.29999995 4.69999981 1.60000002]\n", | |
| " [ 6.80000019 3.20000005 5.9000001 2.29999995]\n", | |
| " [ 7.5999999 3. 6.5999999 2.0999999 ]\n", | |
| " [ 6.4000001 3.20000005 5.30000019 2.29999995]\n", | |
| " [ 5.69999981 4.4000001 1.5 0.40000001]\n", | |
| " [ 6.69999981 3.29999995 5.69999981 2.0999999 ]\n", | |
| " [ 6.4000001 2.79999995 5.5999999 2.0999999 ]\n", | |
| " [ 5.4000001 3.9000001 1.29999995 0.40000001]\n", | |
| " [ 6.0999999 2.5999999 5.5999999 1.39999998]\n", | |
| " [ 7.19999981 3. 5.80000019 1.60000002]\n", | |
| " [ 5.19999981 3.5 1.5 0.2 ]\n", | |
| " [ 5.80000019 2.5999999 4. 1.20000005]\n", | |
| " [ 5.9000001 3. 5.0999999 1.79999995]\n", | |
| " [ 5.4000001 3. 4.5 1.5 ]\n", | |
| " [ 6.69999981 3. 5. 1.70000005]\n", | |
| " [ 6.30000019 2.29999995 4.4000001 1.29999995]\n", | |
| " [ 5.0999999 2.5 3. 1.10000002]\n", | |
| " [ 6.4000001 3.20000005 4.5 1.5 ]\n", | |
| " [ 6.80000019 3. 5.5 2.0999999 ]\n", | |
| " [ 6.19999981 2.79999995 4.80000019 1.79999995]\n", | |
| " [ 6.9000001 3.20000005 5.69999981 2.29999995]\n", | |
| " [ 6.5 3.20000005 5.0999999 2. ]\n", | |
| " [ 5.80000019 2.79999995 5.0999999 2.4000001 ]\n", | |
| " [ 5.0999999 3.79999995 1.5 0.30000001]\n", | |
| " [ 4.80000019 3. 1.39999998 0.30000001]\n", | |
| " [ 7.9000001 3.79999995 6.4000001 2. ]\n", | |
| " [ 5.80000019 2.70000005 5.0999999 1.89999998]\n", | |
| " [ 6.69999981 3. 5.19999981 2.29999995]\n", | |
| " [ 5.0999999 3.79999995 1.89999998 0.40000001]\n", | |
| " [ 4.69999981 3.20000005 1.60000002 0.2 ]\n", | |
| " [ 6. 2.20000005 5. 1.5 ]\n", | |
| " [ 4.80000019 3.4000001 1.60000002 0.2 ]\n", | |
| " [ 7.69999981 2.5999999 6.9000001 2.29999995]\n", | |
| " [ 4.5999999 3.5999999 1. 0.2 ]\n", | |
| " [ 7.19999981 3.20000005 6. 1.79999995]\n", | |
| " [ 5. 3.29999995 1.39999998 0.2 ]\n", | |
| " [ 6.5999999 3. 4.4000001 1.39999998]\n", | |
| " [ 6.0999999 2.79999995 4. 1.29999995]\n", | |
| " [ 5. 3.20000005 1.20000005 0.2 ]\n", | |
| " [ 7. 3.20000005 4.69999981 1.39999998]\n", | |
| " [ 6. 3. 4.80000019 1.79999995]\n", | |
| " [ 7.4000001 2.79999995 6.0999999 1.89999998]\n", | |
| " [ 5.80000019 2.70000005 5.0999999 1.89999998]\n", | |
| " [ 6.19999981 3.4000001 5.4000001 2.29999995]\n", | |
| " [ 5. 2. 3.5 1. ]\n", | |
| " [ 5.5999999 2.5 3.9000001 1.10000002]\n", | |
| " [ 6.69999981 3.0999999 5.5999999 2.4000001 ]\n", | |
| " [ 6.30000019 2.5 5. 1.89999998]\n", | |
| " [ 6.4000001 3.0999999 5.5 1.79999995]\n", | |
| " [ 6.19999981 2.20000005 4.5 1.5 ]\n", | |
| " [ 7.30000019 2.9000001 6.30000019 1.79999995]\n", | |
| " [ 4.4000001 3. 1.29999995 0.2 ]\n", | |
| " [ 7.19999981 3.5999999 6.0999999 2.5 ]\n", | |
| " [ 6.5 3. 5.5 1.79999995]\n", | |
| " [ 5. 3.4000001 1.5 0.2 ]\n", | |
| " [ 4.69999981 3.20000005 1.29999995 0.2 ]\n", | |
| " [ 6.5999999 2.9000001 4.5999999 1.29999995]\n", | |
| " [ 5.5 3.5 1.29999995 0.2 ]\n", | |
| " [ 7.69999981 3. 6.0999999 2.29999995]\n", | |
| " [ 6.0999999 3. 4.9000001 1.79999995]\n", | |
| " [ 4.9000001 3.0999999 1.5 0.1 ]\n", | |
| " [ 5.5 2.4000001 3.79999995 1.10000002]\n", | |
| " [ 5.69999981 2.9000001 4.19999981 1.29999995]\n", | |
| " [ 6. 2.9000001 4.5 1.5 ]\n", | |
| " [ 6.4000001 2.70000005 5.30000019 1.89999998]\n", | |
| " [ 5.4000001 3.70000005 1.5 0.2 ]\n", | |
| " [ 6.0999999 2.9000001 4.69999981 1.39999998]\n", | |
| " [ 6.5 2.79999995 4.5999999 1.5 ]\n", | |
| " [ 5.5999999 2.70000005 4.19999981 1.29999995]\n", | |
| " [ 6.30000019 3.4000001 5.5999999 2.4000001 ]\n", | |
| " [ 4.9000001 3.0999999 1.5 0.1 ]\n", | |
| " [ 6.80000019 2.79999995 4.80000019 1.39999998]\n", | |
| " [ 5.69999981 2.79999995 4.5 1.29999995]\n", | |
| " [ 6. 2.70000005 5.0999999 1.60000002]\n", | |
| " [ 5. 3.5 1.29999995 0.30000001]\n", | |
| " [ 6.5 3. 5.19999981 2. ]\n", | |
| " [ 6.0999999 2.79999995 4.69999981 1.20000005]\n", | |
| " [ 5.0999999 3.5 1.39999998 0.30000001]\n", | |
| " [ 4.5999999 3.0999999 1.5 0.2 ]\n", | |
| " [ 6.5 3. 5.80000019 2.20000005]\n", | |
| " [ 4.5999999 3.4000001 1.39999998 0.30000001]\n", | |
| " [ 4.5999999 3.20000005 1.39999998 0.2 ]\n", | |
| " [ 7.69999981 2.79999995 6.69999981 2. ]\n", | |
| " [ 5.9000001 3.20000005 4.80000019 1.79999995]\n", | |
| " [ 5.0999999 3.79999995 1.60000002 0.2 ]\n", | |
| " [ 4.9000001 3. 1.39999998 0.2 ]\n", | |
| " [ 4.9000001 2.4000001 3.29999995 1. ]\n", | |
| " [ 4.5 2.29999995 1.29999995 0.30000001]\n", | |
| " [ 5.80000019 2.70000005 4.0999999 1. ]\n", | |
| " [ 5. 3.4000001 1.60000002 0.40000001]\n", | |
| " [ 5.19999981 3.4000001 1.39999998 0.2 ]\n", | |
| " [ 5.30000019 3.70000005 1.5 0.2 ]\n", | |
| " [ 5. 3.5999999 1.39999998 0.2 ]\n", | |
| " [ 5.5999999 2.9000001 3.5999999 1.29999995]\n", | |
| " [ 4.80000019 3.0999999 1.60000002 0.2 ]\n", | |
| " [ 6.30000019 2.70000005 4.9000001 1.79999995]\n", | |
| " [ 5.69999981 2.79999995 4.0999999 1.29999995]\n", | |
| " [ 5. 3. 1.60000002 0.2 ]\n", | |
| " [ 6.30000019 3.29999995 6. 2.5 ]\n", | |
| " [ 5. 3.5 1.60000002 0.60000002]\n", | |
| " [ 5.5 2.5999999 4.4000001 1.20000005]\n", | |
| " [ 5.69999981 3. 4.19999981 1.20000005]\n", | |
| " [ 4.4000001 2.9000001 1.39999998 0.2 ]\n", | |
| " [ 4.80000019 3. 1.39999998 0.1 ]\n", | |
| " [ 5.5 2.4000001 3.70000005 1. ]]\n", | |
| "[2 1 2 0 0 0 0 2 1 0 1 1 0 0 2 1 2 2 2 0 2 2 0 2 2 0 1 2 1 1 1 1 1 2 2 2 2\n", | |
| " 2 0 0 2 2 2 0 0 2 0 2 0 2 0 1 1 0 1 2 2 2 2 1 1 2 2 2 1 2 0 2 2 0 0 1 0 2\n", | |
| " 2 0 1 1 1 2 0 1 1 1 2 0 1 1 1 0 2 1 0 0 2 0 0 2 1 0 0 1 0 1 0 0 0 0 1 0 2\n", | |
| " 1 0 2 0 1 1 0 0 1]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from tensorflow.contrib.learn.python.learn.datasets import base\n", | |
| "\n", | |
| "# Data files\n", | |
| "IRIS_TRAINING = \"iris_training.csv\"\n", | |
| "IRIS_TEST = \"iris_test.csv\"\n", | |
| "\n", | |
| "# Load datasets.\n", | |
| "training_set = base.load_csv_with_header(filename=IRIS_TRAINING,\n", | |
| " features_dtype=np.float32,\n", | |
| " target_dtype=np.int)\n", | |
| "test_set = base.load_csv_with_header(filename=IRIS_TEST,\n", | |
| " features_dtype=np.float32,\n", | |
| " target_dtype=np.int)\n", | |
| "\n", | |
| "print(training_set.data)\n", | |
| "\n", | |
| "print(training_set.target)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Feature columns and model creation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Using default config.\n", | |
| "INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_tf_random_seed': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_save_checkpoints_steps': None, '_model_dir': '/tmp/iris_model', '_save_summary_steps': 100}\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Specify that all features have real-value data\n", | |
| "feature_name = \"flower_features\"\n", | |
| "feature_columns = [tf.feature_column.numeric_column(feature_name, \n", | |
| " shape=[4])]\n", | |
| "\n", | |
| "classifier = tf.estimator.LinearClassifier(\n", | |
| " feature_columns=feature_columns,\n", | |
| " n_classes=3,\n", | |
| " model_dir=\"/tmp/iris_model\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Input function" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "({'flower_features': <tf.Tensor 'Const:0' shape=(120, 4) dtype=float32>}, <tf.Tensor 'Const_1:0' shape=(120,) dtype=int64>)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def input_fn(dataset):\n", | |
| " def _fn():\n", | |
| " features = {feature_name: tf.constant(dataset.data)}\n", | |
| " label = tf.constant(dataset.target)\n", | |
| " return features, label\n", | |
| " return _fn\n", | |
| "\n", | |
| "print(input_fn(training_set)())\n", | |
| "\n", | |
| "# raw data -> input function -> feature columns -> model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Training" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Create CheckpointSaverHook.\n", | |
| "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-3000\n", | |
| "INFO:tensorflow:Saving checkpoints for 3001 into /tmp/iris_model/model.ckpt.\n", | |
| "INFO:tensorflow:loss = 8.50027, step = 3001\n", | |
| "INFO:tensorflow:global_step/sec: 827.439\n", | |
| "INFO:tensorflow:loss = 8.40806, step = 3101 (0.123 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 868.825\n", | |
| "INFO:tensorflow:loss = 8.32063, step = 3201 (0.116 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 959.112\n", | |
| "INFO:tensorflow:loss = 8.23757, step = 3301 (0.104 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 844.444\n", | |
| "INFO:tensorflow:loss = 8.15855, step = 3401 (0.118 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 847.278\n", | |
| "INFO:tensorflow:loss = 8.08324, step = 3501 (0.118 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 825.594\n", | |
| "INFO:tensorflow:loss = 8.01139, step = 3601 (0.120 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 882.98\n", | |
| "INFO:tensorflow:loss = 7.94273, step = 3701 (0.114 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 941.876\n", | |
| "INFO:tensorflow:loss = 7.87704, step = 3801 (0.106 sec)\n", | |
| "INFO:tensorflow:global_step/sec: 889.862\n", | |
| "INFO:tensorflow:loss = 7.81412, step = 3901 (0.112 sec)\n", | |
| "INFO:tensorflow:Saving checkpoints for 4000 into /tmp/iris_model/model.ckpt.\n", | |
| "INFO:tensorflow:Loss for final step: 7.75437.\n", | |
| "fit done\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Fit model.\n", | |
| "classifier.train(input_fn=input_fn(training_set),\n", | |
| " steps=1000)\n", | |
| "print('fit done')\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "## Evaluation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Starting evaluation at 2017-09-14-15:16:40\n", | |
| "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-4000\n", | |
| "INFO:tensorflow:Evaluation [1/100]\n", | |
| "INFO:tensorflow:Evaluation [2/100]\n", | |
| "INFO:tensorflow:Evaluation [3/100]\n", | |
| "INFO:tensorflow:Evaluation [4/100]\n", | |
| "INFO:tensorflow:Evaluation [5/100]\n", | |
| "INFO:tensorflow:Evaluation [6/100]\n", | |
| "INFO:tensorflow:Evaluation [7/100]\n", | |
| "INFO:tensorflow:Evaluation [8/100]\n", | |
| "INFO:tensorflow:Evaluation [9/100]\n", | |
| "INFO:tensorflow:Evaluation [10/100]\n", | |
| "INFO:tensorflow:Evaluation [11/100]\n", | |
| "INFO:tensorflow:Evaluation [12/100]\n", | |
| "INFO:tensorflow:Evaluation [13/100]\n", | |
| "INFO:tensorflow:Evaluation [14/100]\n", | |
| "INFO:tensorflow:Evaluation [15/100]\n", | |
| "INFO:tensorflow:Evaluation [16/100]\n", | |
| "INFO:tensorflow:Evaluation [17/100]\n", | |
| "INFO:tensorflow:Evaluation [18/100]\n", | |
| "INFO:tensorflow:Evaluation [19/100]\n", | |
| "INFO:tensorflow:Evaluation [20/100]\n", | |
| "INFO:tensorflow:Evaluation [21/100]\n", | |
| "INFO:tensorflow:Evaluation [22/100]\n", | |
| "INFO:tensorflow:Evaluation [23/100]\n", | |
| "INFO:tensorflow:Evaluation [24/100]\n", | |
| "INFO:tensorflow:Evaluation [25/100]\n", | |
| "INFO:tensorflow:Evaluation [26/100]\n", | |
| "INFO:tensorflow:Evaluation [27/100]\n", | |
| "INFO:tensorflow:Evaluation [28/100]\n", | |
| "INFO:tensorflow:Evaluation [29/100]\n", | |
| "INFO:tensorflow:Evaluation [30/100]\n", | |
| "INFO:tensorflow:Evaluation [31/100]\n", | |
| "INFO:tensorflow:Evaluation [32/100]\n", | |
| "INFO:tensorflow:Evaluation [33/100]\n", | |
| "INFO:tensorflow:Evaluation [34/100]\n", | |
| "INFO:tensorflow:Evaluation [35/100]\n", | |
| "INFO:tensorflow:Evaluation [36/100]\n", | |
| "INFO:tensorflow:Evaluation [37/100]\n", | |
| "INFO:tensorflow:Evaluation [38/100]\n", | |
| "INFO:tensorflow:Evaluation [39/100]\n", | |
| "INFO:tensorflow:Evaluation [40/100]\n", | |
| "INFO:tensorflow:Evaluation [41/100]\n", | |
| "INFO:tensorflow:Evaluation [42/100]\n", | |
| "INFO:tensorflow:Evaluation [43/100]\n", | |
| "INFO:tensorflow:Evaluation [44/100]\n", | |
| "INFO:tensorflow:Evaluation [45/100]\n", | |
| "INFO:tensorflow:Evaluation [46/100]\n", | |
| "INFO:tensorflow:Evaluation [47/100]\n", | |
| "INFO:tensorflow:Evaluation [48/100]\n", | |
| "INFO:tensorflow:Evaluation [49/100]\n", | |
| "INFO:tensorflow:Evaluation [50/100]\n", | |
| "INFO:tensorflow:Evaluation [51/100]\n", | |
| "INFO:tensorflow:Evaluation [52/100]\n", | |
| "INFO:tensorflow:Evaluation [53/100]\n", | |
| "INFO:tensorflow:Evaluation [54/100]\n", | |
| "INFO:tensorflow:Evaluation [55/100]\n", | |
| "INFO:tensorflow:Evaluation [56/100]\n", | |
| "INFO:tensorflow:Evaluation [57/100]\n", | |
| "INFO:tensorflow:Evaluation [58/100]\n", | |
| "INFO:tensorflow:Evaluation [59/100]\n", | |
| "INFO:tensorflow:Evaluation [60/100]\n", | |
| "INFO:tensorflow:Evaluation [61/100]\n", | |
| "INFO:tensorflow:Evaluation [62/100]\n", | |
| "INFO:tensorflow:Evaluation [63/100]\n", | |
| "INFO:tensorflow:Evaluation [64/100]\n", | |
| "INFO:tensorflow:Evaluation [65/100]\n", | |
| "INFO:tensorflow:Evaluation [66/100]\n", | |
| "INFO:tensorflow:Evaluation [67/100]\n", | |
| "INFO:tensorflow:Evaluation [68/100]\n", | |
| "INFO:tensorflow:Evaluation [69/100]\n", | |
| "INFO:tensorflow:Evaluation [70/100]\n", | |
| "INFO:tensorflow:Evaluation [71/100]\n", | |
| "INFO:tensorflow:Evaluation [72/100]\n", | |
| "INFO:tensorflow:Evaluation [73/100]\n", | |
| "INFO:tensorflow:Evaluation [74/100]\n", | |
| "INFO:tensorflow:Evaluation [75/100]\n", | |
| "INFO:tensorflow:Evaluation [76/100]\n", | |
| "INFO:tensorflow:Evaluation [77/100]\n", | |
| "INFO:tensorflow:Evaluation [78/100]\n", | |
| "INFO:tensorflow:Evaluation [79/100]\n", | |
| "INFO:tensorflow:Evaluation [80/100]\n", | |
| "INFO:tensorflow:Evaluation [81/100]\n", | |
| "INFO:tensorflow:Evaluation [82/100]\n", | |
| "INFO:tensorflow:Evaluation [83/100]\n", | |
| "INFO:tensorflow:Evaluation [84/100]\n", | |
| "INFO:tensorflow:Evaluation [85/100]\n", | |
| "INFO:tensorflow:Evaluation [86/100]\n", | |
| "INFO:tensorflow:Evaluation [87/100]\n", | |
| "INFO:tensorflow:Evaluation [88/100]\n", | |
| "INFO:tensorflow:Evaluation [89/100]\n", | |
| "INFO:tensorflow:Evaluation [90/100]\n", | |
| "INFO:tensorflow:Evaluation [91/100]\n", | |
| "INFO:tensorflow:Evaluation [92/100]\n", | |
| "INFO:tensorflow:Evaluation [93/100]\n", | |
| "INFO:tensorflow:Evaluation [94/100]\n", | |
| "INFO:tensorflow:Evaluation [95/100]\n", | |
| "INFO:tensorflow:Evaluation [96/100]\n", | |
| "INFO:tensorflow:Evaluation [97/100]\n", | |
| "INFO:tensorflow:Evaluation [98/100]\n", | |
| "INFO:tensorflow:Evaluation [99/100]\n", | |
| "INFO:tensorflow:Evaluation [100/100]\n", | |
| "INFO:tensorflow:Finished evaluation at 2017-09-14-15:16:41\n", | |
| "INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.966667, average_loss = 0.0706094, global_step = 4000, loss = 2.11828\n", | |
| "\n", | |
| "Accuracy: 0.966667\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Evaluate accuracy.\n", | |
| "accuracy_score = classifier.evaluate(input_fn=input_fn(test_set), \n", | |
| " steps=100)[\"accuracy\"]\n", | |
| "print('\\nAccuracy: {0:f}'.format(accuracy_score))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "source": [ | |
| "# Estimators review\n", | |
| "\n", | |
| "### Load datasets.\n", | |
| "\n", | |
| " training_data = load_csv_with_header()\n", | |
| "\n", | |
| "### define input functions\n", | |
| "\n", | |
| " def input_fn(dataset)\n", | |
| " \n", | |
| "### Define feature columns\n", | |
| "\n", | |
| " feature_columns = [tf.feature_column.numeric_column(feature_name, shape=[4])]\n", | |
| "\n", | |
| "### Create model\n", | |
| "\n", | |
| " classifier = tf.estimator.LinearClassifier()\n", | |
| "\n", | |
| "### Train\n", | |
| "\n", | |
| " classifier.train()\n", | |
| "\n", | |
| "### Evaluate\n", | |
| "\n", | |
| " classifier.evaluate()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Exporting a model for serving predictions\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-4000\n", | |
| "INFO:tensorflow:Assets added to graph.\n", | |
| "INFO:tensorflow:No assets to write.\n", | |
| "INFO:tensorflow:SavedModel written to: /tmp/iris_model/export/1505402201/saved_model.pb\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'/tmp/iris_model/export/1505402201'" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "feature_spec = {'flower_features': tf.FixedLenFeature(shape=[4], dtype=np.float32)}\n", | |
| "\n", | |
| "serving_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)\n", | |
| "\n", | |
| "classifier.export_savedmodel(export_dir_base='/tmp/iris_model' + '/export', \n", | |
| " serving_input_receiver_fn=serving_fn)\n", | |
| "\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 2", | |
| "language": "python", | |
| "name": "python2" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.13" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment