Created
February 5, 2020 16:24
-
-
Save SharanSMenon/89137be312ad3bb2d9402ed1659a42ca to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "question-classification-cnn.ipynb", | |
| "provenance": [] | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "bFGMr3CQGFJz", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import torch\n", | |
| "from torchtext import data\n", | |
| "from torchtext import datasets\n", | |
| "import random\n", | |
| "\n", | |
| "SEED = 1234\n", | |
| "\n", | |
| "torch.manual_seed(SEED)\n", | |
| "torch.backends.cudnn.deterministic = True\n", | |
| "\n", | |
| "TEXT = data.Field(tokenize = 'spacy')\n", | |
| "LABEL = data.LabelField()\n", | |
| "\n", | |
| "train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=True)\n", | |
| "\n", | |
| "train_data, valid_data = train_data.split(random_state = random.seed(SEED))" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "oDMpsS7oGH2V", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "c8814a4d-15fb-4b9f-a5d8-d8243096c3d8" | |
| }, | |
| "source": [ | |
| "vars(train_data[-1])" | |
| ], | |
| "execution_count": 4, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{'label': 'DESC:def', 'text': ['What', 'is', 'a', 'Cartesian', 'Diver', '?']}" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "laidwNkvGRwX", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 51 | |
| }, | |
| "outputId": "bc61e40f-82c2-4f3a-f091-07e61f507608" | |
| }, | |
| "source": [ | |
| "MAX_VOCAB_SIZE = 45_000\n", | |
| "\n", | |
| "TEXT.build_vocab(train_data, \n", | |
| " max_size = MAX_VOCAB_SIZE, \n", | |
| " vectors = \"glove.6B.100d\", \n", | |
| " unk_init = torch.Tensor.normal_)\n", | |
| "\n", | |
| "LABEL.build_vocab(train_data)" | |
| ], | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| ".vector_cache/glove.6B.zip: 862MB [06:32, 2.20MB/s] \n", | |
| "100%|█████████▉| 398009/400000 [00:16<00:00, 24050.78it/s]" | |
| ], | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "07mXbobyGghA", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 54 | |
| }, | |
| "outputId": "e86b6e4b-a093-4015-cd66-0915b7e75bfd" | |
| }, | |
| "source": [ | |
| "print(LABEL.vocab.stoi)" | |
| ], | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "defaultdict(<function _default_unk_index at 0x7f7fb0960158>, {'HUM:ind': 0, 'LOC:other': 1, 'DESC:def': 2, 'NUM:count': 3, 'DESC:manner': 4, 'DESC:desc': 5, 'NUM:date': 6, 'DESC:reason': 7, 'HUM:gr': 8, 'ENTY:other': 9, 'ENTY:cremat': 10, 'LOC:country': 11, 'LOC:city': 12, 'ENTY:animal': 13, 'ENTY:dismed': 14, 'ENTY:food': 15, 'ENTY:termeq': 16, 'ABBR:exp': 17, 'NUM:money': 18, 'NUM:period': 19, 'LOC:state': 20, 'ENTY:event': 21, 'ENTY:sport': 22, 'HUM:desc': 23, 'NUM:other': 24, 'ENTY:product': 25, 'ENTY:color': 26, 'ENTY:techmeth': 27, 'ENTY:substance': 28, 'ENTY:word': 29, 'ENTY:veh': 30, 'NUM:dist': 31, 'HUM:title': 32, 'NUM:perc': 33, 'LOC:mount': 34, 'ENTY:body': 35, 'ABBR:abb': 36, 'ENTY:lang': 37, 'ENTY:instru': 38, 'ENTY:plant': 39, 'NUM:code': 40, 'NUM:temp': 41, 'NUM:volsize': 42, 'NUM:weight': 43, 'ENTY:letter': 44, 'ENTY:symbol': 45, 'ENTY:religion': 46, 'NUM:ord': 47, 'NUM:speed': 48, 'ENTY:currency': 49})\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "gKQbNTskIQGp", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "BATCH_SIZE = 64\n", | |
| "\n", | |
| "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
| "\n", | |
| "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n", | |
| " (train_data, valid_data, test_data), \n", | |
| " batch_size = BATCH_SIZE, \n", | |
| " device = device)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "PjCKXbh5IW-D", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "a41eb153-a353-494d-b2f3-e4133b6b4c77" | |
| }, | |
| "source": [ | |
| "device" | |
| ], | |
| "execution_count": 8, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "device(type='cuda')" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 8 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "FjdHxNeGIX3E", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ZqBeymR9IaKS", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "class CNN(nn.Module):\n", | |
| " def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n", | |
| " dropout, pad_idx): \n", | |
| " super().__init__() \n", | |
| " self.embedding = nn.Embedding(vocab_size, embedding_dim) \n", | |
| " self.convs = nn.ModuleList([\n", | |
| " nn.Conv2d(in_channels = 1, \n", | |
| " out_channels = n_filters, \n", | |
| " kernel_size = (fs, embedding_dim)) \n", | |
| " for fs in filter_sizes\n", | |
| " ]) \n", | |
| " self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim) \n", | |
| " self.dropout = nn.Dropout(dropout)\n", | |
| " \n", | |
| " def forward(self, text):\n", | |
| " text = text.permute(1, 0)\n", | |
| " embedded = self.embedding(text)\n", | |
| " embedded = embedded.unsqueeze(1)\n", | |
| " conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n", | |
| " pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n", | |
| " cat = self.dropout(torch.cat(pooled, dim = 1))\n", | |
| " return self.fc(cat)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "YLUvwoEmIkZ4", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "INPUT_DIM = len(TEXT.vocab)\n", | |
| "EMBEDDING_DIM = 100\n", | |
| "N_FILTERS = 100\n", | |
| "FILTER_SIZES = [2,3,4]\n", | |
| "OUTPUT_DIM = len(LABEL.vocab)\n", | |
| "DROPOUT = 0.5\n", | |
| "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n", | |
| "\n", | |
| "model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "5Asdh-TZImgX", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 187 | |
| }, | |
| "outputId": "0cb48953-4a8c-4051-8058-ffec3d6e4fdf" | |
| }, | |
| "source": [ | |
| "print(model)" | |
| ], | |
| "execution_count": 12, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "CNN(\n", | |
| " (embedding): Embedding(7503, 100)\n", | |
| " (convs): ModuleList(\n", | |
| " (0): Conv2d(1, 100, kernel_size=(2, 100), stride=(1, 1))\n", | |
| " (1): Conv2d(1, 100, kernel_size=(3, 100), stride=(1, 1))\n", | |
| " (2): Conv2d(1, 100, kernel_size=(4, 100), stride=(1, 1))\n", | |
| " )\n", | |
| " (fc): Linear(in_features=300, out_features=50, bias=True)\n", | |
| " (dropout): Dropout(p=0.5, inplace=False)\n", | |
| ")\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "A6YCAJKmIn5i", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "75060a64-db33-4a6f-ddda-7f4e0c1520e0" | |
| }, | |
| "source": [ | |
| "def count_parameters(model):\n", | |
| " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | |
| "\n", | |
| "print(f'The model has {count_parameters(model):,} trainable parameters')" | |
| ], | |
| "execution_count": 13, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "The model has 855,650 trainable parameters\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Q3OmgQkzIpuq", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 136 | |
| }, | |
| "outputId": "0a84a28c-4188-47e1-83ff-1e89de0c4133" | |
| }, | |
| "source": [ | |
| "pretrained_embeddings = TEXT.vocab.vectors\n", | |
| "\n", | |
| "model.embedding.weight.data.copy_(pretrained_embeddings)" | |
| ], | |
| "execution_count": 14, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n", | |
| " [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n", | |
| " [ 0.1638, 0.6046, 1.0789, ..., -0.3140, 0.1844, 0.3624],\n", | |
| " ...,\n", | |
| " [-0.3110, -0.3398, 1.0308, ..., 0.5317, 0.2836, -0.0640],\n", | |
| " [ 0.0091, 0.2810, 0.7356, ..., -0.7508, 0.8967, -0.7631],\n", | |
| " [ 0.4306, 1.2011, 0.0873, ..., 0.8817, 0.3722, 0.3458]])" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 14 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ZmzT0-vNIrga", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n", | |
| "\n", | |
| "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n", | |
| "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "miLY5rvXIs_s", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import torch.optim as optim\n", | |
| "\n", | |
| "optimizer = optim.Adam(model.parameters())\n", | |
| "\n", | |
| "criterion = nn.CrossEntropyLoss()\n", | |
| "\n", | |
| "model = model.to(device)\n", | |
| "criterion = criterion.to(device)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "dC5vB7DYIuyV", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "def categorical_accuracy(preds, y):\n", | |
| " max_preds = preds.argmax(dim = 1, keepdim = True)\n", | |
| " correct = max_preds.squeeze(1).eq(y)\n", | |
| " return correct.sum() / torch.FloatTensor([y.shape[0]])" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "r_ING1JEIyIx", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "def train(model, iterator, optimizer, criterion):\n", | |
| " epoch_loss = 0\n", | |
| " epoch_acc = 0\n", | |
| " model.train()\n", | |
| " for batch in iterator:\n", | |
| " optimizer.zero_grad()\n", | |
| " predictions = model(batch.text)\n", | |
| " loss = criterion(predictions, batch.label)\n", | |
| " acc = categorical_accuracy(predictions, batch.label)\n", | |
| " loss.backward()\n", | |
| " optimizer.step()\n", | |
| " epoch_loss += loss.item()\n", | |
| " epoch_acc += acc.item()\n", | |
| " return epoch_loss / len(iterator), epoch_acc / len(iterator)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "TtfacWc-I4D4", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "def evaluate(model, iterator, criterion):\n", | |
| " epoch_loss = 0\n", | |
| " epoch_acc = 0\n", | |
| " model.eval()\n", | |
| " with torch.no_grad():\n", | |
| " for batch in iterator:\n", | |
| " predictions = model(batch.text)\n", | |
| " loss = criterion(predictions, batch.label)\n", | |
| " acc = categorical_accuracy(predictions, batch.label)\n", | |
| " epoch_loss += loss.item()\n", | |
| " epoch_acc += acc.item()\n", | |
| " return epoch_loss / len(iterator), epoch_acc / len(iterator)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "s1gqEvvjI9OY", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import time\n", | |
| "\n", | |
| "def epoch_time(start_time, end_time):\n", | |
| " elapsed_time = end_time - start_time\n", | |
| " elapsed_mins = int(elapsed_time / 60)\n", | |
| " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", | |
| " return elapsed_mins, elapsed_secs" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "dlTp7h7PI-yw", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 170 | |
| }, | |
| "outputId": "68c7b04c-9402-4d3f-a018-c6c31cbc5c1a" | |
| }, | |
| "source": [ | |
| "N_EPOCHS = 18\n", | |
| "\n", | |
| "best_valid_loss = float('inf')\n", | |
| "\n", | |
| "for epoch in range(N_EPOCHS):\n", | |
| "\n", | |
| " start_time = time.time()\n", | |
| " \n", | |
| " train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n", | |
| " valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n", | |
| " \n", | |
| " end_time = time.time()\n", | |
| "\n", | |
| " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", | |
| " \n", | |
| " if valid_loss < best_valid_loss:\n", | |
| " best_valid_loss = valid_loss\n", | |
| " torch.save(model.state_dict(), 'tut5-model.pt')\n", | |
| " \n", | |
| " print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n", | |
| " print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n", | |
| " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')" | |
| ], | |
| "execution_count": 24, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Epoch: 01 | Epoch Time: 0m 0s\n", | |
| "\tTrain Loss: 0.166 | Train Acc: 96.91%\n", | |
| "\t Val. Loss: 0.924 | Val. Acc: 76.97%\n", | |
| "Epoch: 02 | Epoch Time: 0m 0s\n", | |
| "\tTrain Loss: 0.138 | Train Acc: 97.78%\n", | |
| "\t Val. Loss: 0.926 | Val. Acc: 76.14%\n", | |
| "Epoch: 03 | Epoch Time: 0m 0s\n", | |
| "\tTrain Loss: 0.122 | Train Acc: 97.84%\n", | |
| "\t Val. Loss: 0.918 | Val. Acc: 76.74%\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "qV4_NbJ_JCn6", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "3d954780-057d-41e5-ffd9-eefaf50904ef" | |
| }, | |
| "source": [ | |
| "model.load_state_dict(torch.load('tut5-model.pt'))\n", | |
| "\n", | |
| "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n", | |
| "\n", | |
| "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')" | |
| ], | |
| "execution_count": 25, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Test Loss: 0.888 | Test Acc: 75.78%\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "xdjKc2K6JSLl", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import spacy\n", | |
| "nlp = spacy.load('en')\n", | |
| "\n", | |
| "def predict_class(model, sentence, min_len = 4):\n", | |
| " model.eval()\n", | |
| " tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n", | |
| " if len(tokenized) < min_len:\n", | |
| " tokenized += ['<pad>'] * (min_len - len(tokenized))\n", | |
| " indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n", | |
| " tensor = torch.LongTensor(indexed).to(device)\n", | |
| " tensor = tensor.unsqueeze(1)\n", | |
| " preds = model(tensor)\n", | |
| " max_preds = preds.argmax(dim = 1)\n", | |
| " return max_preds.item()" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "JUV5rLcIJUVh", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "cd43a816-433e-4d71-ffb3-974f75f6cda1" | |
| }, | |
| "source": [ | |
| "pred_class = predict_class(model, \"Who is Keyser Söze?\")\n", | |
| "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
| ], | |
| "execution_count": 27, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted class is: 23 = HUM:desc\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "lfqNlaTkJWf0", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "b90ee9eb-4bc6-4a01-95cf-6de89bae0036" | |
| }, | |
| "source": [ | |
| "pred_class = predict_class(model, \"How many minutes are in six hundred and eighteen hours?\")\n", | |
| "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
| ], | |
| "execution_count": 28, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted class is: 3 = NUM:count\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "HNPZ59e_JbXx", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "93383363-e477-40fa-f303-b202995fd00c" | |
| }, | |
| "source": [ | |
| "pred_class = predict_class(model, \"What continent is Bulgaria in?\")\n", | |
| "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
| ], | |
| "execution_count": 29, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted class is: 1 = LOC:other\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "M3l56FQvJdLq", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "a37dd7bd-ad6e-47ff-ae9d-c0c633139256" | |
| }, | |
| "source": [ | |
| "pred_class = predict_class(model, \"What does WYSIWYG stand for?\")\n", | |
| "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
| ], | |
| "execution_count": 30, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted class is: 17 = ABBR:exp\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "swn5acMVJe_U", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "28548e80-f2e0-46b1-fc68-e8451be62070" | |
| }, | |
| "source": [ | |
| "pred_class = predict_class(model, \"Where is New York City?\")\n", | |
| "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" | |
| ], | |
| "execution_count": 31, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Predicted class is: 1 = LOC:other\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Gm65tDkyJj7Q", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment