Created
September 28, 2019 12:18
-
-
Save leifvan/1f1a8c957dcf9f3b76bfd3bb687f88b7 to your computer and use it in GitHub Desktop.
ELMo pooling issue
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "ELMo pooling issue", | |
| "provenance": [], | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/leifvan/1f1a8c957dcf9f3b76bfd3bb687f88b7/elmo-pooling-issue.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "JOebV9VKJ9Rv", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import tensorflow_hub as hub\n", | |
| "import numpy as np\n", | |
| "from numpy.linalg import norm" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "0qoSkWezKEdg", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 84 | |
| }, | |
| "outputId": "67e7e0dd-cf1a-4ff4-d6fe-9d6eaee0d244" | |
| }, | |
| "source": [ | |
| "model = hub.Module('https://tfhub.dev/google/elmo/2')\n", | |
| "input_placeholder = tf.placeholder(tf.string, shape=(None,))\n", | |
| "embed_default = model(input_placeholder, signature=\"default\", as_dict=True)['default']\n", | |
| "embed_elmo = model(input_placeholder, signature=\"default\", as_dict=True)['elmo']" | |
| ], | |
| "execution_count": 9, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n" | |
| ], | |
| "name": "stdout" | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n" | |
| ], | |
| "name": "stderr" | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n" | |
| ], | |
| "name": "stdout" | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "INFO:tensorflow:Saver not created because there are no variables in the graph to restore\n" | |
| ], | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ZV6IxVJbKIi8", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 101 | |
| }, | |
| "outputId": "3f869d8e-6dbb-43f9-bd40-a18a4e8fec09" | |
| }, | |
| "source": [ | |
| "token = \"word\"\n", | |
| "sent_1 = \"A short sentence.\"\n", | |
| "sent_2 = \"Another short sentence.\"\n", | |
| "sent_3 = \"A sentence that is quite a bit longer than the previous sentences.\"\n", | |
| "with tf.Session() as sess:\n", | |
| " sess.run(tf.global_variables_initializer())\n", | |
| " vecs_1 = sess.run(embed_default, feed_dict={input_placeholder: [token]})\n", | |
| " vecs_2 = sess.run(embed_default, feed_dict={input_placeholder: [token, sent_1]})\n", | |
| " vecs_3 = sess.run(embed_default, feed_dict={input_placeholder: [token, sent_1, sent_2]})\n", | |
| " vecs_4 = sess.run(embed_default, feed_dict={input_placeholder: [token, sent_1, sent_2, sent_3]})\n", | |
| "\n", | |
| "# we would expect that the vector of `token` is equal for all runs, but it is not\n", | |
| "print('1 vs 2:', norm(vecs_1[0]-vecs_2[0]))\n", | |
| "print('1 vs 3:', norm(vecs_1[0]-vecs_3[0]))\n", | |
| "print('1 vs 4:', norm(vecs_1[0]-vecs_4[0]))\n", | |
| "print('but:')\n", | |
| "print('2 vs 3:', norm(vecs_2[0]-vecs_3[0]))\n", | |
| "\n", | |
| "# the last print hints at the fact, that the change in value depends on the\n", | |
| "# maximum number of tokens " | |
| ], | |
| "execution_count": 10, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "1 vs 2: 3.234873\n", | |
| "1 vs 3: 3.234873\n", | |
| "1 vs 4: 17.791801\n", | |
| "but:\n", | |
| "2 vs 3: 5.242001e-06\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "5VcF6Nj7MBcJ", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 101 | |
| }, | |
| "outputId": "b5c9dd05-674f-497b-e5f2-fa7b396e4ae5" | |
| }, | |
| "source": [ | |
| "# we can fix it with manually calculating the mean pooling\n", | |
| "def pool_vecs(arr, lengths):\n", | |
| " out = []\n", | |
| " for vec, length in zip(arr, lengths):\n", | |
| " out.append(np.sum(vec[:length], axis=0)/length)\n", | |
| " return out\n", | |
| "\n", | |
| "with tf.Session() as sess:\n", | |
| " sess.run(tf.global_variables_initializer())\n", | |
| " raw_vecs_1 = sess.run(embed_elmo, feed_dict={input_placeholder: [token]})\n", | |
| " raw_vecs_2 = sess.run(embed_elmo, feed_dict={input_placeholder: [token, sent_1]})\n", | |
| " raw_vecs_3 = sess.run(embed_elmo, feed_dict={input_placeholder: [token, sent_1, sent_2]})\n", | |
| " raw_vecs_4 = sess.run(embed_elmo, feed_dict={input_placeholder: [token, sent_1, sent_2, sent_3]})\n", | |
| " manual_vecs_1 = pool_vecs(raw_vecs_1, [1])\n", | |
| " manual_vecs_2 = pool_vecs(raw_vecs_2, [1,3])\n", | |
| " manual_vecs_3 = pool_vecs(raw_vecs_3, [1,3,3])\n", | |
| " manual_vecs_4 = pool_vecs(raw_vecs_4, [1,3,3,12])\n", | |
| " \n", | |
| "# now it works as expected\n", | |
| "print('1 vs 2:', norm(manual_vecs_1[0]-manual_vecs_2[0]))\n", | |
| "print('1 vs 3:', norm(manual_vecs_1[0]-manual_vecs_3[0]))\n", | |
| "print('1 vs 4:', norm(manual_vecs_1[0]-manual_vecs_4[0]))\n", | |
| "print()\n", | |
| "print('2 vs 3:', norm(manual_vecs_2[0]-manual_vecs_3[0]))" | |
| ], | |
| "execution_count": 11, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "1 vs 2: 8.415249e-06\n", | |
| "1 vs 3: 8.378789e-06\n", | |
| "1 vs 4: 9.699437e-06\n", | |
| "\n", | |
| "2 vs 3: 5.1718102e-06\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "1gwD625dR_iw", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 302 | |
| }, | |
| "outputId": "1d041ac9-ef81-4e42-cce2-8b39bf14f685" | |
| }, | |
| "source": [ | |
| "# the issue seems to stem from the assumption that elmo outputs zeros for the\n", | |
| "# padded parts of the input, but that is not true\n", | |
| "\n", | |
| "# e.g. in raw_vecs_2 we get 3 vectors to pool for both inputs, because the\n", | |
| "# longest input has 3 tokens.\n", | |
| "print(\"Shape of raw_vecs_2:\")\n", | |
| "print(raw_vecs_2.shape)\n", | |
| "print()\n", | |
| "\n", | |
| "# We would expect that for the first sentence (that consists of one token)\n", | |
| "# the two vectors of the padded part should be 0, but in fact they are not\n", | |
| "print('Valid part of output, should be >0')\n", | |
| "print(norm(raw_vecs_2[0,0]))\n", | |
| "print('Invalid part of output, should be =0')\n", | |
| "print(norm(raw_vecs_2[0,1:]))\n", | |
| "print()\n", | |
| "\n", | |
| "# furthermore, the value of the valid part stays the same, so its not like\n", | |
| "# that the information is somehow distributed over all vectors\n", | |
| "print('distances of the valid parts:')\n", | |
| "print(norm(raw_vecs_1[0,0]-raw_vecs_2[0,0]))\n", | |
| "print(norm(raw_vecs_1[0,0]-raw_vecs_3[0,0]))\n", | |
| "print(norm(raw_vecs_1[0,0]-raw_vecs_4[0,0]))\n", | |
| "print()\n", | |
| "\n", | |
| "# in fact, we can see that this was not considered in the 'default' output of\n", | |
| "# elmo, as it just outputs the sum of all values divided by the length of\n", | |
| "# the valid part (in this case 1)\n", | |
| "elmo_default_vec_4 = np.sum(raw_vecs_4[0], axis=0)\n", | |
| "print('If 0, how we pooled here is exactly as elmos default output does')\n", | |
| "print(norm(elmo_default_vec_4 - vecs_4[0]))\n", | |
| "print('And, again, its not the same as the method considering only the valid part:')\n", | |
| "print(norm(elmo_default_vec_4 - manual_vecs_4[0]))" | |
| ], | |
| "execution_count": 12, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Shape of raw_vecs_2:\n", | |
| "(2, 3, 1024)\n", | |
| "\n", | |
| "Valid part of output, should be >0\n", | |
| "14.730557\n", | |
| "Invalid part of output, should be =0\n", | |
| "2.2874007\n", | |
| "\n", | |
| "distances of the valid parts:\n", | |
| "8.415249e-06\n", | |
| "8.378789e-06\n", | |
| "9.699437e-06\n", | |
| "\n", | |
| "If 0, how we pooled here is exactly as elmos default output does\n", | |
| "0.0\n", | |
| "And, again, its not the same as the method considering only the valid part:\n", | |
| "17.7918\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment