Created
November 5, 2020 07:31
-
-
Save o8r/95f239ef597c87c9b4ff8967a0c47670 to your computer and use it in GitHub Desktop.
MNIST example using PyTorch-Lightning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# MNIST\n", | |
| "\n", | |
| "pytorch_lightningを用いたMNISTの実装" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## データのセットアップ" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### LightningDataModule派生クラス" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pytorch_lightning as pl\n", | |
| "from torch.utils.data import random_split, DataLoader\n", | |
| "\n", | |
| "# Note - you must have torchvision installed for this example\n", | |
| "from torchvision.datasets import MNIST\n", | |
| "from torchvision import transforms\n", | |
| "\n", | |
| "\n", | |
| "class MNISTDataModule(pl.LightningDataModule):\n", | |
| "\n", | |
| " def __init__(self, data_dir: str = './'):\n", | |
| " super().__init__()\n", | |
| " self.data_dir = data_dir\n", | |
| " self.transform = transforms.Compose([\n", | |
| " transforms.ToTensor(),\n", | |
| " transforms.Normalize((0.1307,), (0.3081,))\n", | |
| " ])\n", | |
| "\n", | |
| " # self.dims is returned when you call dm.size()\n", | |
| " # Setting default dims here because we know them.\n", | |
| " # Could optionally be assigned dynamically in dm.setup()\n", | |
| " self.dims = (1, 28, 28)\n", | |
| "\n", | |
| " def prepare_data(self):\n", | |
| " # download\n", | |
| " MNIST(self.data_dir, train=True, download=True)\n", | |
| " MNIST(self.data_dir, train=False, download=True)\n", | |
| "\n", | |
| " def setup(self, stage=None):\n", | |
| "\n", | |
| " # Assign train/val datasets for use in dataloaders\n", | |
| " if stage == 'fit' or stage is None:\n", | |
| " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", | |
| " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", | |
| "\n", | |
| " # Optionally...\n", | |
| " # self.dims = tuple(self.mnist_train[0][0].shape)\n", | |
| "\n", | |
| " # Assign test dataset for use in dataloader(s)\n", | |
| " if stage == 'test' or stage is None:\n", | |
| " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", | |
| "\n", | |
| " # Optionally...\n", | |
| " # self.dims = tuple(self.mnist_test[0][0].shape)\n", | |
| "\n", | |
| " def train_dataloader(self):\n", | |
| " return DataLoader(self.mnist_train, batch_size=32)\n", | |
| "\n", | |
| " def val_dataloader(self):\n", | |
| " return DataLoader(self.mnist_val, batch_size=32)\n", | |
| "\n", | |
| " def test_dataloader(self):\n", | |
| " return DataLoader(self.mnist_test, batch_size=32)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### データセットの取得(ダウンロード)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dm = MNISTDataModule('./data/')\n", | |
| "\n", | |
| "# Download dataset\n", | |
| "dm.prepare_data()\n", | |
| "\n", | |
| "# Get the first batch\n", | |
| "dm.setup('fit')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### データの確認\n", | |
| "\n", | |
| "``x_train``の各要素は,28x28サイズの1チャネル(グレースケール)の画像だが,28x28=784要素の1次元配列のように保持されている.28x28の形式に変換して表示する." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.image.AxesImage at 0x1ab3e444640>" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAANdUlEQVR4nO3db6xU9Z3H8c9n2RIJ7QP8gyLFbUWNWzeRbsiNSYnpSkpQH0BN3JSYDRrsraZKGzFZog8wwQeNsaAmWnPrn1LTpampVEgaV0JItDEqaFiFkla2YVv0CkvU1D6RBb774B6aW7zzm+ucmTlz+b5fyc3MnO+cc74Z+XjOzPnzc0QIwJnv75puAEB/EHYgCcIOJEHYgSQIO5DE3/dzZbb56R/osYjwRNNrbdltL7X9O9sHbK+tsywAveVOj7Pbnibp95K+IemQpF2SVkTEbwvzsGUHeqwXW/YhSQci4g8RcUzSzyUtq7E8AD1UJ+xzJf1p3OtD1bS/YXvY9m7bu2usC0BNdX6gm2hX4VO76RExImlEYjceaFKdLfshSfPGvf6ipPfqtQOgV+qEfZekS21/2fZ0Sd+StLU7bQHoto534yPiuO07JP2npGmSnoqIfV3rDEBXdXzoraOV8Z0d6LmenFQDYOog7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTR8fjskmT7oKSPJZ2QdDwiFnajKQDdVyvslX+JiKNdWA6AHmI3HkiibthD0ou237A9PNEbbA/b3m17d811AajBEdH5zPaFEfGe7dmStku6MyJeKry/85UBmJSI8ETTa23ZI+K96vGIpC2ShuosD0DvdBx22zNtf+HUc0lLJO3tVmMAuqvOr/HnS9pi+9Ry/iMiXuhKVwC6rtZ39s+8Mr6zAz3Xk+/sAKYOwg4kQdiBJAg7kARhB5LoxoUwaFh1+HNCV1xxRXHehx56qFhfvHhxsb53b/nUiqVLl7asvfvuu8V50V1s2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCa56mwIuu+yyYv2RRx5pWVuyZEm32/lMnn322Za12267rdayr7nmmmJ99uzZLWulviTp6NGpew9VrnoDkiPsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zj4AzjvvvGL9lVdeKdbnz5/fsnbixInivPv27SvWP/roo2L96quvLtZLPvzww2L95MmTxfo555zT8bqHhyccreyvnnjiiY6X3TSOswPJEXYgCcIOJEHYgSQIO5AEYQeSIOxAEtw3vg8uv/zyYv2FF8ojXV900UXFeulciYcffrg47913312sX3vttcX60NBQsX7WWWe1rM2aNas4b13Hjh1rWXv55Zd7uu5B1HbLbvsp20ds7x037Wzb222/Uz329r8agNomsxv/E0mnD+uxVtKOiLhU0o7qNYAB1jbsEfGSpA9Om7xM0qbq+SZJy7vbFoBu6/Q7+/kRMSpJETFqu+XNvmwPSyqfiAyg53r+A11EjEgakbgQBmhSp4feDtueI0nV45HutQSgFzoN+1ZJK6vnKyU93512APRK2+vZbW+W9HVJ50o6LGmdpF9J+oWkiyT9UdKNEXH6j3gTLSvlbvxzzz1XrC9fvrzW8jdt2tSydsstt9Radjvt7ktfGh/+/vvvL847Y8aMjno65emnn25ZW7VqVa1lD7JW17O3/c4eEStalBbX6ghAX3G6LJAEYQeSIOxAEoQdSIKwA0lwiWsfXHLJJbXm37BhQ7F+77331lp+HS+++GKx/sknn7SsTZs2rda6Dxw4UKw3+bkMIrbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEx9n74M477yzW2w17/MADDxTrpWPZvdbuWPnq1atb1qZPn16c9/jx48X6mjVrivX333+/WM+GLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJNH2VtJdXVnSW0mfydqdA9BuSOiS9evXF+vr1q3reNlnsla3kmbLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJwdRbfeemuxvnHjxmJ95syZLWvPPPNMcd6bb765WO/nv92ppOPj7Lafsn3E9t5x0+6z/a7tPdXfdd1sFkD3TWY3/ieSlk4wfWNELKj+ft3dtgB0W9uwR8RLkj7oQy8AeqjOD3R32H6r2s2f1epNtodt77a9u8a6ANTUadh/JGm+pAWSRiX9sNUbI2IkIhZGxMIO1wWgCzoKe0QcjogTEXFS0o8lDXW3LQDd1lHYbc8Z9/Kbkva2ei+AwdD2OLvtzZK+LulcSYclrateL5AUkg5K+k5EjLZdGcfZB87cuXOL9VdffbXW/KV7t99www211o2JtTrO3naQiIhYMcHkJ2t3BKCvOF0WSIKwA0kQdiAJwg4kQdiBJBiyObl2l7C2O7TWzk033dSyxqG1/mLLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJz9DLdo0aJife3atbWWv23btmJ9165dtZaP7mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMGTzGWDevHktazt37izOe/HFFxfro6PlO4RfeeWVxfrRo0eLdXRfx0M2AzgzEHYgCcIOJEHYgSQIO5AEYQeSIOxAElzPPgVceOGFxfqDDz7YstbuOPrx48eL9fXr1xfrHEefOtpu2W3Ps73T9n7b+2x/r5p+tu3ttt+pHmf1vl0AnZrMbvxxSWsi4h8lXSXpu7a/ImmtpB0RcamkHdVrAAOqbdgjYjQi3qyefyxpv6S5kpZJ2lS9bZOk5T3qEUAXfKbv7La/JOmrkl6TdH5EjEpj/0OwPbvFPMOShmv2CaCmSYfd9ucl/VLS9yPiz/aE59p/SkSMSBqplsGFMEBDJnXozfbnNBb0n0XEc9Xkw7bnVPU5ko70pkUA3dB2y+6xTfiTkvZHxIZxpa2SVkr6QfX4fE86hIaGhor1G2+8seNlP/bYY8X6448/3vGyMVgmsxv/NUn/Jult23uqafdoLOS/sL1K0h8ldf4vDkDPtQ17RPxGUqsv6Iu72w6AXuF0WSAJwg4kQdiBJAg7kARhB5LgEtcBcNVVVxXr7Y6Fl+zdu7dY37x5c8fLxtTClh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA4ex/MmDGjWL/99tuL9QsuuKBYLw27vWXLluK8r732WrGOMwdbdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwqVjtF1fWdIRYa6//vpifdu2bbWWXxp2efr06bWWjaknIia8GzRbdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IYjLjs8+T9FNJF0g6KWkkIh62fZ+kb0v63+qt90TEr3vV6FS2YMGCni7/rrvu6unycWaYzM0rjktaExFv2v6CpDdsb69qGyPiwd61B6BbJjM++6ik0er5x7b3S5rb68YAdNdn+s5u+0uSvirp1L2M7rD9lu2nbM9qMc+w7d22d9drFUAdkw677c9L+qWk70fEnyX9SNJ8SQs0tuX/4UTzRcRIRCyMiIX12wXQqUmF3fbnNBb0n0XEc5IUEYcj4kREnJT0Y0lDvWsTQF1tw27bkp6UtD8iNoybPmfc274pqTxcKIBGtb3E1fYiSS9Leltjh94k6R5JKzS2Cx+SDkr6TvVjXmlZKS9xXbiw/A3m9ddfL9ZXr15drD/66KMta/28hBmDodUlrpP5Nf43kiaamWPqwBTCGXRAEoQdSIKwA0kQdiAJwg4kQdiBJLiVNHCG4VbSQHKEHUiCsANJEHYgCcIOJEHYgSQIO5DEZO4u201HJf3PuNfnVtMG0aD2Nqh9SfTWqW729g+tCn09qeZTK7d3D+q96Qa1t0HtS6K3TvWrN3bjgSQIO5BE02EfaXj9JYPa26D2JdFbp/rSW6Pf2QH0T9NbdgB9QtiBJBoJu+2ltn9n+4DttU300Irtg7bftr2n6fHpqjH0jtjeO27a2ba3236nepxwjL2GervP9rvVZ7fH9nUN9TbP9k7b+23vs/29anqjn12hr758bn3/zm57mqTfS/qGpEOSdklaERG/7WsjLdg+KGlhRDR+AobtqyX9RdJPI+KfqmkPSPogIn5Q/Y9yVkT8+4D0dp+kvzQ9jHc1WtGc8cOMS1ou6WY1+NkV+vpX9eFza2LLPiTpQET8ISKOSfq5pGUN9DHwIuIlSR+cNnmZpE3V800a+8fSdy16GwgRMRoRb1bPP5Z0apjxRj+7Ql990UTY50r607jXhzRY472HpBdtv2F7uOlmJnD+qWG2qsfZDfdzurbDePfTacOMD8xn18nw53U1EfaJ7o81SMf/vhYR/yzpWknfrXZXMTmTGsa7XyYYZnwgdDr8eV1NhP2QpHnjXn9R0nsN9DGhiHivejwiaYsGbyjqw6dG0K0ejzTcz18N0jDeEw0zrgH47Joc/ryJsO+SdKntL9ueLulbkrY20Men2J5Z/XAi2zMlLdHgDUW9VdLK6vlKSc832MvfGJRhvFsNM66GP7vGhz+PiL7/SbpOY7/I/7eke5vooUVfF0v6r+pvX9O9Sdqssd26/9PYHtEqSedI2iHpnerx7AHq7RmNDe39lsaCNaeh3hZp7KvhW5L2VH/XNf3ZFfrqy+fG6bJAEpxBByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D8oJzrpWp+8SQAAAABJRU5ErkJggg==\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "from matplotlib import pyplot\n", | |
| "%matplotlib inline\n", | |
| "\n", | |
| "xb, _ = iter(dm.train_dataloader()).next()\n", | |
| "pyplot.imshow(xb[0].reshape((28, 28)), cmap='gray')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## NNモデルの構築\n", | |
| "\n", | |
| "単純な線形モデル\n", | |
| "$y = w \\times x + b$\n", | |
| "のニューラルネットワークを構築する.\n", | |
| "\n", | |
| "MNISTは10クラス分類問題なので,$y$は10次元である." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from torch import nn, optim\n", | |
| "import torch.nn.functional as F\n", | |
| "import pytorch_lightning as pl\n", | |
| "\n", | |
| "class Mnist_Logistic(pl.LightningModule):\n", | |
| " \n", | |
| " def __init__(self):\n", | |
| " super().__init__()\n", | |
| " self.lin = nn.Linear(784, 10)\n", | |
| " self.train_acc = pl.metrics.Accuracy()\n", | |
| " self.valid_acc = pl.metrics.Accuracy()\n", | |
| " self.test_acc = pl.metrics.Accuracy()\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " return self.lin(x.view(-1, 784))\n", | |
| " \n", | |
| " def _loss(self, batch, batch_idx):\n", | |
| " x, y = batch\n", | |
| " y_hat = self(x)\n", | |
| " loss = F.cross_entropy(y_hat, y)\n", | |
| " return y_hat, loss\n", | |
| " \n", | |
| " def training_step(self, batch, batch_idx):\n", | |
| " _, y = batch\n", | |
| " y_hat, loss = self._loss(batch, batch_idx)\n", | |
| " self.train_acc(y_hat, y)\n", | |
| " self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)\n", | |
| " return loss\n", | |
| " \n", | |
| " def training_epoch_end(self, outs):\n", | |
| " self.log('train_acc_epoch', self.train_acc.compute())\n", | |
| " \n", | |
| " def validation_step(self, batch, batch_idx):\n", | |
| " _, y = batch\n", | |
| " y_hat, loss = self._loss(batch, batch_idx)\n", | |
| " self.valid_acc(y_hat, y)\n", | |
| " self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)\n", | |
| " return loss\n", | |
| " \n", | |
| " def test_step(self, batch, batch_idx):\n", | |
| " _, y = batch\n", | |
| " y_hat, loss = self._loss(batch, batch_idx)\n", | |
| " self.test_acc(y_hat, y)\n", | |
| " self.log('test_acc', self.test_acc, on_step=True, on_epoch=True)\n", | |
| " return loss\n", | |
| " \n", | |
| " def configure_optimizers(self):\n", | |
| " optimizer = optim.SGD(self.parameters(), lr=0.5)\n", | |
| " return optimizer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 学習" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "#lr = 0.5 # learning rate\n", | |
| "epochs = 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model = Mnist_Logistic()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "GPU available: True, used: True\n", | |
| "TPU available: False, using: 0 TPU cores\n", | |
| "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", | |
| "\n", | |
| " | Name | Type | Params\n", | |
| "---------------------------------------\n", | |
| "0 | lin | Linear | 7 K \n", | |
| "1 | train_acc | Accuracy | 0 \n", | |
| "2 | valid_acc | Accuracy | 0 \n", | |
| "3 | test_acc | Accuracy | 0 \n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "e4b05ab537114d4e846c4a2ae7dfbab5", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "trainer = pl.Trainer(max_epochs=epochs, auto_select_gpus=True, gpus=1)\n", | |
| "trainer.fit(model, dm)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Train accuracy: tensor(0.8741, device='cuda:0')\n", | |
| "Validation accuracy: tensor(0.8694, device='cuda:0')\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print('Train accuracy: ', model.train_acc.compute())\n", | |
| "print('Validation accuracy: ', model.valid_acc.compute())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## EMNISTデータセットによるテスト" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### EMNISTデータセットのダウンロード" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from torchvision.datasets import EMNIST\n", | |
| "\n", | |
| "emnist = EMNIST('./data/emnist/',\n", | |
| " split='digits',\n", | |
| " train=False,\n", | |
| " download=True,\n", | |
| " transform = transforms.Compose([\n", | |
| " transforms.ToTensor(),\n", | |
| " transforms.Normalize((0.1307,), (0.3081,))\n", | |
| " ]))\n", | |
| "\n", | |
| "test_loader = DataLoader(emnist, batch_size=64, shuffle=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "9b2464674a6a41168b34f13be7a0a0a4", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "--------------------------------------------------------------------------------\n", | |
| "DATALOADER:0 TEST RESULTS\n", | |
| "{'test_acc': tensor(0.1761, device='cuda:0'),\n", | |
| " 'test_acc_epoch': tensor(0.1761, device='cuda:0'),\n", | |
| " 'train_acc': tensor(0.9167, device='cuda:0'),\n", | |
| " 'train_acc_epoch': tensor(0.8741, device='cuda:0'),\n", | |
| " 'valid_acc': tensor(0.8694, device='cuda:0'),\n", | |
| " 'valid_acc_epoch': tensor(0.8694, device='cuda:0')}\n", | |
| "--------------------------------------------------------------------------------\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[{'train_acc': 0.9166666865348816,\n", | |
| " 'valid_acc_epoch': 0.8694000244140625,\n", | |
| " 'valid_acc': 0.8694000244140625,\n", | |
| " 'train_acc_epoch': 0.8741454482078552,\n", | |
| " 'test_acc_epoch': 0.17612500488758087,\n", | |
| " 'test_acc': 0.17612500488758087}]" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "trainer.test(model, test_loader)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "torch", | |
| "language": "python", | |
| "name": "torch" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.5" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment