Skip to content

Instantly share code, notes, and snippets.

@o8r
Created November 5, 2020 07:31
Show Gist options
  • Select an option

  • Save o8r/95f239ef597c87c9b4ff8967a0c47670 to your computer and use it in GitHub Desktop.

Select an option

Save o8r/95f239ef597c87c9b4ff8967a0c47670 to your computer and use it in GitHub Desktop.
MNIST example using PyTorch-Lightning
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST\n",
"\n",
"pytorch_lightningを用いたMNISTの実装"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## データのセットアップ"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LightningDataModule派生クラス"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"from torch.utils.data import random_split, DataLoader\n",
"\n",
"# Note - you must have torchvision installed for this example\n",
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"\n",
"\n",
"class MNISTDataModule(pl.LightningDataModule):\n",
"\n",
" def __init__(self, data_dir: str = './'):\n",
" super().__init__()\n",
" self.data_dir = data_dir\n",
" self.transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ])\n",
"\n",
" # self.dims is returned when you call dm.size()\n",
" # Setting default dims here because we know them.\n",
" # Could optionally be assigned dynamically in dm.setup()\n",
" self.dims = (1, 28, 28)\n",
"\n",
" def prepare_data(self):\n",
" # download\n",
" MNIST(self.data_dir, train=True, download=True)\n",
" MNIST(self.data_dir, train=False, download=True)\n",
"\n",
" def setup(self, stage=None):\n",
"\n",
" # Assign train/val datasets for use in dataloaders\n",
" if stage == 'fit' or stage is None:\n",
" mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
" self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
"\n",
" # Optionally...\n",
" # self.dims = tuple(self.mnist_train[0][0].shape)\n",
"\n",
" # Assign test dataset for use in dataloader(s)\n",
" if stage == 'test' or stage is None:\n",
" self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
"\n",
" # Optionally...\n",
" # self.dims = tuple(self.mnist_test[0][0].shape)\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.mnist_train, batch_size=32)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.mnist_val, batch_size=32)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.mnist_test, batch_size=32)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### データセットの取得(ダウンロード)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"dm = MNISTDataModule('./data/')\n",
"\n",
"# Download dataset\n",
"dm.prepare_data()\n",
"\n",
"# Get the first batch\n",
"dm.setup('fit')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### データの確認\n",
"\n",
"``x_train``の各要素は,28x28サイズの1チャネル(グレースケール)の画像だが,28x28=784要素の1次元配列のように保持されている.28x28の形式に変換して表示する."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x1ab3e444640>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAANdUlEQVR4nO3db6xU9Z3H8c9n2RIJ7QP8gyLFbUWNWzeRbsiNSYnpSkpQH0BN3JSYDRrsraZKGzFZog8wwQeNsaAmWnPrn1LTpampVEgaV0JItDEqaFiFkla2YVv0CkvU1D6RBb774B6aW7zzm+ucmTlz+b5fyc3MnO+cc74Z+XjOzPnzc0QIwJnv75puAEB/EHYgCcIOJEHYgSQIO5DE3/dzZbb56R/osYjwRNNrbdltL7X9O9sHbK+tsywAveVOj7Pbnibp95K+IemQpF2SVkTEbwvzsGUHeqwXW/YhSQci4g8RcUzSzyUtq7E8AD1UJ+xzJf1p3OtD1bS/YXvY9m7bu2usC0BNdX6gm2hX4VO76RExImlEYjceaFKdLfshSfPGvf6ipPfqtQOgV+qEfZekS21/2fZ0Sd+StLU7bQHoto534yPiuO07JP2npGmSnoqIfV3rDEBXdXzoraOV8Z0d6LmenFQDYOog7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTR8fjskmT7oKSPJZ2QdDwiFnajKQDdVyvslX+JiKNdWA6AHmI3HkiibthD0ou237A9PNEbbA/b3m17d811AajBEdH5zPaFEfGe7dmStku6MyJeKry/85UBmJSI8ETTa23ZI+K96vGIpC2ShuosD0DvdBx22zNtf+HUc0lLJO3tVmMAuqvOr/HnS9pi+9Ry/iMiXuhKVwC6rtZ39s+8Mr6zAz3Xk+/sAKYOwg4kQdiBJAg7kARhB5LoxoUwaFh1+HNCV1xxRXHehx56qFhfvHhxsb53b/nUiqVLl7asvfvuu8V50V1s2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCa56mwIuu+yyYv2RRx5pWVuyZEm32/lMnn322Za12267rdayr7nmmmJ99uzZLWulviTp6NGpew9VrnoDkiPsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zj4AzjvvvGL9lVdeKdbnz5/fsnbixInivPv27SvWP/roo2L96quvLtZLPvzww2L95MmTxfo555zT8bqHhyccreyvnnjiiY6X3TSOswPJEXYgCcIOJEHYgSQIO5AEYQeSIOxAEtw3vg8uv/zyYv2FF8ojXV900UXFeulciYcffrg47913312sX3vttcX60NBQsX7WWWe1rM2aNas4b13Hjh1rWXv55Zd7uu5B1HbLbvsp20ds7x037Wzb222/Uz329r8agNomsxv/E0mnD+uxVtKOiLhU0o7qNYAB1jbsEfGSpA9Om7xM0qbq+SZJy7vbFoBu6/Q7+/kRMSpJETFqu+XNvmwPSyqfiAyg53r+A11EjEgakbgQBmhSp4feDtueI0nV45HutQSgFzoN+1ZJK6vnKyU93512APRK2+vZbW+W9HVJ50o6LGmdpF9J+oWkiyT9UdKNEXH6j3gTLSvlbvxzzz1XrC9fvrzW8jdt2tSydsstt9Radjvt7ktfGh/+/vvvL847Y8aMjno65emnn25ZW7VqVa1lD7JW17O3/c4eEStalBbX6ghAX3G6LJAEYQeSIOxAEoQdSIKwA0lwiWsfXHLJJbXm37BhQ7F+77331lp+HS+++GKx/sknn7SsTZs2rda6Dxw4UKw3+bkMIrbsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEx9n74M477yzW2w17/MADDxTrpWPZvdbuWPnq1atb1qZPn16c9/jx48X6mjVrivX333+/WM+GLTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJNH2VtJdXVnSW0mfydqdA9BuSOiS9evXF+vr1q3reNlnsla3kmbLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJwdRbfeemuxvnHjxmJ95syZLWvPPPNMcd6bb765WO/nv92ppOPj7Lafsn3E9t5x0+6z/a7tPdXfdd1sFkD3TWY3/ieSlk4wfWNELKj+ft3dtgB0W9uwR8RLkj7oQy8AeqjOD3R32H6r2s2f1epNtodt77a9u8a6ANTUadh/JGm+pAWSRiX9sNUbI2IkIhZGxMIO1wWgCzoKe0QcjogTEXFS0o8lDXW3LQDd1lHYbc8Z9/Kbkva2ei+AwdD2OLvtzZK+LulcSYclrateL5AUkg5K+k5EjLZdGcfZB87cuXOL9VdffbXW/KV7t99www211o2JtTrO3naQiIhYMcHkJ2t3BKCvOF0WSIKwA0kQdiAJwg4kQdiBJBiyObl2l7C2O7TWzk033dSyxqG1/mLLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJz9DLdo0aJife3atbWWv23btmJ9165dtZaP7mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMGTzGWDevHktazt37izOe/HFFxfro6PlO4RfeeWVxfrRo0eLdXRfx0M2AzgzEHYgCcIOJEHYgSQIO5AEYQeSIOxAElzPPgVceOGFxfqDDz7YstbuOPrx48eL9fXr1xfrHEefOtpu2W3Ps73T9n7b+2x/r5p+tu3ttt+pHmf1vl0AnZrMbvxxSWsi4h8lXSXpu7a/ImmtpB0RcamkHdVrAAOqbdgjYjQi3qyefyxpv6S5kpZJ2lS9bZOk5T3qEUAXfKbv7La/JOmrkl6TdH5EjEpj/0OwPbvFPMOShmv2CaCmSYfd9ucl/VLS9yPiz/aE59p/SkSMSBqplsGFMEBDJnXozfbnNBb0n0XEc9Xkw7bnVPU5ko70pkUA3dB2y+6xTfiTkvZHxIZxpa2SVkr6QfX4fE86hIaGhor1G2+8seNlP/bYY8X6448/3vGyMVgmsxv/NUn/Jult23uqafdoLOS/sL1K0h8ldf4vDkDPtQ17RPxGUqsv6Iu72w6AXuF0WSAJwg4kQdiBJAg7kARhB5LgEtcBcNVVVxXr7Y6Fl+zdu7dY37x5c8fLxtTClh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA4ex/MmDGjWL/99tuL9QsuuKBYLw27vWXLluK8r732WrGOMwdbdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IwqVjtF1fWdIRYa6//vpifdu2bbWWXxp2efr06bWWjaknIia8GzRbdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IYjLjs8+T9FNJF0g6KWkkIh62fZ+kb0v63+qt90TEr3vV6FS2YMGCni7/rrvu6unycWaYzM0rjktaExFv2v6CpDdsb69qGyPiwd61B6BbJjM++6ik0er5x7b3S5rb68YAdNdn+s5u+0uSvirp1L2M7rD9lu2nbM9qMc+w7d22d9drFUAdkw677c9L+qWk70fEnyX9SNJ8SQs0tuX/4UTzRcRIRCyMiIX12wXQqUmF3fbnNBb0n0XEc5IUEYcj4kREnJT0Y0lDvWsTQF1tw27bkp6UtD8iNoybPmfc274pqTxcKIBGtb3E1fYiSS9Leltjh94k6R5JKzS2Cx+SDkr6TvVjXmlZKS9xXbiw/A3m9ddfL9ZXr15drD/66KMta/28hBmDodUlrpP5Nf43kiaamWPqwBTCGXRAEoQdSIKwA0kQdiAJwg4kQdiBJLiVNHCG4VbSQHKEHUiCsANJEHYgCcIOJEHYgSQIO5DEZO4u201HJf3PuNfnVtMG0aD2Nqh9SfTWqW729g+tCn09qeZTK7d3D+q96Qa1t0HtS6K3TvWrN3bjgSQIO5BE02EfaXj9JYPa26D2JdFbp/rSW6Pf2QH0T9NbdgB9QtiBJBoJu+2ltn9n+4DttU300Irtg7bftr2n6fHpqjH0jtjeO27a2ba3236nepxwjL2GervP9rvVZ7fH9nUN9TbP9k7b+23vs/29anqjn12hr758bn3/zm57mqTfS/qGpEOSdklaERG/7WsjLdg+KGlhRDR+AobtqyX9RdJPI+KfqmkPSPogIn5Q/Y9yVkT8+4D0dp+kvzQ9jHc1WtGc8cOMS1ou6WY1+NkV+vpX9eFza2LLPiTpQET8ISKOSfq5pGUN9DHwIuIlSR+cNnmZpE3V800a+8fSdy16GwgRMRoRb1bPP5Z0apjxRj+7Ql990UTY50r607jXhzRY472HpBdtv2F7uOlmJnD+qWG2qsfZDfdzurbDePfTacOMD8xn18nw53U1EfaJ7o81SMf/vhYR/yzpWknfrXZXMTmTGsa7XyYYZnwgdDr8eV1NhP2QpHnjXn9R0nsN9DGhiHivejwiaYsGbyjqw6dG0K0ejzTcz18N0jDeEw0zrgH47Joc/ryJsO+SdKntL9ueLulbkrY20Men2J5Z/XAi2zMlLdHgDUW9VdLK6vlKSc832MvfGJRhvFsNM66GP7vGhz+PiL7/SbpOY7/I/7eke5vooUVfF0v6r+pvX9O9Sdqssd26/9PYHtEqSedI2iHpnerx7AHq7RmNDe39lsaCNaeh3hZp7KvhW5L2VH/XNf3ZFfrqy+fG6bJAEpxBByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D8oJzrpWp+8SQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from matplotlib import pyplot\n",
"%matplotlib inline\n",
"\n",
"xb, _ = iter(dm.train_dataloader()).next()\n",
"pyplot.imshow(xb[0].reshape((28, 28)), cmap='gray')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## NNモデルの構築\n",
"\n",
"単純な線形モデル\n",
"$y = w \\times x + b$\n",
"のニューラルネットワークを構築する.\n",
"\n",
"MNISTは10クラス分類問題なので,$y$は10次元である."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn, optim\n",
"import torch.nn.functional as F\n",
"import pytorch_lightning as pl\n",
"\n",
"class Mnist_Logistic(pl.LightningModule):\n",
" \n",
" def __init__(self):\n",
" super().__init__()\n",
" self.lin = nn.Linear(784, 10)\n",
" self.train_acc = pl.metrics.Accuracy()\n",
" self.valid_acc = pl.metrics.Accuracy()\n",
" self.test_acc = pl.metrics.Accuracy()\n",
" \n",
" def forward(self, x):\n",
" return self.lin(x.view(-1, 784))\n",
" \n",
" def _loss(self, batch, batch_idx):\n",
" x, y = batch\n",
" y_hat = self(x)\n",
" loss = F.cross_entropy(y_hat, y)\n",
" return y_hat, loss\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" _, y = batch\n",
" y_hat, loss = self._loss(batch, batch_idx)\n",
" self.train_acc(y_hat, y)\n",
" self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)\n",
" return loss\n",
" \n",
" def training_epoch_end(self, outs):\n",
" self.log('train_acc_epoch', self.train_acc.compute())\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
" _, y = batch\n",
" y_hat, loss = self._loss(batch, batch_idx)\n",
" self.valid_acc(y_hat, y)\n",
" self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)\n",
" return loss\n",
" \n",
" def test_step(self, batch, batch_idx):\n",
" _, y = batch\n",
" y_hat, loss = self._loss(batch, batch_idx)\n",
" self.test_acc(y_hat, y)\n",
" self.log('test_acc', self.test_acc, on_step=True, on_epoch=True)\n",
" return loss\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = optim.SGD(self.parameters(), lr=0.5)\n",
" return optimizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 学習"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"#lr = 0.5 # learning rate\n",
"epochs = 2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = Mnist_Logistic()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True, used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"---------------------------------------\n",
"0 | lin | Linear | 7 K \n",
"1 | train_acc | Accuracy | 0 \n",
"2 | valid_acc | Accuracy | 0 \n",
"3 | test_acc | Accuracy | 0 \n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e4b05ab537114d4e846c4a2ae7dfbab5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer = pl.Trainer(max_epochs=epochs, auto_select_gpus=True, gpus=1)\n",
"trainer.fit(model, dm)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train accuracy: tensor(0.8741, device='cuda:0')\n",
"Validation accuracy: tensor(0.8694, device='cuda:0')\n"
]
}
],
"source": [
"print('Train accuracy: ', model.train_acc.compute())\n",
"print('Validation accuracy: ', model.valid_acc.compute())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## EMNISTデータセットによるテスト"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EMNISTデータセットのダウンロード"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from torchvision.datasets import EMNIST\n",
"\n",
"emnist = EMNIST('./data/emnist/',\n",
" split='digits',\n",
" train=False,\n",
" download=True,\n",
" transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.1307,), (0.3081,))\n",
" ]))\n",
"\n",
"test_loader = DataLoader(emnist, batch_size=64, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9b2464674a6a41168b34f13be7a0a0a4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"DATALOADER:0 TEST RESULTS\n",
"{'test_acc': tensor(0.1761, device='cuda:0'),\n",
" 'test_acc_epoch': tensor(0.1761, device='cuda:0'),\n",
" 'train_acc': tensor(0.9167, device='cuda:0'),\n",
" 'train_acc_epoch': tensor(0.8741, device='cuda:0'),\n",
" 'valid_acc': tensor(0.8694, device='cuda:0'),\n",
" 'valid_acc_epoch': tensor(0.8694, device='cuda:0')}\n",
"--------------------------------------------------------------------------------\n",
"\n"
]
},
{
"data": {
"text/plain": [
"[{'train_acc': 0.9166666865348816,\n",
" 'valid_acc_epoch': 0.8694000244140625,\n",
" 'valid_acc': 0.8694000244140625,\n",
" 'train_acc_epoch': 0.8741454482078552,\n",
" 'test_acc_epoch': 0.17612500488758087,\n",
" 'test_acc': 0.17612500488758087}]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.test(model, test_loader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "torch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment