Skip to content

Instantly share code, notes, and snippets.

@william-r-s
Last active May 25, 2020 00:17
Show Gist options
  • Select an option

  • Save william-r-s/05b8cc3289997408bbbe36b37f51d556 to your computer and use it in GitHub Desktop.

Select an option

Save william-r-s/05b8cc3289997408bbbe36b37f51d556 to your computer and use it in GitHub Desktop.
t5-trivia
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "TPU",
"colab": {
"name": "t5-trivia",
"provenance": [],
"collapsed_sections": [
"zrtR2urJV3ST",
"-pFvyrHmm6Mx"
],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/william-r-s/05b8cc3289997408bbbe36b37f51d556/t5-trivia.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YONnGjpAYUdU"
},
"source": [
"\n",
"<a href=\"https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zrtR2urJV3ST"
},
"source": [
"##### Copyright 2020 The T5 Authors\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "DWdCSqJ6WHBh",
"colab": {}
},
"source": [
"# Copyright 2019 The T5 Authors. All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"# =============================================================================="
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zSeyoqE7WMwu"
},
"source": [
"# Fine-Tuning the Text-To-Text Transfer Transformer (T5) for Closed-Book Question Answering\n",
"## _Or: What does T5 know?_\n",
"\n",
"*The following tutorial guides you through the process of fine-tuning a pre-trained T5 model, evaluating its accuracy, and using it for prediction,\n",
"all on a free Google Cloud TPU <a href=\"https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>.*\n",
"\n",
"### Background\n",
"\n",
"T5 was introduced in the paper [_Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer_](https://arxiv.org/abs/1910.10683). In that paper, we provided a comprehensive picture of how we pre-trained a standard text-to-text Transformer model on a large text corpus, achieving state-of-the-art results on many NLP tasks after fine-tuning.\n",
"\n",
"We pre-trained T5 on a mixture of supervised and unsupervised tasks with the majoriy of data coming from an unlabeled dataset we developed called [C4](https://www.tensorflow.org/datasets/catalog/c4). C4 is based on a massive scrape of the web produced by [Common Crawl](https://commoncrawl.org). Loosely speaking, pre-training on C4 ideally gives T5 an understanding of natural language in addition to general world knowledge.\n",
"\n",
"### How can we assess what T5 knows?\n",
"\n",
"As the name implies, T5 is a text-to-text model, which enables us to train it on arbitrary tasks involving a textual input and output. As we showed in our paper, a huge variety of NLP tasks can be cast in this format, including translation, summarization, and even classification and regression tasks.\n",
"\n",
"One way to use this text-to-text framework is on reading comprehension problems, where the model is fed some context along with a question and is trained to predict the question's answer. For example, we might feed the model the text from the Wikipedia article about [Hurrican Connie](https://en.wikipedia.org/wiki/Hurricane_Connie) along with the question \"On what date did Hurricane Connie occur?\" and train the model to predict the answer \"August 3rd, 1955\".\n",
"A related task is open-domain question answering (QA) where the model is not provided with this oracle context. Typically, open-domain QA systems include a mechanism to look up information in an external knowledge source. This setting is similar to an \"open-book\" exam.\n",
"\n",
"In this notebook, we'll be training T5 on a variant of this task which we call **closed-book question answering**. In closed-book QA, we feed the model a question *without any context or access to external knowledge* and train it to predict the answer. Since the model doesn't receive any context, the primary way it can learn to answer these questions is based on the \"knowledge\" it obtained during pre-training. We don't expect T5 to contain super specific information, so we will be focusing on two question-answering datasets which largely include trivia questions (i.e. facts about well-known subjects). [Similar](https://arxiv.org/abs/1909.01066) [investigations](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) have recently been done to test the knowledge stored by BERT and GPT-2.\n",
"\n",
"T5 was not pre-trained on closed-book QA, so in this notebook we'll first create two new tasks and then use the [`t5`](https://github.com/google-research/text-to-text-transfer-transformer) library to fine-tune, evaluate, and obtain predictions from T5. In the end, T5's performance on closed-book QA can give us a sense of what kind (and how much) information T5 managed to learn during pre-training.\n",
"\n",
"## State-of-the-art Results\n",
"We published a [more in-depth investigation](https://arxiv.org/abs/2002.08910) of closed-book QA with T5 where we achieved SOTA on open-domain variants of WebQuestions and TriviaQA in addition to surpisingly strong results on Natural Questions. The code in this notebook is a simplified version of those experiments but still produces good results.\n",
"\n",
"For code to reproduce our best results, please see the [t5_closed_book_qa](https://github.com/google-research/google-research/tree/master/t5_closed_book_qa) repo.\n",
"\n",
"\n",
"### Caveats\n",
"\n",
"* While we provide instructions for running on a [Cloud TPU](https://cloud.google.com/tpu/) via Colab for free, a [Google Cloud Storage (GCS)](http://console.cloud.google.com/storage) bucket is required for storing model parameters and data. The [GCS free tier](https://cloud.google.com/free/) provides 5 GB of storage, which should be enough to train the `large` model and smaller but not the `3B` or `11B` parameter models. You can use part of your initial $300 credit to get more space.\n",
"* The Cloud TPU provided by Colab (a `v2-8`) does not have enough memory to fine-tune the `11B` parameter model. For this model, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "yAb_APDrefs6"
},
"source": [
"# Set Up"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eDeE_yVuHMYg"
},
"source": [
"<h3><a href=\"https://cloud.google.com/tpu/\"><img valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"></a> &nbsp;&nbsp;Train on TPU</h3>\n",
"\n",
"\n",
"\n",
"\n",
" 1. Create a Cloud Storage bucket for your data and model checkpoints at http://console.cloud.google.com/storage, and fill in the `BASE_DIR` parameter in the following form. There is a [free tier](https://cloud.google.com/free/) if you do not yet have an account.\n",
" \n",
" 1. On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
" 1. Run the following cell and follow instructions to:\n",
" * Set up a Colab TPU running environment\n",
" * Verify that you are connected to a TPU device\n",
" * Upload your credentials to TPU to access your GCS bucket\n"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"colab_type": "code",
"id": "xYh-IaN4C7Z1",
"colab": {}
},
"source": [
"print(\"Installing dependencies...\")\n",
"%tensorflow_version 2.x\n",
"!pip install --upgrade \"git+https://github.com/tensorflow/mesh.git#egg=mesh-tensorflow\"\n",
"!pip install --upgrade \"git+https://github.com/google-research/text-to-text-transfer-transformer.git#egg=t5\"\n",
"\n",
"import functools\n",
"import os\n",
"import time\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
"\n",
"import tensorflow.compat.v1 as tf\n",
"import tensorflow_datasets as tfds\n",
"\n",
"import t5\n",
"\n",
"BASE_DIR = \"gs://\" #@param { type: \"string\" }\n",
"if not BASE_DIR or BASE_DIR == \"gs://\":\n",
" raise ValueError(\"You must enter a BASE_DIR.\")\n",
"DATA_DIR = os.path.join(BASE_DIR, \"data\")\n",
"MODELS_DIR = os.path.join(BASE_DIR, \"models\")\n",
"ON_CLOUD = True\n",
"\n",
"\n",
"if ON_CLOUD:\n",
" print(\"Setting up GCS access...\")\n",
" import tensorflow_gcs_config\n",
" from google.colab import auth\n",
" # Set credentials for GCS reading/writing from Colab and TPU.\n",
" TPU_TOPOLOGY = \"2x2\"\n",
" try:\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
" TPU_ADDRESS = tpu.get_master()\n",
" print('Running on TPU:', TPU_ADDRESS)\n",
" except ValueError:\n",
" raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')\n",
" auth.authenticate_user()\n",
" tf.config.experimental_connect_to_host(TPU_ADDRESS)\n",
" tensorflow_gcs_config.configure_gcs_from_colab_auth()\n",
"\n",
"tf.disable_v2_behavior()\n",
"\n",
"# Improve logging.\n",
"from contextlib import contextmanager\n",
"import logging as py_logging\n",
"\n",
"if ON_CLOUD:\n",
" tf.get_logger().propagate = False\n",
" py_logging.root.setLevel('INFO')\n",
"\n",
"@contextmanager\n",
"def tf_verbosity_level(level):\n",
" og_level = tf.logging.get_verbosity()\n",
" tf.logging.set_verbosity(level)\n",
" yield\n",
" tf.logging.set_verbosity(og_level)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dMoJ-G9mqDqa"
},
"source": [
"# Creating new Tasks and Mixture"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zwoLPQhE6bef"
},
"source": [
"Two core components of the T5 library are `Task` and `Mixture` objects.\n",
"\n",
"A `Task` is a dataset along with preprocessing functions and evaluation metrics. A `Mixture` is a collection of `Task` objects along with a mixing rate or a function defining how to compute a mixing rate based on the properties of the constituent `Tasks`.\n",
"\n",
"For this example, we will fine-tune the model to do closed-book question answering."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "152zECujzPMk"
},
"source": [
"### Natural Questions\n",
"\n",
"[Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) is a challenging corpus for open-domain QA. Each example includes a question along with an entire Wikipedia article that may or may not contain its answer. The goal is to produce the correct answer given this context. In our case, we will be ignoring the provided context in hopes that the model will learn to find the answers from the world knowledge it has acquired during pre-training.\n",
"\n",
"Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove questions with multiple answers, and to do a bit of cleaning of the text."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "OjEonhK3zNRu",
"outputId": "52f5a5be-bea0-4ee9-d9e5-37afc8a3c549",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"import gzip\n",
"import json\n",
"\n",
"# Public directory of Natural Questions data on GCS.\n",
"NQ_JSONL_DIR = \"gs://natural_questions/v1.0-simplified/\"\n",
"NQ_SPLIT_FNAMES = {\n",
" \"train\": \"simplified-nq-train.jsonl.gz\",\n",
" \"validation\": \"nq-dev-all.jsonl.gz\"\n",
"}\n",
"nq_counts_path = os.path.join(DATA_DIR, \"nq-counts.json\")\n",
"nq_tsv_path = {\n",
" \"train\": os.path.join(DATA_DIR, \"nq-train.tsv\"),\n",
" \"validation\": os.path.join(DATA_DIR, \"nq-validation.tsv\")\n",
"}\n",
"\n",
"def nq_jsonl_to_tsv(in_fname, out_fname):\n",
"\n",
" def extract_answer(tokens, span):\n",
" \"\"\"Reconstruct answer from token span and remove extra spaces.\"\"\"\n",
" start, end = span[\"start_token\"], span[\"end_token\"] \n",
" ans = \" \".join(tokens[start:end])\n",
" # Remove incorrect spacing around punctuation.\n",
" ans = ans.replace(\" ,\", \",\").replace(\" .\", \".\").replace(\" %\", \"%\")\n",
" ans = ans.replace(\" - \", \"-\").replace(\" : \", \":\").replace(\" / \", \"/\")\n",
" ans = ans.replace(\"( \", \"(\").replace(\" )\", \")\")\n",
" ans = ans.replace(\"`` \", \"\\\"\").replace(\" ''\", \"\\\"\")\n",
" ans = ans.replace(\" 's\", \"'s\").replace(\"s ' \", \"s' \")\n",
" return ans\n",
"\n",
" count = 0\n",
" with tf.io.gfile.GFile(in_fname, \"rb\") as infile,\\\n",
" tf.io.gfile.GFile(out_fname, \"w\") as outfile:\n",
" for line in gzip.open(infile):\n",
" ex = json.loads(line)\n",
" # Remove any examples with more than one answer.\n",
" if len(ex['annotations'][0]['short_answers']) != 1:\n",
" continue\n",
" # Questions in NQ do not include a question mark.\n",
" question = ex[\"question_text\"] + \"?\"\n",
" answer_span = ex['annotations'][0]['short_answers'][0]\n",
" # Handle the two document formats in NQ (tokens or text).\n",
" if \"document_tokens\" in ex:\n",
" tokens = [t[\"token\"] for t in ex[\"document_tokens\"]]\n",
" elif \"document_text\" in ex:\n",
" tokens = ex[\"document_text\"].split(\" \")\n",
" answer = extract_answer(tokens, answer_span)\n",
" # Write this line as <question>\\t<answer>\n",
" outfile.write(\"%s\\t%s\\n\" % (question, answer))\n",
" count += 1\n",
" tf.logging.log_every_n(\n",
" tf.logging.INFO,\n",
" \"Wrote %d examples to %s.\" % (count, out_fname),\n",
" 1000)\n",
" return count\n",
"\n",
"if tf.io.gfile.exists(nq_counts_path):\n",
" # Used cached data and counts.\n",
" tf.logging.info(\"Loading NQ from cache.\")\n",
" num_nq_examples = json.load(tf.io.gfile.GFile(nq_counts_path))\n",
"else:\n",
" # Create TSVs and get counts.\n",
" tf.logging.info(\"Generating NQ TSVs.\")\n",
" num_nq_examples = {}\n",
" for split, fname in NQ_SPLIT_FNAMES.items():\n",
" num_nq_examples[split] = nq_jsonl_to_tsv(\n",
" os.path.join(NQ_JSONL_DIR, fname), nq_tsv_path[split])\n",
" json.dump(num_nq_examples, tf.io.gfile.GFile(nq_counts_path, \"w\"))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Loading NQ from cache.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "R-Ja8akCX1dR"
},
"source": [
"Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "KPOteeqctpzw",
"outputId": "98ed07e4-19b5-4d23-b320-7375505edf7d",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 139
}
},
"source": [
"def nq_dataset_fn(split, shuffle_files=False):\n",
" # We only have one file for each split.\n",
" del shuffle_files\n",
"\n",
" # Load lines from the text file as examples.\n",
" ds = tf.data.TextLineDataset(nq_tsv_path[split])\n",
" # Split each \"<question>\\t<answer>\" example into (question, answer) tuple.\n",
" ds = ds.map(\n",
" functools.partial(tf.io.decode_csv, record_defaults=[\"\", \"\"],\n",
" field_delim=\"\\t\", use_quote_delim=False),\n",
" num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
" # Map each tuple to a {\"question\": ... \"answer\": ...} dict.\n",
" ds = ds.map(lambda *ex: dict(zip([\"question\", \"answer\"], ex)))\n",
" return ds\n",
"\n",
"print(\"A few raw validation examples...\")\n",
"for ex in tfds.as_numpy(nq_dataset_fn(\"validation\").take(5)):\n",
" print(ex)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"A few raw validation examples...\n",
"{'question': b'what do the 3 dots mean in math?', 'answer': b'the therefore sign (\\xe2\\x88\\xb4) is generally used before a logical consequence, such as the conclusion of a syllogism'}\n",
"{'question': b'who is playing the halftime show at super bowl 2016?', 'answer': b'Coldplay with special guest performers Beyonc\\xc3\\xa9 and Bruno Mars'}\n",
"{'question': b'who won the 2017 sports personality of the year?', 'answer': b'Mo Farah'}\n",
"{'question': b'where was the world economic forum held this year?', 'answer': b'Davos, a mountain resort in Graub\\xc3\\xbcnden, in the eastern Alps region of Switzerland'}\n",
"{'question': b'who has made the most premier league appearances?', 'answer': b'Gareth Barry'}\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MCUYT7JmX9Tj"
},
"source": [
"Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the answers are sometimes formatted in odd ways. Finally, we prepend 'trivia question:' to the inputs so that the model knows what task it's trying to solve."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "x8tNn6HMYLMb",
"colab": {}
},
"source": [
"def trivia_preprocessor(ds):\n",
" def normalize_text(text):\n",
" \"\"\"Lowercase and remove quotes from a TensorFlow string.\"\"\"\n",
" text = tf.strings.lower(text)\n",
" text = tf.strings.regex_replace(text,\"'(.*)'\", r\"\\1\")\n",
" return text\n",
"\n",
" def to_inputs_and_targets(ex):\n",
" \"\"\"Map {\"question\": ..., \"answer\": ...}->{\"inputs\": ..., \"targets\": ...}.\"\"\"\n",
" return {\n",
" \"inputs\":\n",
" tf.strings.join(\n",
" [\"trivia question: \", normalize_text(ex[\"question\"])]),\n",
" \"targets\": normalize_text(ex[\"answer\"])\n",
" }\n",
" return ds.map(to_inputs_and_targets, \n",
" num_parallel_calls=tf.data.experimental.AUTOTUNE)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "gm1Pm2aRZ9Ow"
},
"source": [
"Finally, we put everything together to create a `Task`."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "yJyRavOpZ7UW",
"colab": {}
},
"source": [
"t5.data.TaskRegistry.add(\n",
" \"nq_context_free\",\n",
" # Supply a function which returns a tf.data.Dataset.\n",
" dataset_fn=nq_dataset_fn,\n",
" splits=[\"train\", \"validation\"],\n",
" # Supply a function which preprocesses text from the tf.data.Dataset.\n",
" text_preprocessor=[trivia_preprocessor],\n",
" # Use the same vocabulary that we used for pre-training.\n",
" sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,\n",
" # Lowercase targets before computing metrics.\n",
" postprocess_fn=t5.data.postprocessors.lower_text, \n",
" # We'll use accuracy as our evaluation metric.\n",
" metric_fns=[t5.evaluation.metrics.accuracy],\n",
" # Not required, but helps for mixing and auto-caching.\n",
" num_input_examples=num_nq_examples\n",
")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "qe4o_0jFbP-p"
},
"source": [
"Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "I64TqHGxbOJ2",
"colab": {}
},
"source": [
"nq_task = t5.data.TaskRegistry.get(\"nq_context_free\")\n",
"ds = nq_task.get_dataset(split=\"validation\", sequence_length={\"inputs\": 128, \"targets\": 32})\n",
"print(\"A few preprocessed validation examples...\")\n",
"for ex in tfds.as_numpy(ds.take(5)):\n",
" print(ex)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "G1ktIEGePdBr"
},
"source": [
"**Note**: Instead of defining `nq_dataset_fn` and above, we also could have used the `TextLineTask` class with the `parse_tsv` preprocessor for equivalent results as follows:\n",
"\n",
"```py\n",
"t5.data.TaskRegistry.add(\n",
" \"nq_context_free\",\n",
" t5.data.TextLineTask,\n",
" split_to_filepattern=nq_tsv_path,\n",
" text_preprocessor=[\n",
" functools.partial(\n",
" t5.data.preprocessors.parse_tsv,\n",
" field_names=[\"question\", \"answer\"]),\n",
" trivia_preprocessor\n",
" ],\n",
" postprocess_fn=t5.data.postprocessors.lower_text, \n",
" metric_fns=[t5.evaluation.metrics.accuracy],\n",
" num_input_examples=num_nq_examples\n",
")\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "y4_1gpcK9i3W"
},
"source": [
"## TriviaQA\n",
"\n",
"A second dataset we will use is related to [TriviaQA](https://nlp.cs.washington.edu/triviaqa/). It is also intended for reading comprehension, but, once again, we will modify the task here by ignoring the provided context.\n",
"\n",
"Since the dataset has been imported into [TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/trivia_qa), we can let it handle the data parsing for us. It will take a few minutes to download and preprocess the first time, but we'll be able to access it instantly from our data directory afterward."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "mQTQHS94Z0Tq",
"colab": {}
},
"source": [
"ds = tfds.load(\n",
" \"trivia_qa/unfiltered.nocontext\",\n",
" data_dir=DATA_DIR,\n",
" # Download data locally for preprocessing to avoid using GCS space.\n",
" download_and_prepare_kwargs={\"download_dir\": \"./downloads\"})\n",
"print(\"A few raw validation examples...\")\n",
"for ex in tfds.as_numpy(ds[\"validation\"].take(2)):\n",
" print(ex)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "gq5U_rjDb1bn"
},
"source": [
"As with Natural Questions, we need to preprocess the raw examples into `inputs` and `targets` features. We can reuse the `trivia_preprocessor` above, but first we need to convert the TriviaQA examples into the correct format, ignoring the fields we don't need for our task.\n",
"\n",
"We'll then define our `Task` and print out a few preprocessed examples from the validation set.\n",
"\n",
"Note that we do not need to specify the splits or number of examples since that information is provided by TFDS."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "6rU32DjyeLuL",
"colab": {}
},
"source": [
"def tiviaqa_extract_qa(ds):\n",
" def exract_qa(ex):\n",
" return {\n",
" \"question\": ex[\"question\"],\n",
" \"answer\": ex[\"answer\"][\"value\"]\n",
" }\n",
" return ds.map(exract_qa, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
"\n",
"t5.data.TaskRegistry.add(\n",
" \"triviaqa_context_free\",\n",
" # A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.\n",
" t5.data.TfdsTask,\n",
" tfds_name=\"trivia_qa/unfiltered.nocontext:1.1.0\",\n",
" tfds_data_dir=DATA_DIR,\n",
" sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,\n",
" text_preprocessor=[tiviaqa_extract_qa, trivia_preprocessor],\n",
" postprocess_fn=t5.data.postprocessors.lower_text,\n",
" metric_fns=[t5.evaluation.metrics.accuracy]\n",
")\n",
"\n",
"# Load and print a few examples.\n",
"triviaqa_task = t5.data.TaskRegistry.get(\"triviaqa_context_free\")\n",
"ds = triviaqa_task.get_dataset(split=\"validation\", sequence_length={\"inputs\": 128, \"targets\": 32})\n",
"print(\"A few preprocessed validation examples...\")\n",
"for ex in tfds.as_numpy(ds.take(3)):\n",
" print(ex)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wlghm_3rAd-M"
},
"source": [
"## Dataset Mixture\n",
"\n",
"We now create a `Mixture` from the above `Tasks`, which we will fine-tune on.\n",
"\n",
"There are different ways to automatically set the rate (for example, based on the number of examples using `rate_num_examples`), but we will just hardcode an equal mixture for simplicity."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "Zgs-s3eDAU37",
"colab": {}
},
"source": [
"t5.data.MixtureRegistry.remove(\"trivia_all\")\n",
"t5.data.MixtureRegistry.add(\n",
" \"trivia_all\",\n",
" [\"nq_context_free\", \"triviaqa_context_free\"],\n",
" default_rate=1.0\n",
")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "CUkorodCENGw"
},
"source": [
"# Transferring to new Tasks\n",
"\n",
"We are now ready to fine-tune one of the pre-trained T5 models on our new mixture of closed-book QA tasks.\n",
"\n",
"First, we'll instantiate a `Model` object using the model size of your choice. Note that larger models are slower to train and use but will likely achieve higher accuracy. You also may be able to increase accuracy by training longer with more `FINETUNE_STEPS` below.\n",
"\n",
"\n",
"## Caveats\n",
"\n",
"* Due to its memory requirements, you will not be able to train the `11B` parameter model on the TPU provided by Colab. Instead, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).\n",
"* Due to the checkpoint size, you will not be able use the 5GB GCS free tier for the `3B` parameter models. You will need at least 25GB of space, which you can purchase with your $300 of initial credit on GCP.\n",
"* While `large` can achieve decent results, it is recommended that you fine-tune at least the `3B` parameter model.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "syte5n0nnMOC"
},
"source": [
"## Define Model"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"colab_type": "code",
"id": "yGQ-zpgy3raf",
"colab": {}
},
"source": [
"MODEL_SIZE = \"base\" #@param[\"small\", \"base\", \"large\", \"3B\", \"11B\"]\n",
"# Public GCS path for T5 pre-trained model checkpoints\n",
"BASE_PRETRAINED_DIR = \"gs://t5-data/pretrained_models\"\n",
"PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)\n",
"MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)\n",
"\n",
"if ON_CLOUD and MODEL_SIZE == \"3B\":\n",
" tf.logging.warn(\n",
" \"The `3B` model is too large to use with the 5GB GCS free tier. \"\n",
" \"Make sure you have at least 25GB on GCS before continuing.\"\n",
" )\n",
"elif ON_CLOUD and MODEL_SIZE == \"11B\":\n",
" raise ValueError(\n",
" \"The `11B` parameter is too large to fine-tune on the `v2-8` TPU \"\n",
" \"provided by Colab. Please comment out this Error if you're running \"\n",
" \"on a larger TPU.\"\n",
" )\n",
"\n",
"# Set parallelism and batch size to fit on v2-8 TPU (if possible).\n",
"# Limit number of checkpoints to fit within 5GB (if possible).\n",
"model_parallelism, train_batch_size, keep_checkpoint_max = {\n",
" \"small\": (1, 256, 16),\n",
" \"base\": (2, 128, 8),\n",
" \"large\": (8, 64, 4),\n",
" \"3B\": (8, 16, 1),\n",
" \"11B\": (8, 16, 1)}[MODEL_SIZE]\n",
"\n",
"tf.io.gfile.makedirs(MODEL_DIR)\n",
"# The models from our paper are based on the Mesh Tensorflow Transformer.\n",
"model = t5.models.MtfModel(\n",
" model_dir=MODEL_DIR,\n",
" tpu=TPU_ADDRESS,\n",
" tpu_topology=TPU_TOPOLOGY,\n",
" model_parallelism=model_parallelism,\n",
" batch_size=train_batch_size,\n",
" sequence_length={\"inputs\": 128, \"targets\": 32},\n",
" learning_rate_schedule=0.003,\n",
" save_checkpoints_steps=5000,\n",
" keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,\n",
" iterations_per_loop=100,\n",
")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dInuo63ZQrFi"
},
"source": [
"Before we continue, let's load a [TensorBoard](https://www.tensorflow.org/tensorboard) visualizer so that we can keep monitor our progress. The page should automatically update as fine-tuning and evaluation proceed."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "M5mPyYATNsVT",
"colab": {}
},
"source": [
"if ON_CLOUD:\n",
" %reload_ext tensorboard\n",
" import tensorboard as tb\n",
"tb.notebook.start(\"--logdir \" + MODELS_DIR)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DZhAd0U_4B_o"
},
"source": [
"## Fine-tune\n",
"\n",
"We are now ready to fine-tune our model. This will take a while (~2 hours with default settings), so please be patient! The larger the model and more `FINETUNE_STEPS` you use, the longer it will take.\n",
"\n",
"Don't worry, you can always come back later and increase the number of steps, and it will automatically pick up where you left off."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "V7t7a25LBTj9",
"outputId": "c1ce8edc-aeb1-4fda-f58e-a241d65671ae",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"FINETUNE_STEPS = 25000 #@param {type: \"integer\"}\n",
"\n",
"model.finetune(\n",
" mixture_or_task_name=\"trivia_all\",\n",
" pretrained_model_dir=PRETRAINED_DIR,\n",
" finetune_steps=FINETUNE_STEPS\n",
")"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Using config: {'_model_dir': 'gs://wrsaunde_gmail_com_0/t5-trivia/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
"cluster_def {\n",
" job {\n",
" name: \"worker\"\n",
" tasks {\n",
" key: 0\n",
" value: \"10.30.157.114:8470\"\n",
" }\n",
" }\n",
"}\n",
"isolate_session_state: true\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.30.157.114:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.30.157.114:8470', '_evaluation_master': 'grpc://10.30.157.114:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=100, num_shards=None, num_cores_per_replica=1, per_host_input_for_training=4, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver.TPUClusterResolver object at 0x7fab590ecb70>}\n",
"INFO:tensorflow:_TPUContext: eval_on_tpu True\n",
"INFO:tensorflow:Querying Tensorflow master (grpc://10.30.157.114:8470) for TPU system metadata.\n",
"INFO:tensorflow:Initializing TPU system (master: grpc://10.30.157.114:8470) to fetch topology for model parallelism. This might take a while.\n",
"INFO:tensorflow:Found TPU system:\n",
"INFO:tensorflow:*** Num TPU Cores: 8\n",
"INFO:tensorflow:*** Num TPU Workers: 1\n",
"INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 8998188137571915786)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, -8710970079490415130)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 7112257790700540229)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -1173290236248598252)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 5862886712787155308)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -3736053305016653854)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 558450301430855215)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 4824533807886938680)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 5865745603249128026)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 3436610714463999572)\n",
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -53807693441898264)\n",
"INFO:tensorflow:Calling model_fn.\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:absl:Load dataset info from gs://wrsaunde_gmail_com_0/t5-trivia/data/trivia_qa/unfiltered.nocontext/1.1.0\n",
"INFO:absl:Reusing dataset trivia_qa (gs://wrsaunde_gmail_com_0/t5-trivia/data/trivia_qa/unfiltered.nocontext/1.1.0)\n",
"INFO:absl:Constructing tf.data.Dataset for split train, from gs://wrsaunde_gmail_com_0/t5-trivia/data/trivia_qa/unfiltered.nocontext/1.1.0\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:enable_2d_tiling: False\n",
"INFO:tensorflow:num_cores_per_replica: 1\n",
"INFO:tensorflow:computation_shape: [1, 1, 1]\n",
"INFO:tensorflow:num_replicas: 8\n",
"INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0]\n",
" [0 0 1]\n",
" [1 0 0]\n",
" [1 0 1]\n",
" [0 1 0]\n",
" [0 1 1]\n",
" [1 1 0]\n",
" [1 1 1]]]\n",
"INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0]]\n",
"\n",
" [[0 0 1]]\n",
"\n",
" [[0 1 0]]\n",
"\n",
" [[0 1 1]]\n",
"\n",
" [[1 0 0]]\n",
"\n",
" [[1 0 1]]\n",
"\n",
" [[1 1 0]]\n",
"\n",
" [[1 1 1]]]\n",
"INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]\n",
"WARNING:tensorflow:Unrecognized format for tpu physical shape\n",
"INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]\n",
"WARNING:tensorflow:Unrecognized format for tpu physical shape\n",
"WARNING:tensorflow:SimdMeshImpl ignoring devices ['', '', '', '', '', '', '', '']\n",
"INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('d_ff', 'model'), ('ensemble', 'ensemble'), ('experts', 'batch'), ('heads', 'model'), ('vocab', 'model'), ('batch', 'batch')}\n",
"INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7fab59336b00>\n",
"INFO:tensorflow:serialize_num_microbatches: tokens_per_microbatch_per_replica=8192 batch_dim=Dimension(name='batch', size=128) sequence_length={'inputs': 128, 'targets': 32} batch_per_replica=32 num_microbatches=1\n",
"WARNING:tensorflow:Using default tf glorot_uniform_initializer for variable encoder/block_000/layer_000/SelfAttention/relative_attention_bias The initialzer will guess the input and output dimensions based on dimension order.\n",
"WARNING:tensorflow:Using default tf glorot_uniform_initializer for variable decoder/block_000/layer_000/SelfAttention/relative_attention_bias The initialzer will guess the input and output dimensions based on dimension order.\n",
"INFO:tensorflow:Create pnum_tensor\n",
"INFO:tensorflow:Casting <dtype: 'int32'> to float32 for allreduce\n",
"INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_000/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_001/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_002/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_003/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_004/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_005/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_006/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_007/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_008/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_009/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_010/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_001/EncDecAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_001/EncDecAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_001/EncDecAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_001/EncDecAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_002/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable decoder/block_011/layer_002/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_000/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_000/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_000/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_000/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_000/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_000/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_001/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_001/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_001/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_001/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_001/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_001/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_002/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_002/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_002/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_002/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_002/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_002/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_003/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_003/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_003/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_003/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_003/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_003/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_004/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_004/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_004/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_004/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_004/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_004/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_005/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_005/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_005/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_005/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_005/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_005/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_006/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_006/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_006/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_006/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_006/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_006/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_007/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_007/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_007/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_007/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_007/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_007/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_008/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_008/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_008/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_008/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_008/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_008/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_009/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_009/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_009/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_009/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_009/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_009/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_010/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_010/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_010/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_010/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_010/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_010/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_011/layer_000/SelfAttention/k size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_011/layer_000/SelfAttention/o size 589824 slice_size 294912 Shape[heads=768, d_model=768] \n",
"INFO:tensorflow:Variable encoder/block_011/layer_000/SelfAttention/q size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_011/layer_000/SelfAttention/v size 589824 slice_size 294912 Shape[d_model=768, heads=768] \n",
"INFO:tensorflow:Variable encoder/block_011/layer_001/DenseReluDense/wi/kernel size 2359296 slice_size 1179648 Shape[d_model=768, d_ff=3072] \n",
"INFO:tensorflow:Variable encoder/block_011/layer_001/DenseReluDense/wo/kernel size 2359296 slice_size 1179648 Shape[d_ff=3072, d_model=768] \n",
"INFO:tensorflow:Variable shared/embedding size 24576000 slice_size 12288000 Shape[vocab=32000, d_model=768] \n",
"INFO:tensorflow:Variable stacked/encoder/block_000/layer_000/SelfAttention/relative_attention_bias size 768 slice_size 384 Shape[stacked=2, heads=12, buckets=32] \n",
"INFO:tensorflow: encoder/block_000/layer_000/SelfAttention/relative_attention_bias\n",
"INFO:tensorflow: decoder/block_000/layer_000/SelfAttention/relative_attention_bias\n",
"INFO:tensorflow:Variable stacked/encoder/block_000/layer_000/layer_norm/scale size 47616 slice_size 47616 Shape[stacked=62, d_model=768] \n",
"INFO:tensorflow: encoder/block_000/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_000/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_001/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_001/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_002/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_002/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_003/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_003/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_004/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_004/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_005/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_005/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_006/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_006/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_007/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_007/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_008/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_008/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_009/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_009/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_010/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_010/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_011/layer_000/layer_norm/scale\n",
"INFO:tensorflow: encoder/block_011/layer_001/layer_norm/scale\n",
"INFO:tensorflow: encoder/final_layer_norm/scale\n",
"INFO:tensorflow: decoder/block_000/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_000/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_000/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_001/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_001/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_001/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_002/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_002/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_002/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_003/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_003/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_003/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_004/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_004/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_004/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_005/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_005/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_005/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_006/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_006/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_006/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_007/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_007/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_007/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_008/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_008/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_008/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_009/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_009/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_009/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_010/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_010/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_010/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_011/layer_000/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_011/layer_001/layer_norm/scale\n",
"INFO:tensorflow: decoder/block_011/layer_002/layer_norm/scale\n",
"INFO:tensorflow: decoder/final_layer_norm/scale\n",
"INFO:tensorflow:Trainable Variables count: 195 Total size: 222805248 Total slice_size: 111426432 \n",
"INFO:tensorflow:All Variables count: 203 Total size: 223291904 Total slice_size: 111767680 \n",
"INFO:tensorflow:Counters:\n",
"allconcat: 3.28e+04\n",
" allconcat/0: 3.28e+04\n",
" allconcat/0/reshape_op: 3.28e+04\n",
"allreduce: 2.63e+09\n",
" allreduce/[0]: 9.01e+08\n",
" allreduce/[0]/einsum_op: 8.91e+08\n",
" allreduce/[0]/reduce_op: 1e+07\n",
" allreduce/[1]: 1.73e+09\n",
" allreduce/[1]/einsum_op: 1.72e+09\n",
" allreduce/[1]/reduce_op: 1.22e+06\n",
"einsum: 8.07e+12\n",
"einsum_unique: 8.07e+12\n",
"output: 5.02e+10\n",
" output/AddOperation: 6.37e+09\n",
" output/BinaryOpWithBroadcasting: 8.53e+08\n",
" output/BroadcastOperation: 1.52e+09\n",
" output/Constant: 8\n",
" output/EinsumOperation: 2.58e+10\n",
" output/ImportOperation: 4.98e+05\n",
" output/MinMaxOperation: 1.77e+06\n",
" output/OneHotOperation: 8.4e+08\n",
" output/RandomOperation: 2.73e+07\n",
" output/RangeOperation: 1.28e+03\n",
" output/ReduceOperation: 4.15e+08\n",
" output/ReshapeOperation: 3.07e+09\n",
" output/ScalarAddOperation: 8.96e+08\n",
" output/ScalarMultiplyOperation: 9.86e+08\n",
" output/ShiftOperation: 8.19e+03\n",
" output/SlicewiseOperation: 7.95e+09\n",
" output/StackOperation: 2.99e+06\n",
" output/StackedVariable: 2.99e+06\n",
" output/StopGradient: 5.27e+08\n",
" output/UnstackOperation: 2.99e+06\n",
" output/Variable: 8.91e+08\n",
"output_unique: 3.24e+10\n",
" output_unique/AddOperation: 4.28e+09\n",
" output_unique/BinaryOpWithBroadcasting: 8.1e+08\n",
" output_unique/BroadcastOperation: 1.09e+09\n",
" output_unique/Constant: 1\n",
" output_unique/EinsumOperation: 1.68e+10\n",
" output_unique/ImportOperation: 6.22e+04\n",
" output_unique/MinMaxOperation: 2.22e+05\n",
" output_unique/OneHotOperation: 7.93e+08\n",
" output_unique/RandomOperation: 2.1e+07\n",
" output_unique/RangeOperation: 160\n",
" output_unique/ReduceOperation: 3.72e+08\n",
" output_unique/ReshapeOperation: 2.67e+09\n",
" output_unique/ScalarAddOperation: 2.24e+08\n",
" output_unique/ScalarMultiplyOperation: 4.95e+08\n",
" output_unique/ShiftOperation: 4.1e+03\n",
" output_unique/SlicewiseOperation: 4.09e+09\n",
" output_unique/StackOperation: 5.03e+05\n",
" output_unique/StackedVariable: 5.03e+05\n",
" output_unique/StopGradient: 5.27e+08\n",
" output_unique/UnstackOperation: 5.03e+05\n",
" output_unique/Variable: 2.23e+08\n",
"variables: 2.23e+08\n",
" variables/trainable: 2.23e+08\n",
" variables/untrainable: 4.87e+05\n",
"INFO:tensorflow:Initializing variables from gs://t5-data/pretrained_models/base/model.ckpt-999900:\n",
"INFO:tensorflow:Variables in gs://t5-data/pretrained_models/base/model.ckpt-999900 but not in graph:\n",
"INFO:tensorflow:\n",
"INFO:tensorflow:Variables in graph but not in gs://t5-data/pretrained_models/base/model.ckpt-999900:\n",
"INFO:tensorflow:\n",
"INFO:tensorflow:training_loop marked as finished\n",
"WARNING:tensorflow:Reraising captured error\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-25-bb6835b845e8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mmixture_or_task_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"trivia_all\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mpretrained_model_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mPRETRAINED_DIR\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mfinetune_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mFINETUNE_STEPS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m )\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/t5/models/mtf_model.py\u001b[0m in \u001b[0;36mfinetune\u001b[0;34m(self, mixture_or_task_name, finetune_steps, pretrained_model_dir, pretrained_checkpoint_step, split)\u001b[0m\n\u001b[1;32m 293\u001b[0m self.train(mixture_or_task_name, checkpoint_step + finetune_steps,\n\u001b[1;32m 294\u001b[0m \u001b[0minit_checkpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpretrained_model_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_ckpt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 295\u001b[0;31m split=split)\n\u001b[0m\u001b[1;32m 296\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 297\u001b[0m def predict(self, input_file, output_file, checkpoint_steps=-1,\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/t5/models/mtf_model.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, mixture_or_task_name, steps, init_checkpoint, split)\u001b[0m\n\u001b[1;32m 235\u001b[0m utils.train_model(self.estimator(vocabulary, init_checkpoint), vocabulary,\n\u001b[1;32m 236\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sequence_length\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataset_fn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 237\u001b[0;31m steps, self._ensemble_inputs, dataset_split=split)\n\u001b[0m\u001b[1;32m 238\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/mesh_tensorflow/transformer/utils.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(estimator, vocabulary, sequence_length, batch_size, train_dataset_fn, train_steps, ensemble_inputs, dataset_split)\u001b[0m\n\u001b[1;32m 1446\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1447\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1448\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1449\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1450\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, input_fn, hooks, steps, max_steps, saving_listeners)\u001b[0m\n\u001b[1;32m 3081\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3082\u001b[0m \u001b[0mrendezvous\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_done\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'training_loop'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3083\u001b[0;31m \u001b[0mrendezvous\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_errors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3084\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3085\u001b[0m def evaluate(self,\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py\u001b[0m in \u001b[0;36mraise_errors\u001b[0;34m(self, timeout_sec)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwarn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Reraising captured error'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 150\u001b[0;31m \u001b[0msix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtyp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtyp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkept_errors\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/six.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(tp, value, tb)\u001b[0m\n\u001b[1;32m 691\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_assertNotRegex\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 692\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 693\u001b[0;31m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 694\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mPY3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 695\u001b[0m \u001b[0mexec_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmoves\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltins\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"exec\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, input_fn, hooks, steps, max_steps, saving_listeners)\u001b[0m\n\u001b[1;32m 3076\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msteps\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3077\u001b[0m \u001b[0mmax_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_steps\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3078\u001b[0;31m saving_listeners=saving_listeners)\n\u001b[0m\u001b[1;32m 3079\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3080\u001b[0m \u001b[0mrendezvous\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecord_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'training_loop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, input_fn, hooks, steps, max_steps, saving_listeners)\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[0msaving_listeners\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_listeners_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 349\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 350\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Loss for final step: %s.'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36m_train_model\u001b[0;34m(self, input_fn, hooks, saving_listeners)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_model_distributed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1181\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1182\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_model_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1183\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1184\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_train_model_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36m_train_model_default\u001b[0;34m(self, input_fn, hooks, saving_listeners)\u001b[0m\n\u001b[1;32m 1209\u001b[0m \u001b[0mworker_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_hooks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1210\u001b[0m estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN,\n\u001b[0;32m-> 1211\u001b[0;31m self.config)\n\u001b[0m\u001b[1;32m 1212\u001b[0m \u001b[0mglobal_step_tensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_global_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1213\u001b[0m return self._train_with_estimator_spec(estimator_spec, worker_hooks,\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36m_call_model_fn\u001b[0;34m(self, features, labels, mode, config)\u001b[0m\n\u001b[1;32m 2913\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2914\u001b[0m return super(TPUEstimator, self)._call_model_fn(features, labels, mode,\n\u001b[0;32m-> 2915\u001b[0;31m config)\n\u001b[0m\u001b[1;32m 2916\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2917\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_model_fn_for_inference\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36m_call_model_fn\u001b[0;34m(self, features, labels, mode, config)\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1169\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Calling model_fn.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1170\u001b[0;31m \u001b[0mmodel_fn_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1171\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Done calling model_fn.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36m_model_fn\u001b[0;34m(features, labels, mode, config, params)\u001b[0m\n\u001b[1;32m 3204\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mmodel_fn_lib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModeKeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTRAIN\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3205\u001b[0m compile_op, loss, host_call, scaffold_fn, training_hooks = (\n\u001b[0;32m-> 3206\u001b[0;31m _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))\n\u001b[0m\u001b[1;32m 3207\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding_config\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3208\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_default_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36m_train_on_tpu_system\u001b[0;34m(ctx, model_fn_wrapper, dequeue_fn)\u001b[0m\n\u001b[1;32m 3646\u001b[0m \u001b[0mnum_shards\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_replicas\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3647\u001b[0m \u001b[0moutputs_from_all_shards\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3648\u001b[0;31m device_assignment=ctx.device_assignment)\n\u001b[0m\u001b[1;32m 3649\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3650\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py\u001b[0m in \u001b[0;36msplit_compile_and_shard\u001b[0;34m(computation, inputs, num_shards, input_shard_axes, outputs_from_all_shards, output_shard_axes, infeed_queue, device_assignment, name)\u001b[0m\n\u001b[1;32m 1563\u001b[0m \u001b[0minfeed_queue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minfeed_queue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1564\u001b[0m \u001b[0mdevice_assignment\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice_assignment\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1565\u001b[0;31m name=name)\n\u001b[0m\u001b[1;32m 1566\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1567\u001b[0m \u001b[0;31m# There must be at least one shard since num_shards > 0.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/tpu.py\u001b[0m in \u001b[0;36msplit_compile_and_replicate\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 1278\u001b[0m \u001b[0mvscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_custom_getter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcustom_getter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1280\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcomputation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mcomputation_inputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1281\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1282\u001b[0m \u001b[0mvscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_use_resource\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msaved_use_resource\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36mmulti_tpu_train_steps_on_single_shard\u001b[0;34m(replica_id)\u001b[0m\n\u001b[1;32m 3632\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0miterations_per_loop_var\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3633\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msingle_tpu_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3634\u001b[0;31m inputs=[0, _INITIAL_LOSS])\n\u001b[0m\u001b[1;32m 3635\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/training_loop.py\u001b[0m in \u001b[0;36mwhile_loop\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0marray_ops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconstant\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 177\u001b[0m return control_flow_ops.while_loop(\n\u001b[0;32m--> 178\u001b[0;31m condition_wrapper, body_wrapper, inputs, name=\"\", parallel_iterations=1)\n\u001b[0m\u001b[1;32m 179\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mwhile_loop\u001b[0;34m(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)\u001b[0m\n\u001b[1;32m 2764\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_to_collection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraphKeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mWHILE_CONTEXT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloop_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2765\u001b[0m result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,\n\u001b[0;32m-> 2766\u001b[0;31m return_same_structure)\n\u001b[0m\u001b[1;32m 2767\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmaximum_iterations\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2768\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36mBuildLoop\u001b[0;34m(self, pred, body, loop_vars, shape_invariants, return_same_structure)\u001b[0m\n\u001b[1;32m 2246\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_default_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_mutation_lock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2247\u001b[0m original_body_result, exit_vars = self._BuildLoop(\n\u001b[0;32m-> 2248\u001b[0;31m pred, body, original_loop_vars, loop_vars, shape_invariants)\n\u001b[0m\u001b[1;32m 2249\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2250\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py\u001b[0m in \u001b[0;36m_BuildLoop\u001b[0;34m(self, pred, body, original_loop_vars, loop_vars, shape_invariants)\u001b[0m\n\u001b[1;32m 2171\u001b[0m expand_composites=True)\n\u001b[1;32m 2172\u001b[0m \u001b[0mpre_summaries\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_collection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraphKeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_SUMMARY_COLLECTION\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2173\u001b[0;31m \u001b[0mbody_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mpacked_vars_for_body\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2174\u001b[0m \u001b[0mpost_summaries\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_collection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGraphKeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_SUMMARY_COLLECTION\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2175\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_sequence_or_composite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody_result\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/tpu/training_loop.py\u001b[0m in \u001b[0;36mbody_wrapper\u001b[0;34m(*inputs)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0mdequeue_ops\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbody\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdequeue_ops\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;31m# If the computation only returned one value, make it a tuple.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(i, loss)\u001b[0m\n\u001b[1;32m 3631\u001b[0m outputs = training_loop.while_loop(\n\u001b[1;32m 3632\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0miterations_per_loop_var\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3633\u001b[0;31m \u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msingle_tpu_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3634\u001b[0m inputs=[0, _INITIAL_LOSS])\n\u001b[1;32m 3635\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m(step)\u001b[0m\n\u001b[1;32m 1751\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1752\u001b[0m estimator_spec = self._verify_estimator_spec(\n\u001b[0;32m-> 1753\u001b[0;31m self._call_model_fn(features, labels))\n\u001b[0m\u001b[1;32m 1754\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_op\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mestimator_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mestimator_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1755\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py\u001b[0m in \u001b[0;36m_call_model_fn\u001b[0;34m(self, features, labels, is_export_mode)\u001b[0m\n\u001b[1;32m 2029\u001b[0m \u001b[0m_add_item_to_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_CTX_KEY\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muser_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2030\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2031\u001b[0;31m \u001b[0mestimator_spec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2032\u001b[0m if (running_on_cpu and\n\u001b[1;32m 2033\u001b[0m isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/mesh_tensorflow/transformer/utils.py\u001b[0m in \u001b[0;36mmy_model_fn\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mglobal_vars\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mckpt_vars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 652\u001b[0m tf.train.init_from_checkpoint(\n\u001b[0;32m--> 653\u001b[0;31m \u001b[0minit_checkpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrestore_vars\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 654\u001b[0m )\n\u001b[1;32m 655\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py\u001b[0m in \u001b[0;36minit_from_checkpoint\u001b[0;34m(ckpt_dir_or_file, assignment_map)\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m distribution_strategy_context.get_replica_context().merge_call(\n\u001b[0;32m--> 292\u001b[0;31m init_from_checkpoint_fn)\n\u001b[0m\u001b[1;32m 293\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py\u001b[0m in \u001b[0;36mmerge_call\u001b[0;34m(self, merge_fn, args, kwargs)\u001b[0m\n\u001b[1;32m 2418\u001b[0m merge_fn = autograph.tf_convert(\n\u001b[1;32m 2419\u001b[0m merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False)\n\u001b[0;32m-> 2420\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_merge_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmerge_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2422\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_merge_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmerge_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py\u001b[0m in \u001b[0;36m_merge_call\u001b[0;34m(self, merge_fn, args, kwargs)\u001b[0m\n\u001b[1;32m 2425\u001b[0m distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access\n\u001b[1;32m 2426\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2427\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmerge_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_strategy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2428\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2429\u001b[0m \u001b[0m_pop_per_thread_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mag_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mControlStatusCtx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstatus\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mag_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mStatus\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUNSPECIFIED\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 282\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 283\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mismethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(_)\u001b[0m\n\u001b[1;32m 285\u001b[0m \"\"\"\n\u001b[1;32m 286\u001b[0m init_from_checkpoint_fn = lambda _: _init_from_checkpoint(\n\u001b[0;32m--> 287\u001b[0;31m ckpt_dir_or_file, assignment_map)\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdistribution_strategy_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_cross_replica_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0minit_from_checkpoint_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py\u001b[0m in \u001b[0;36m_init_from_checkpoint\u001b[0;34m(ckpt_dir_or_file, assignment_map)\u001b[0m\n\u001b[1;32m 328\u001b[0m \"tensor %s (%s) from checkpoint reader.\" % (\n\u001b[1;32m 329\u001b[0m \u001b[0mvar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_shape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 330\u001b[0;31m \u001b[0mtensor_name_in_ckpt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvariable_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtensor_name_in_ckpt\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 331\u001b[0m ))\n\u001b[1;32m 332\u001b[0m \u001b[0mvar_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Shape of variable shared/embedding:0 ((32000, 768)) doesn't match with shape of tensor shared/embedding ([32128, 768]) from checkpoint reader."
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-pFvyrHmm6Mx"
},
"source": [
"## Expected Results [SPOILER ALERT]"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "i_-7qYemnEHl"
},
"source": [
"Below are the expected accuracies on the Natural Question (NQ) and TriviQA validation sets for various model sizes. The full 11B model produces the exact text of the answer 34.5% and 25.1% of the time on TriviaQA and NQ, respectively. The 3B parameter model, which is the largest that can be trained with a free Cloud TPU in Colab, achieves 29.7% and 23.7%, respectively.\n",
"\n",
"In reality, the model performs better than this since requiring exact match is too strict of a metric, as you’ll see in the examples below. This helps to explain why the model appears to perform better on TriviaQA than NQ, as the latter tends to include more long-form answers extracted from the context.\n",
"\n",
"Please see our [paper on closed-book QA](https://tiny.cc/t5-qa) where achieved even better results.\n",
"\n",
"<img src=\"https://storage.googleapis.com/t5-data/assets/t5_trivia_expected.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eYeciUZ_D7T2"
},
"source": [
"## Evaluate\n",
"\n",
"We now evaluate on the validation sets of the tasks in our mixture. Accuracy results will be logged and added to the TensorBoard above."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "bz6CJRHzNfd3",
"outputId": "62d01d17-2f5c-47ee-87fc-86d619807929",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 649
}
},
"source": [
"# Use a larger batch size for evaluation, which requires less memory.\n",
"model.batch_size = train_batch_size * 4\n",
"model.eval(\n",
" mixture_or_task_name=\"trivia_all\",\n",
" checkpoint_steps=\"all\"\n",
")"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Using config: {'_model_dir': 'gs://wrsaunde_gmail_com_0/t5-trivia/models/3B', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
"cluster_def {\n",
" job {\n",
" name: \"worker\"\n",
" tasks {\n",
" key: 0\n",
" value: \"10.30.157.114:8470\"\n",
" }\n",
" }\n",
"}\n",
"isolate_session_state: true\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.30.157.114:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.30.157.114:8470', '_evaluation_master': 'grpc://10.30.157.114:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=100, num_shards=None, num_cores_per_replica=1, per_host_input_for_training=4, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None, eval_training_input_configuration=2, experimental_host_call_every_n_steps=1), '_cluster': <tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver.TPUClusterResolver object at 0x7fab4f0b0a20>}\n",
"INFO:tensorflow:_TPUContext: eval_on_tpu True\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:absl:Load dataset info from gs://wrsaunde_gmail_com_0/t5-trivia/data/trivia_qa/unfiltered.nocontext/1.1.0\n",
"INFO:absl:Reusing dataset trivia_qa (gs://wrsaunde_gmail_com_0/t5-trivia/data/trivia_qa/unfiltered.nocontext/1.1.0)\n",
"INFO:absl:Constructing tf.data.Dataset for split validation, from gs://wrsaunde_gmail_com_0/t5-trivia/data/trivia_qa/unfiltered.nocontext/1.1.0\n"
],
"name": "stderr"
},
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-23-e650f699e4b4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m model.eval(\n\u001b[1;32m 4\u001b[0m \u001b[0mmixture_or_task_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"trivia_all\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mcheckpoint_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"all\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m )\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/t5/models/mtf_model.py\u001b[0m in \u001b[0;36meval\u001b[0;34m(self, mixture_or_task_name, checkpoint_steps, summary_dir, split)\u001b[0m\n\u001b[1;32m 265\u001b[0m utils.eval_model(self.estimator(vocabulary), vocabulary,\n\u001b[1;32m 266\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sequence_length\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m self._model_dir, dataset_fn, summary_dir, checkpoint_steps)\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir,\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/mesh_tensorflow/transformer/utils.py\u001b[0m in \u001b[0;36meval_model\u001b[0;34m(estimator, vocabulary, sequence_length, batch_size, dataset_split, model_dir, eval_dataset_fn, eval_summary_dir, eval_checkpoint_step)\u001b[0m\n\u001b[1;32m 1572\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1573\u001b[0m \u001b[0;31m# Create list of postprocessed text targets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1574\u001b[0;31m \u001b[0mexamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mex\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtfds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1575\u001b[0m targets = [\n\u001b[1;32m 1576\u001b[0m eval_dataset.postprocess_fn( # pylint:disable=g-complex-comprehension\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/mesh_tensorflow/transformer/utils.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 1572\u001b[0m \u001b[0mds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1573\u001b[0m \u001b[0;31m# Create list of postprocessed text targets\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1574\u001b[0;31m \u001b[0mexamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mex\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtfds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1575\u001b[0m targets = [\n\u001b[1;32m 1576\u001b[0m eval_dataset.postprocess_fn( # pylint:disable=g-complex-comprehension\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_datasets/core/dataset_utils.py\u001b[0m in \u001b[0;36m_graph_dataset_iterator\u001b[0;34m(ds_iter, graph)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 176\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds_item\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 177\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOutOfRangeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 956\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 957\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 958\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 959\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 960\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1179\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1180\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1181\u001b[0;31m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m 1182\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1183\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1357\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1358\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1359\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1360\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1361\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1363\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1364\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1365\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1366\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1367\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1348\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1349\u001b[0m return self._call_tf_sessionrun(options, feed_dict, fetch_list,\n\u001b[0;32m-> 1350\u001b[0;31m target_list, run_metadata)\n\u001b[0m\u001b[1;32m 1351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1352\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[0;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[1;32m 1441\u001b[0m return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,\n\u001b[1;32m 1442\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1443\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1445\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "92dClA1SWwIx"
},
"source": [
"Let's look at a few random predictions from the validation sets. Note that we measure accuracy based on an *exact match* of the predicted answer and the ground-truth answer. As a result, some of the answers are semantically correct but are counted wrong by the exact match score."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "-FuqHRuvxOct",
"colab": {}
},
"source": [
"import random\n",
"\n",
"def print_random_predictions(task_name, n=10):\n",
" \"\"\"Print n predictions from the validation split of a task.\"\"\"\n",
" # Grab the dataset for this task.\n",
" ds = t5.data.TaskRegistry.get(task_name).get_dataset(\n",
" split=\"validation\",\n",
" sequence_length={\"inputs\": 128, \"targets\": 32},\n",
" shuffle=False)\n",
"\n",
" def _prediction_file_to_ckpt(path):\n",
" \"\"\"Extract the global step from a prediction filename.\"\"\"\n",
" return int(path.split(\"_\")[-2])\n",
"\n",
" # Grab the paths of all logged predictions.\n",
" prediction_files = tf.io.gfile.glob(\n",
" os.path.join(\n",
" MODEL_DIR,\n",
" \"validation_eval/%s_*_predictions\" % task_name))\n",
" # Get most recent prediction file by sorting by their step.\n",
" latest_prediction_file = sorted(\n",
" prediction_files, key=_prediction_file_to_ckpt)[-1]\n",
"\n",
" # Collect (inputs, targets, prediction) from the dataset and predictions file\n",
" results = []\n",
" with tf.io.gfile.GFile(latest_prediction_file) as preds:\n",
" for ex, pred in zip(tfds.as_numpy(ds), preds):\n",
" results.append((tf.compat.as_text(ex[\"inputs_plaintext\"]),\n",
" tf.compat.as_text(ex[\"targets_plaintext\"]),\n",
" pred.strip()))\n",
"\n",
" print(\"<== Random predictions for %s using checkpoint %s ==>\\n\" %\n",
" (task_name, \n",
" _prediction_file_to_ckpt(latest_prediction_file)))\n",
"\n",
" for inp, tgt, pred in random.choices(results, k=10):\n",
" print(\"Input:\", inp)\n",
" print(\"Target:\", tgt)\n",
" print(\"Prediction:\", pred)\n",
" print(\"Counted as Correct?\", tgt == pred)\n",
" print()\n",
"\n",
"print_random_predictions(\"triviaqa_context_free\")\n",
"print_random_predictions(\"nq_context_free\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vbqiq2Ab4PJk"
},
"source": [
"## Predict\n",
"\n",
"Now that we have fine-tuned the model, we can feed T5 arbitrary questions and have it predict the answers!\n",
"\n",
"There is a significant amount of overhead in initializing the model so this may take a few minutes to run each time even though the prediction itself is quite fast.\n",
"\n",
"\n",
"To avoid this overhead, you might consider exporting a `SavedModel` and running it on [Cloud ML Engine](https://cloud.google.com/ml-engine/).\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "both",
"colab_type": "code",
"id": "xatHPuCJsPns",
"colab": {}
},
"source": [
"question_1 = \"Where is the Google headquarters located?\" #@param {type:\"string\"}\n",
"question_2 = \"What is the most populous country in the world?\" #@param {type:\"string\"}\n",
"question_3 = \"Who are the 4 members of The Beatles?\" #@param {type:\"string\"}\n",
"question_4 = \"How many teeth do humans have?\" #@param {type:\"string\"}\n",
"\n",
"questions = [question_1, question_2, question_3, question_4]\n",
"\n",
"now = time.time()\n",
"# Write out the supplied questions to text files.\n",
"predict_inputs_path = os.path.join(MODEL_DIR, \"predict_inputs_%d.txt\" % now)\n",
"predict_outputs_path = os.path.join(MODEL_DIR, \"predict_outputs_%d.txt\" % now)\n",
"# Manually apply preprocessing by prepending \"triviaqa question:\".\n",
"with tf.io.gfile.GFile(predict_inputs_path, \"w\") as f:\n",
" for q in questions:\n",
" f.write(\"trivia question: %s\\n\" % q.lower())\n",
"\n",
"# Ignore any logging so that we only see the model's answers to the questions.\n",
"with tf_verbosity_level('ERROR'):\n",
" model.batch_size = 8 # Min size for small model on v2-8 with parallelism 1.\n",
" model.predict(\n",
" input_file=predict_inputs_path,\n",
" output_file=predict_outputs_path,\n",
" # Select the most probable output token at each step.\n",
" temperature=0,\n",
" )\n",
"\n",
"# The output filename will have the checkpoint appended so we glob to get \n",
"# the latest.\n",
"prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + \"*\"))\n",
"print(\"\\nPredictions using checkpoint %s:\\n\" % prediction_files[-1].split(\"-\")[-1])\n",
"with tf.io.gfile.GFile(prediction_files[-1]) as f:\n",
" for q, a in zip(questions, f):\n",
" if q:\n",
" print(\"Q: \" + q)\n",
" print(\"A: \" + a)\n",
" print()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xxElhphBZMD5"
},
"source": [
"# Export Model for Serving\n",
"\n",
"As mentioned in the previous section, exporting a [`SavedModel`](https://www.tensorflow.org/guide/saved_model) can be useful for improving performance during inference or allowing your model to be deployed on a variety of platforms (e.g., TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub).\n",
"\n",
"**Note:** we currently only support exporting a SavedModel that runs on both CPU and GPU, not TPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "l_YuEL9FZ-UR"
},
"source": [
"## Export SavedModel\n",
"\n",
"We first export the SavedModel. We set a batch size of 1 for simplicity, but it may be more efficient to use a larger batch size if you want to handle multiple requests per call.\n",
"\n",
"For 3B and 11B models the export will take approximately 30-45 minutes."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "eWu8lbh3aHjc",
"colab": {}
},
"source": [
"export_dir = os.path.join(MODEL_DIR, \"export\")\n",
"\n",
"model.batch_size = 1 # make one prediction per call\n",
"saved_model_path = model.export(\n",
" export_dir,\n",
" checkpoint_step=-1, # use most recent\n",
" beam_size=1, # no beam search\n",
" temperature=1.0, # sample according to predicted distribution\n",
")\n",
"print(\"Model saved to:\", saved_model_path)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PZ8WXpFkaNTP"
},
"source": [
"## Load SavedModel\n",
"\n",
"One way to test our model is to load it either in eager mode or a TF 1.x session so that we can repeatedly predict from the model without the overhead of loading the graph and weights each time.\n",
"\n",
"We pay the overhead once here, but it shouldn't take more than a few minutes.\n",
"\n",
"\n",
"### Optional: Switch to GPU Runtime\n",
"\n",
"Changing the runtime type to GPU in the `Runtime` menu above before loading the SavedModel will speed up inference by using the GPU instead of CPU.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"cellView": "form",
"colab_type": "code",
"id": "LyBuc4WH-cyB",
"colab": {}
},
"source": [
"#@title Optional: Run this cell to re-initialize if you switched to GPU runtime.\n",
"%tensorflow_version 2.x\n",
"!pip install tensorflow-text\n",
"from google.colab import auth\n",
"auth.authenticate_user()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "1TpeMFGhaN7r",
"colab": {}
},
"source": [
"import tensorflow as tf\n",
"import tensorflow_text # Required to run exported model.\n",
"\n",
"def load_predict_fn(model_path):\n",
" if tf.executing_eagerly():\n",
" print(\"Loading SavedModel in eager mode.\")\n",
" imported = tf.saved_model.load(model_path, [\"serve\"])\n",
" return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()\n",
" else:\n",
" print(\"Loading SavedModel in tf 1.x graph mode.\")\n",
" tf.compat.v1.reset_default_graph()\n",
" sess = tf.compat.v1.Session()\n",
" meta_graph_def = tf.compat.v1.saved_model.load(sess, [\"serve\"], model_path)\n",
" signature_def = meta_graph_def.signature_def[\"serving_default\"]\n",
" return lambda x: sess.run(\n",
" fetches=signature_def.outputs[\"outputs\"].name, \n",
" feed_dict={signature_def.inputs[\"input\"].name: x}\n",
" )\n",
"\n",
"predict_fn = load_predict_fn(saved_model_path)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YbGC8xefaYtV"
},
"source": [
"## Predict\n",
"\n",
"We can now call the predict method with different inputs each time and relatively quickly get results."
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "3WA0BYI9abgv",
"colab": {}
},
"source": [
"def answer(question):\n",
" return predict_fn([question])[0].decode('utf-8')\n",
"\n",
"for question in [\"trivia question: where is the google headquarters?\",\n",
" \"trivia question: what is the most populous country in the world?\",\n",
" \"trivia question: who are the 4 members of the beatles?\",\n",
" \"trivia question: how many teeth do humans have?\"]:\n",
" print(answer(question))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ePXPEQhDafmV"
},
"source": [
"## Deploy SavedModel\n",
"\n",
"You can now deploy your SavedModel for serving (e.g., with [TensorFlow Serving](https://www.tensorflow.org/tfx/tutorials/serving/rest_simple))."
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment