Last active
November 25, 2022 11:20
-
-
Save PiotrCzapla/00c82fb193c9ebc20702ea22de2cb737 to your computer and use it in GitHub Desktop.
fastai 04_minibatch_training_withoutbody.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 121, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt\n", | |
| "from pathlib import Path\n", | |
| "from torch import tensor,nn\n", | |
| "import torch.nn.functional as F\n", | |
| "from fastcore.test import test_close\n", | |
| "\n", | |
| "torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)\n", | |
| "torch.manual_seed(1)\n", | |
| "mpl.rcParams['image.cmap'] = 'gray'\n", | |
| "\n", | |
| "path_data = Path('data')\n", | |
| "path_gz = path_data/'mnist.pkl.gz'\n", | |
| "with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')\n", | |
| "x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Initial setup" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "### Data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 122, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "n,m = x_train.shape\n", | |
| "c = y_train.max()+1\n", | |
| "nh = 50" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 123, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class Model(nn.Module):\n", | |
| " def __init__(self, n_in, nh, n_out):\n", | |
| " super().__init__()\n", | |
| " self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]\n", | |
| " \n", | |
| " def __call__(self, x):\n", | |
| " for l in self.layers: x = l(x)\n", | |
| " return x" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 124, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([50000, 10])" | |
| ] | |
| }, | |
| "execution_count": 124, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = Model(m, nh, 10)\n", | |
| "pred = model(x_train)\n", | |
| "pred.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "### Cross entropy loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "First, we will need to compute the softmax of our activations. This is defined by:\n", | |
| "\n", | |
| "$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + \\cdots + e^{x_{n-1}}}$$\n", | |
| "\n", | |
| "or more concisely:\n", | |
| "\n", | |
| "$$\\hbox{softmax(x)}_{i} = \\frac{e^{x_{i}}}{\\sum_{0 \\leq j \\leq n-1} e^{x_{j}}}$$ \n", | |
| "\n", | |
| "In practice, we will need the log of the softmax when we calculate the loss." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 125, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def log_softmax(x): ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 126, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n", | |
| " [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n", | |
| " [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n", | |
| " ...,\n", | |
| " [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n", | |
| " [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n", | |
| " [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=<LogBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 126, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "log_softmax(pred)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "Note that the formula \n", | |
| "\n", | |
| "$$\\log \\left ( \\frac{a}{b} \\right ) = \\log(a) - \\log(b)$$ \n", | |
| "\n", | |
| "gives a simplification when we compute the log softmax, which was previously defined as `(x.exp()/(x.exp().sum(-1,keepdim=True))).log()`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 127, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def log_softmax(x): ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "Then, there is a way to compute the log of the sum of exponentials in a more stable way, called the [LogSumExp trick](https://en.wikipedia.org/wiki/LogSumExp). The idea is to use the following formula:\n", | |
| "\n", | |
| "$$\\log \\left ( \\sum_{j=1}^{n} e^{x_{j}} \\right ) = \\log \\left ( e^{a} \\sum_{j=1}^{n} e^{x_{j}-a} \\right ) = a + \\log \\left ( \\sum_{j=1}^{n} e^{x_{j}-a} \\right )$$\n", | |
| "\n", | |
| "where a is the maximum of the $x_{j}$." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 128, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def logsumexp(x):\n", | |
| " ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "This way, we will avoid an overflow when taking the exponential of a big activation. In PyTorch, this is already implemented for us. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 129, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def log_softmax(x): ... # numerical stability achieved using logsumexp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 130, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-2.37, -2.49, -2.36, ..., -2.31, -2.28, -2.22],\n", | |
| " [-2.37, -2.44, -2.44, ..., -2.27, -2.26, -2.16],\n", | |
| " [-2.48, -2.33, -2.28, ..., -2.30, -2.30, -2.27],\n", | |
| " ...,\n", | |
| " [-2.33, -2.52, -2.34, ..., -2.31, -2.21, -2.16],\n", | |
| " [-2.38, -2.38, -2.33, ..., -2.29, -2.26, -2.17],\n", | |
| " [-2.33, -2.55, -2.36, ..., -2.29, -2.27, -2.16]], grad_fn=<SubBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 130, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "test_close(logsumexp(pred), pred.logsumexp(-1))\n", | |
| "sm_pred = log_softmax(pred)\n", | |
| "sm_pred" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "The cross entropy loss for some target $x$ and some prediction $p(x)$ is given by:\n", | |
| "\n", | |
| "$$ -\\sum x\\, \\log p(x) $$\n", | |
| "\n", | |
| "But since our $x$s are 1-hot encoded, this can be rewritten as $-\\log(p_{i})$ where i is the index of the desired target.\n", | |
| "\n", | |
| "This can be done using numpy-style [integer array indexing](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#integer-array-indexing). Note that PyTorch supports all the tricks in the advanced indexing methods discussed in that link." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 131, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([5, 0, 4])" | |
| ] | |
| }, | |
| "execution_count": 131, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y_train[:3]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 132, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(-2.20, grad_fn=<SelectBackward0>),\n", | |
| " tensor(-2.37, grad_fn=<SelectBackward0>),\n", | |
| " tensor(-2.36, grad_fn=<SelectBackward0>))" | |
| ] | |
| }, | |
| "execution_count": 132, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sm_pred[0,5],sm_pred[1,0],sm_pred[2,4]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 133, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([-2.20, -2.37, -2.36], grad_fn=<IndexBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 133, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "sm_pred[[0,1,2], y_train[:3]]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 134, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def nll(input, target): ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 135, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(2.30, grad_fn=<NegBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 135, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "loss = nll(sm_pred, y_train)\n", | |
| "loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "Then use PyTorch's implementation." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 136, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "test_close(F.nll_loss(F.log_softmax(pred, -1), y_train), loss, 1e-3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "In PyTorch, `F.log_softmax` and `F.nll_loss` are combined in one optimized function, `F.cross_entropy`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 137, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "test_close(F.cross_entropy(pred, y_train), loss, 1e-3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "## Basic training loop" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "Basically the training loop repeats over the following steps:\n", | |
| "- get the output of the model on a batch of inputs\n", | |
| "- compare the output to the labels we have and compute a loss\n", | |
| "- calculate the gradients of the loss with respect to every parameter of the model\n", | |
| "- update said parameters with those gradients to make them a little bit better" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 138, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "loss_func = F.cross_entropy" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 139, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor([-0.09, -0.21, -0.08, 0.10, -0.04, 0.08, -0.04, -0.03, 0.01, 0.06], grad_fn=<SelectBackward0>),\n", | |
| " torch.Size([64, 10]))" | |
| ] | |
| }, | |
| "execution_count": 139, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "bs=64 # batch size\n", | |
| "\n", | |
| "xb = x_train[0:bs] # a mini-batch from x\n", | |
| "preds = model(xb) # predictions\n", | |
| "preds[0], preds.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 140, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(2.30, grad_fn=<NllLossBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 140, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "yb = y_train[0:bs]\n", | |
| "loss_func(preds, yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 141, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([3, 9, 3, 8, 5, 9, 3, 9, 3, 9, 5, 3, 9, 9, 3, 9, 9, 5, 8, 7, 9, 5, 3, 8, 9, 5, 9, 5, 5, 9, 3, 5, 9, 7, 5, 7, 9, 9, 3, 9, 3, 5, 3, 8,\n", | |
| " 3, 5, 9, 5, 9, 5, 3, 9, 3, 8, 9, 5, 9, 5, 9, 5, 8, 8, 9, 8])" | |
| ] | |
| }, | |
| "execution_count": 141, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.argmax(preds, dim=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 142, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def accuracy(out, yb): ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 143, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(0.09)" | |
| ] | |
| }, | |
| "execution_count": 143, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "accuracy(preds, yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 144, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "lr = 0.5 # learning rate\n", | |
| "epochs = 3 # how many epochs to train for" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 145, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.3036487102508545 0.09375\n", | |
| "0.12374822050333023 0.96875\n", | |
| "0.09232541173696518 0.96875\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for epoch in range(epochs):\n", | |
| " for i in range(0, n, bs):\n", | |
| " ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "## Using parameters and optim" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true, | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "### Parameters" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "Use `nn.Module.__setattr__`:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 146, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class Model(nn.Module):\n", | |
| " def __init__(self, n_in, nh, n_out):\n", | |
| " ... \n", | |
| " \n", | |
| " def __call__(self, x): ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 147, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = Model(m, nh, 10)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 148, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "l1: Linear(in_features=784, out_features=50, bias=True)\n", | |
| "l2: Linear(in_features=50, out_features=10, bias=True)\n", | |
| "relu: ReLU()\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for name,l in model.named_children(): print(f\"{name}: {l}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 149, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Model(\n", | |
| " (l1): Linear(in_features=784, out_features=50, bias=True)\n", | |
| " (l2): Linear(in_features=50, out_features=10, bias=True)\n", | |
| " (relu): ReLU()\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 149, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 150, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Linear(in_features=784, out_features=50, bias=True)" | |
| ] | |
| }, | |
| "execution_count": 150, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.l1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 151, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def fit():\n", | |
| " for epoch in range(epochs):\n", | |
| " for i in range(0, n, bs):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 152, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.309433937072754 0.0625\n", | |
| "0.20068740844726562 0.953125\n", | |
| "0.18196897208690643 0.9375\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "fit()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "Behind the scenes, PyTorch overrides the `__setattr__` function in `nn.Module` so that the submodules you define are properly registered as parameters of the model." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 153, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class DummyModule():\n", | |
| " def __init__(self, n_in, nh, n_out):\n", | |
| " ... \n", | |
| " \n", | |
| " def __setattr__(self,k,v):\n", | |
| " ...\n", | |
| " \n", | |
| " def __repr__(self): return f'{self._modules}'\n", | |
| " \n", | |
| " def parameters(self):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 154, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'l1': Linear(in_features=784, out_features=50, bias=True), 'l2': Linear(in_features=50, out_features=10, bias=True)}" | |
| ] | |
| }, | |
| "execution_count": 154, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mdl = DummyModule(m,nh,10)\n", | |
| "mdl" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 155, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[torch.Size([50, 784]),\n", | |
| " torch.Size([50]),\n", | |
| " torch.Size([10, 50]),\n", | |
| " torch.Size([10])]" | |
| ] | |
| }, | |
| "execution_count": 155, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "[o.shape for o in mdl.parameters()]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true, | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "### Registering modules" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "We can use the original `layers` approach, but we have to register the modules." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 156, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "layers = [nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10)]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 157, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class Model(nn.Module):\n", | |
| " def __init__(self, layers):\n", | |
| " ...\n", | |
| " \n", | |
| " def __call__(self, x):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 158, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = Model(layers)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 159, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Model(\n", | |
| " (layer_0): Linear(in_features=784, out_features=50, bias=True)\n", | |
| " (layer_1): ReLU()\n", | |
| " (layer_2): Linear(in_features=50, out_features=10, bias=True)\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 159, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true, | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "### nn.ModuleList" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "`nn.ModuleList` does this for us." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 160, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class SequentialModel(nn.Module):\n", | |
| " def __init__(self, layers):\n", | |
| " ...\n", | |
| " \n", | |
| " def __call__(self, x):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 161, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = SequentialModel(layers)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 162, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "SequentialModel(\n", | |
| " (layers): ModuleList(\n", | |
| " (0): Linear(in_features=784, out_features=50, bias=True)\n", | |
| " (1): ReLU()\n", | |
| " (2): Linear(in_features=50, out_features=10, bias=True)\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 162, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 163, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.3222596645355225 0.03125\n", | |
| "0.14386768639087677 0.96875\n", | |
| "0.08797654509544373 0.96875\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(0.02, grad_fn=<NllLossBackward0>), tensor(1.))" | |
| ] | |
| }, | |
| "execution_count": 163, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "fit()\n", | |
| "loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true, | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "### nn.Sequential" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "`nn.Sequential` is a convenient class which does the same as the above:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 164, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 165, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.3087124824523926 0.09375\n", | |
| "0.20316630601882935 0.90625\n", | |
| "0.20330585539340973 0.921875\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(0.01, grad_fn=<NllLossBackward0>), tensor(1.))" | |
| ] | |
| }, | |
| "execution_count": 165, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "fit()\n", | |
| "loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 166, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Sequential(\n", | |
| " (0): Linear(in_features=784, out_features=50, bias=True)\n", | |
| " (1): ReLU()\n", | |
| " (2): Linear(in_features=50, out_features=10, bias=True)\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 166, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true, | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "### optim" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "Let's replace our previous manually coded optimization step:\n", | |
| "\n", | |
| "```python\n", | |
| "with torch.no_grad():\n", | |
| " for p in model.parameters(): p -= p.grad * lr\n", | |
| " model.zero_grad()\n", | |
| "```\n", | |
| "\n", | |
| "and instead use just:\n", | |
| "\n", | |
| "```python\n", | |
| "opt.step()\n", | |
| "opt.zero_grad()\n", | |
| "```" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 167, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class Optimizer():\n", | |
| " def __init__(self, params, lr=0.5): self.params,self.lr=list(params),lr\n", | |
| " \n", | |
| " def step(self):\n", | |
| " ...\n", | |
| "\n", | |
| " def zero_grad(self):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 168, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 169, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "opt = Optimizer(model.parameters())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 170, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.30017352104187 0.0625\n", | |
| "0.13068024814128876 0.96875\n", | |
| "0.11748667806386948 0.96875\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for epoch in range(epochs):\n", | |
| " for i in range(0, n, bs):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "PyTorch already provides this exact functionality in `optim.SGD` (it also handles stuff like momentum, which we'll look at later)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 171, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from torch import optim" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 172, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def get_model():\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 173, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(2.30, grad_fn=<NllLossBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 173, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model,opt = get_model()\n", | |
| "loss_func(model(xb), yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 174, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.312685012817383 0.078125\n", | |
| "0.21422098577022552 0.90625\n", | |
| "0.17829009890556335 0.921875\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for epoch in range(epochs):\n", | |
| " for i in range(0, n, bs):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Dataset and DataLoader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "heading_collapsed": true | |
| }, | |
| "source": [ | |
| "### Dataset" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "source": [ | |
| "It's clunky to iterate through minibatches of x and y values separately:\n", | |
| "\n", | |
| "```python\n", | |
| " xb = x_train[s]\n", | |
| " yb = y_train[s]\n", | |
| "```\n", | |
| "\n", | |
| "Instead, let's do these two steps together, by introducing a `Dataset` class:\n", | |
| "\n", | |
| "```python\n", | |
| " xb,yb = train_ds[s]\n", | |
| "```" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 175, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class Dataset():\n", | |
| " def __init__(self, x, y): ...\n", | |
| " def __len__(self): return ...\n", | |
| " def __getitem__(self, i): return ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 176, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n", | |
| "assert len(train_ds)==len(x_train)\n", | |
| "assert len(valid_ds)==len(x_valid)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 177, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.]]),\n", | |
| " tensor([5, 0, 4, 1, 9]))" | |
| ] | |
| }, | |
| "execution_count": 177, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "xb,yb = train_ds[0:5]\n", | |
| "assert xb.shape==(5,28*28)\n", | |
| "assert yb.shape==(5,)\n", | |
| "xb,yb" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 178, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "model,opt = get_model()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 179, | |
| "metadata": { | |
| "hidden": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.2979648113250732 0.09375\n", | |
| "0.21198976039886475 0.953125\n", | |
| "0.17290501296520233 0.921875\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for epoch in range(epochs):\n", | |
| " for i in range(0, n, bs):\n", | |
| " ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### DataLoader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Previously, our loop iterated over batches (xb, yb) like this:\n", | |
| "\n", | |
| "```python\n", | |
| "for i in range(0, n, bs):\n", | |
| " xb,yb = train_ds[i:min(n,i+bs)]\n", | |
| " ...\n", | |
| "```\n", | |
| "\n", | |
| "Let's make our loop much cleaner, using a data loader:\n", | |
| "\n", | |
| "```python\n", | |
| "for xb,yb in train_dl:\n", | |
| " ...\n", | |
| "```" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 180, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class DataLoader():\n", | |
| " def __init__(self, ds, bs): self.ds,self.bs = ds,bs\n", | |
| " def __iter__(self): ..." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 181, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, bs)\n", | |
| "valid_dl = DataLoader(valid_ds, bs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 182, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([64, 784])" | |
| ] | |
| }, | |
| "execution_count": 182, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "xb,yb = next(iter(valid_dl))\n", | |
| "xb.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 183, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([3, 8, 6, 9, 6, 4, 5, 3, 8, 4, 5, 2, 3, 8, 4, 8, 1, 5, 0, 5, 9, 7, 4, 1, 0, 3, 0, 6, 2, 9, 9, 4, 1, 3, 6, 8, 0, 7, 7, 6, 8, 9, 0, 3,\n", | |
| " 8, 3, 7, 7, 8, 4, 4, 1, 2, 9, 8, 1, 1, 0, 6, 6, 5, 0, 1, 1])" | |
| ] | |
| }, | |
| "execution_count": 183, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "yb" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 184, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(3)" | |
| ] | |
| }, | |
| "execution_count": 184, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAANeElEQVR4nO3df6hc9ZnH8c9HTTExQaNBTdJo2hv/2GUxZhVZMSzVYnFFiBVcGnDJxsCtUKHVVVayQkUphGVbBf+IpBiSXbuWmtg1VCWKhPUXFOOP1djY+INsEnNzgwY0otKNPvvHPVmuyT3fuZlfZ/Y+7xdcZuY8c855GPLJOTPfM/N1RAjA1HdS0w0A6A/CDiRB2IEkCDuQBGEHkjilnzuzzUf/QI9FhCda3tGR3fbVtv9o+13bd3ayLQC95XbH2W2fLGmXpKsk7ZP0sqTlEfGHwjoc2YEe68WR/VJJ70bE+xHxJ0m/lrSsg+0B6KFOwj5f0t5xj/dVy77G9rDt7ba3d7AvAB3q5AO6iU4VjjtNj4h1ktZJnMYDTerkyL5P0oJxj78paX9n7QDolU7C/rKkC2x/y/Y3JP1A0pbutAWg29o+jY+II7ZvkbRV0smS1kfEW13rDEBXtT301tbOeM8O9FxPLqoB8P8HYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HXKZrRn8eLFxfqtt95aWxsaGiquO2PGjGJ99erVxfrpp59erD/11FO1tcOHDxfXRXdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJjFdQDMnDmzWN+zZ0+xfsYZZ3Sxm+764IMPamul6wMkadOmTd1uJ4W6WVw7uqjG9m5JhyV9KelIRFzSyfYA9E43rqC7IiI+7MJ2APQQ79mBJDoNe0h62vYrtocneoLtYdvbbW/vcF8AOtDpafzlEbHf9tmSnrH9dkQ8N/4JEbFO0jqJD+iAJnV0ZI+I/dXtQUm/lXRpN5oC0H1th932abZnHb0v6XuSdnSrMQDd1fY4u+1va+xoLo29Hfj3iPhZi3U4jZ/ArFmzivUnn3yyWP/oo49qa6+99lpx3SVLlhTr559/frG+YMGCYn369Om1tdHR0eK6l112WbHeav2suj7OHhHvSyr/qgKAgcHQG5AEYQeSIOxAEoQdSIKwA0nwFVd0ZM6cOcX6HXfc0VZNklauXFmsb9y4sVjPqm7ojSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBlM3oyIcfln9r9MUXX6yttRpnb/X1W8bZTwxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2dGT27NnF+urVq9ve9rx589peF8fjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfC78ShavLg8Ue+jjz5arC9atKi2tmvXruK6V111VbG+d+/eYj2rtn833vZ62wdt7xi37Ezbz9h+p7otX1kBoHGTOY3fIOnqY5bdKenZiLhA0rPVYwADrGXYI+I5SYeOWbxM0tHfBNoo6brutgWg29q9Nv6ciBiRpIgYsX123RNtD0sabnM/ALqk51+EiYh1ktZJfEAHNKndobdR23Mlqbo92L2WAPRCu2HfImlFdX+FpMe70w6AXmk5zm77EUnfkTRH0qikn0r6D0m/kXSepD2SboiIYz/Em2hbnMYPmBUrVhTr99xzT7G+YMGCYv3zzz+vrV177bXFdbdt21asY2J14+wt37NHxPKa0nc76ghAX3G5LJAEYQeSIOxAEoQdSIKwA0nwU9JTwMyZM2trt99+e3Hdu+66q1g/6aTy8eDQofKI69KlS2trb7/9dnFddBdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KWDDhg21teuvv76jbW/atKlYv//++4t1xtIHB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYpYGhoqGfbXrt2bbH+0ksv9Wzf6C6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsU8DTTz9dW1u8eHHPti21Hodfs2ZNbW3//v1t9YT2tDyy215v+6DtHeOW3W37A9uvV3/X9LZNAJ2azGn8BklXT7D8voi4qPp7srttAei2lmGPiOcklef4ATDwOvmA7hbbb1Sn+bPrnmR72PZ229s72BeADrUb9rWShiRdJGlE0s/rnhgR6yLikoi4pM19AeiCtsIeEaMR8WVEfCXpl5Iu7W5bALqtrbDbnjvu4fcl7ah7LoDB4IgoP8F+RNJ3JM2RNCrpp9XjiySFpN2SfhgRIy13Zpd3hrZMnz69tvbwww8X17344ouL9fPOO6+tno46cOBAbW3lypXFdbdu3drRvrOKCE+0vOVFNRGxfILFD3XcEYC+4nJZIAnCDiRB2IEkCDuQBGEHkmg59NbVnTH01nennnpqsX7KKeUBmU8++aSb7XzNF198UazfdtttxfqDDz7YzXamjLqhN47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wouvDCC4v1++67r1i/4oor2t73nj17ivWFCxe2ve2pjHF2IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYBMGPGjGL9s88+61MnJ2727NqZvyRJ69evr60tW7aso33Pnz+/WB8Zafnr5lMS4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kETLWVzRuaGhoWL9hRdeKNafeOKJYn3Hjh21tVZjzatWrSrWp02bVqy3GutetGhRsV7y3nvvFetZx9Hb1fLIbnuB7W22d9p+y/aPq+Vn2n7G9jvVbfnqCgCNmsxp/BFJ/xARfybpryT9yPafS7pT0rMRcYGkZ6vHAAZUy7BHxEhEvFrdPyxpp6T5kpZJ2lg9baOk63rUI4AuOKH37LYXSloi6feSzomIEWnsPwTbZ9esMyxpuMM+AXRo0mG3PVPSZkk/iYhP7AmvtT9ORKyTtK7aBl+EARoyqaE329M0FvRfRcRj1eJR23Or+lxJB3vTIoBuaHlk99gh/CFJOyPiF+NKWyStkLSmun28Jx1OATfccEOxfu655xbrN910UzfbOSGtzuA6+Yr0p59+WqzffPPNbW8bx5vMafzlkv5O0pu2X6+WrdZYyH9je5WkPZLK/6IBNKpl2CPiBUl1/71/t7vtAOgVLpcFkiDsQBKEHUiCsANJEHYgCb7i2gdnnXVW0y30zObNm4v1e++9t7Z28GD5OqwDBw601RMmxpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgyuY+aPVzzFdeeWWxfuONNxbr8+bNq619/PHHxXVbeeCBB4r1559/vlg/cuRIR/vHiWPKZiA5wg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2YIphnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmgZdtsLbG+zvdP2W7Z/XC2/2/YHtl+v/q7pfbsA2tXyohrbcyXNjYhXbc+S9Iqk6yT9raRPI+JfJr0zLqoBeq7uoprJzM8+Immkun/Y9k5J87vbHoBeO6H37LYXSloi6ffVoltsv2F7ve3ZNesM295ue3tnrQLoxKSvjbc9U9J/SvpZRDxm+xxJH0oKSfdq7FT/phbb4DQe6LG60/hJhd32NEm/k7Q1In4xQX2hpN9FxF+02A5hB3qs7S/C2LakhyTtHB/06oO7o74vaUenTQLoncl8Gr9U0vOS3pT0VbV4taTlki7S2Gn8bkk/rD7MK22LIzvQYx2dxncLYQd6j++zA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmj5g5Nd9qGk/x73eE61bBANam+D2pdEb+3qZm/n1xX6+n3243Zub4+ISxproGBQexvUviR6a1e/euM0HkiCsANJNB32dQ3vv2RQexvUviR6a1dfemv0PTuA/mn6yA6gTwg7kEQjYbd9te0/2n7X9p1N9FDH9m7bb1bTUDc6P101h95B2zvGLTvT9jO236luJ5xjr6HeBmIa78I0442+dk1Pf9739+y2T5a0S9JVkvZJelnS8oj4Q18bqWF7t6RLIqLxCzBs/7WkTyX969GptWz/s6RDEbGm+o9ydkT844D0drdOcBrvHvVWN83436vB166b05+3o4kj+6WS3o2I9yPiT5J+LWlZA30MvIh4TtKhYxYvk7Sxur9RY/9Y+q6mt4EQESMR8Wp1/7Cko9OMN/raFfrqiybCPl/S3nGP92mw5nsPSU/bfsX2cNPNTOCco9NsVbdnN9zPsVpO491Px0wzPjCvXTvTn3eqibBPNDXNII3/XR4RfynpbyT9qDpdxeSslTSksTkARyT9vMlmqmnGN0v6SUR80mQv403QV19etybCvk/SgnGPvylpfwN9TCgi9le3ByX9VmNvOwbJ6NEZdKvbgw33838iYjQivoyIryT9Ug2+dtU045sl/SoiHqsWN/7aTdRXv163JsL+sqQLbH/L9jck/UDSlgb6OI7t06oPTmT7NEnf0+BNRb1F0orq/gpJjzfYy9cMyjTeddOMq+HXrvHpzyOi73+SrtHYJ/LvSfqnJnqo6evbkv6r+nur6d4kPaKx07r/0dgZ0SpJZ0l6VtI71e2ZA9Tbv2lsau83NBasuQ31tlRjbw3fkPR69XdN069doa++vG5cLgskwRV0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wLwpj8ONnyk5wAAAABJRU5ErkJggg==", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "plt.imshow(xb[0].view(28,28))\n", | |
| "yb[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 185, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model,opt = get_model()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 186, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def fit():\n", | |
| " for epoch in range(epochs):\n", | |
| " for xb,yb in train_dl:\n", | |
| " ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 187, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(0.05, grad_fn=<NllLossBackward0>), tensor(1.))" | |
| ] | |
| }, | |
| "execution_count": 187, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "fit()\n", | |
| "loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Random sampling" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We want our training set to be in a random order, and that order should differ each iteration. But the validation set shouldn't be randomized." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 188, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import random" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 189, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Sampler():\n", | |
| " def __init__(self, ds, shuffle=False): self.n,self.shuffle = len(ds),shuffle\n", | |
| " def __iter__(self):\n", | |
| " ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 190, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from itertools import islice" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 191, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ss = Sampler(train_ds)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 192, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0\n", | |
| "1\n", | |
| "2\n", | |
| "3\n", | |
| "4\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "it = iter(ss)\n", | |
| "for o in range(5): print(next(it))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 193, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[0, 1, 2, 3, 4]" | |
| ] | |
| }, | |
| "execution_count": 193, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "list(islice(ss, 5))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 194, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[2468, 34785, 22293, 22313, 36680]" | |
| ] | |
| }, | |
| "execution_count": 194, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "ss = Sampler(train_ds, shuffle=True)\n", | |
| "list(islice(ss, 5))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 195, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import fastcore.all as fc" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 196, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class BatchSampler():\n", | |
| " def __init__(self, sampler, bs, drop_last=False): fc.store_attr()\n", | |
| " def __iter__(self): ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 197, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[[3338, 3713, 47999, 33349],\n", | |
| " [1382, 28497, 19584, 35095],\n", | |
| " [33760, 20524, 1959, 7968],\n", | |
| " [40952, 25061, 32207, 20443],\n", | |
| " [11419, 11479, 45286, 40070]]" | |
| ] | |
| }, | |
| "execution_count": 197, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "batchs = BatchSampler(ss, 4)\n", | |
| "list(islice(batchs, 5))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 198, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def collate(b):\n", | |
| " ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 199, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class DataLoader():\n", | |
| " def __init__(self, ds, batchs, collate_fn=collate): fc.store_attr()\n", | |
| " def __iter__(self): ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 200, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_samp = BatchSampler(Sampler(train_ds, shuffle=True ), bs)\n", | |
| "valid_samp = BatchSampler(Sampler(valid_ds, shuffle=False), bs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 201, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, batchs=train_samp, collate_fn=collate)\n", | |
| "valid_dl = DataLoader(valid_ds, batchs=valid_samp, collate_fn=collate)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 202, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(3)" | |
| ] | |
| }, | |
| "execution_count": 202, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAANeElEQVR4nO3df6hc9ZnH8c9HTTExQaNBTdJo2hv/2GUxZhVZMSzVYnFFiBVcGnDJxsCtUKHVVVayQkUphGVbBf+IpBiSXbuWmtg1VCWKhPUXFOOP1djY+INsEnNzgwY0otKNPvvHPVmuyT3fuZlfZ/Y+7xdcZuY8c855GPLJOTPfM/N1RAjA1HdS0w0A6A/CDiRB2IEkCDuQBGEHkjilnzuzzUf/QI9FhCda3tGR3fbVtv9o+13bd3ayLQC95XbH2W2fLGmXpKsk7ZP0sqTlEfGHwjoc2YEe68WR/VJJ70bE+xHxJ0m/lrSsg+0B6KFOwj5f0t5xj/dVy77G9rDt7ba3d7AvAB3q5AO6iU4VjjtNj4h1ktZJnMYDTerkyL5P0oJxj78paX9n7QDolU7C/rKkC2x/y/Y3JP1A0pbutAWg29o+jY+II7ZvkbRV0smS1kfEW13rDEBXtT301tbOeM8O9FxPLqoB8P8HYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HXKZrRn8eLFxfqtt95aWxsaGiquO2PGjGJ99erVxfrpp59erD/11FO1tcOHDxfXRXdxZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJjFdQDMnDmzWN+zZ0+xfsYZZ3Sxm+764IMPamul6wMkadOmTd1uJ4W6WVw7uqjG9m5JhyV9KelIRFzSyfYA9E43rqC7IiI+7MJ2APQQ79mBJDoNe0h62vYrtocneoLtYdvbbW/vcF8AOtDpafzlEbHf9tmSnrH9dkQ8N/4JEbFO0jqJD+iAJnV0ZI+I/dXtQUm/lXRpN5oC0H1th932abZnHb0v6XuSdnSrMQDd1fY4u+1va+xoLo29Hfj3iPhZi3U4jZ/ArFmzivUnn3yyWP/oo49qa6+99lpx3SVLlhTr559/frG+YMGCYn369Om1tdHR0eK6l112WbHeav2suj7OHhHvSyr/qgKAgcHQG5AEYQeSIOxAEoQdSIKwA0nwFVd0ZM6cOcX6HXfc0VZNklauXFmsb9y4sVjPqm7ojSM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBlM3oyIcfln9r9MUXX6yttRpnb/X1W8bZTwxHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2dGT27NnF+urVq9ve9rx589peF8fjyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSfC78ShavLg8Ue+jjz5arC9atKi2tmvXruK6V111VbG+d+/eYj2rtn833vZ62wdt7xi37Ezbz9h+p7otX1kBoHGTOY3fIOnqY5bdKenZiLhA0rPVYwADrGXYI+I5SYeOWbxM0tHfBNoo6brutgWg29q9Nv6ciBiRpIgYsX123RNtD0sabnM/ALqk51+EiYh1ktZJfEAHNKndobdR23Mlqbo92L2WAPRCu2HfImlFdX+FpMe70w6AXmk5zm77EUnfkTRH0qikn0r6D0m/kXSepD2SboiIYz/Em2hbnMYPmBUrVhTr99xzT7G+YMGCYv3zzz+vrV177bXFdbdt21asY2J14+wt37NHxPKa0nc76ghAX3G5LJAEYQeSIOxAEoQdSIKwA0nwU9JTwMyZM2trt99+e3Hdu+66q1g/6aTy8eDQofKI69KlS2trb7/9dnFddBdHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2KWDDhg21teuvv76jbW/atKlYv//++4t1xtIHB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYpYGhoqGfbXrt2bbH+0ksv9Wzf6C6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsU8DTTz9dW1u8eHHPti21Hodfs2ZNbW3//v1t9YT2tDyy215v+6DtHeOW3W37A9uvV3/X9LZNAJ2azGn8BklXT7D8voi4qPp7srttAei2lmGPiOcklef4ATDwOvmA7hbbb1Sn+bPrnmR72PZ229s72BeADrUb9rWShiRdJGlE0s/rnhgR6yLikoi4pM19AeiCtsIeEaMR8WVEfCXpl5Iu7W5bALqtrbDbnjvu4fcl7ah7LoDB4IgoP8F+RNJ3JM2RNCrpp9XjiySFpN2SfhgRIy13Zpd3hrZMnz69tvbwww8X17344ouL9fPOO6+tno46cOBAbW3lypXFdbdu3drRvrOKCE+0vOVFNRGxfILFD3XcEYC+4nJZIAnCDiRB2IEkCDuQBGEHkmg59NbVnTH01nennnpqsX7KKeUBmU8++aSb7XzNF198UazfdtttxfqDDz7YzXamjLqhN47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wouvDCC4v1++67r1i/4oor2t73nj17ivWFCxe2ve2pjHF2IDnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcfYBMGPGjGL9s88+61MnJ2727NqZvyRJ69evr60tW7aso33Pnz+/WB8Zafnr5lMS4+xAcoQdSIKwA0kQdiAJwg4kQdiBJAg7kETLWVzRuaGhoWL9hRdeKNafeOKJYn3Hjh21tVZjzatWrSrWp02bVqy3GutetGhRsV7y3nvvFetZx9Hb1fLIbnuB7W22d9p+y/aPq+Vn2n7G9jvVbfnqCgCNmsxp/BFJ/xARfybpryT9yPafS7pT0rMRcYGkZ6vHAAZUy7BHxEhEvFrdPyxpp6T5kpZJ2lg9baOk63rUI4AuOKH37LYXSloi6feSzomIEWnsPwTbZ9esMyxpuMM+AXRo0mG3PVPSZkk/iYhP7AmvtT9ORKyTtK7aBl+EARoyqaE329M0FvRfRcRj1eJR23Or+lxJB3vTIoBuaHlk99gh/CFJOyPiF+NKWyStkLSmun28Jx1OATfccEOxfu655xbrN910UzfbOSGtzuA6+Yr0p59+WqzffPPNbW8bx5vMafzlkv5O0pu2X6+WrdZYyH9je5WkPZLK/6IBNKpl2CPiBUl1/71/t7vtAOgVLpcFkiDsQBKEHUiCsANJEHYgCb7i2gdnnXVW0y30zObNm4v1e++9t7Z28GD5OqwDBw601RMmxpEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgyuY+aPVzzFdeeWWxfuONNxbr8+bNq619/PHHxXVbeeCBB4r1559/vlg/cuRIR/vHiWPKZiA5wg4kQdiBJAg7kARhB5Ig7EAShB1IgnF2YIphnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmgZdtsLbG+zvdP2W7Z/XC2/2/YHtl+v/q7pfbsA2tXyohrbcyXNjYhXbc+S9Iqk6yT9raRPI+JfJr0zLqoBeq7uoprJzM8+Immkun/Y9k5J87vbHoBeO6H37LYXSloi6ffVoltsv2F7ve3ZNesM295ue3tnrQLoxKSvjbc9U9J/SvpZRDxm+xxJH0oKSfdq7FT/phbb4DQe6LG60/hJhd32NEm/k7Q1In4xQX2hpN9FxF+02A5hB3qs7S/C2LakhyTtHB/06oO7o74vaUenTQLoncl8Gr9U0vOS3pT0VbV4taTlki7S2Gn8bkk/rD7MK22LIzvQYx2dxncLYQd6j++zA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmj5g5Nd9qGk/x73eE61bBANam+D2pdEb+3qZm/n1xX6+n3243Zub4+ISxproGBQexvUviR6a1e/euM0HkiCsANJNB32dQ3vv2RQexvUviR6a1dfemv0PTuA/mn6yA6gTwg7kEQjYbd9te0/2n7X9p1N9FDH9m7bb1bTUDc6P101h95B2zvGLTvT9jO236luJ5xjr6HeBmIa78I0442+dk1Pf9739+y2T5a0S9JVkvZJelnS8oj4Q18bqWF7t6RLIqLxCzBs/7WkTyX969GptWz/s6RDEbGm+o9ydkT844D0drdOcBrvHvVWN83436vB166b05+3o4kj+6WS3o2I9yPiT5J+LWlZA30MvIh4TtKhYxYvk7Sxur9RY/9Y+q6mt4EQESMR8Wp1/7Cko9OMN/raFfrqiybCPl/S3nGP92mw5nsPSU/bfsX2cNPNTOCco9NsVbdnN9zPsVpO491Px0wzPjCvXTvTn3eqibBPNDXNII3/XR4RfynpbyT9qDpdxeSslTSksTkARyT9vMlmqmnGN0v6SUR80mQv403QV19etybCvk/SgnGPvylpfwN9TCgi9le3ByX9VmNvOwbJ6NEZdKvbgw33838iYjQivoyIryT9Ug2+dtU045sl/SoiHqsWN/7aTdRXv163JsL+sqQLbH/L9jck/UDSlgb6OI7t06oPTmT7NEnf0+BNRb1F0orq/gpJjzfYy9cMyjTeddOMq+HXrvHpzyOi73+SrtHYJ/LvSfqnJnqo6evbkv6r+nur6d4kPaKx07r/0dgZ0SpJZ0l6VtI71e2ZA9Tbv2lsau83NBasuQ31tlRjbw3fkPR69XdN069doa++vG5cLgskwRV0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wLwpj8ONnyk5wAAAABJRU5ErkJggg==", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "xb,yb = next(iter(valid_dl))\n", | |
| "plt.imshow(xb[0].view(28,28))\n", | |
| "yb[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 203, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([64, 784]), torch.Size([64]))" | |
| ] | |
| }, | |
| "execution_count": 203, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "xb.shape,yb.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 204, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(0.04, grad_fn=<NllLossBackward0>), tensor(1.))" | |
| ] | |
| }, | |
| "execution_count": 204, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model,opt = get_model()\n", | |
| "fit()\n", | |
| "\n", | |
| "loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Multiprocessing DataLoader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 205, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch.multiprocessing as mp\n", | |
| "from fastcore.basics import store_attr" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 206, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class DataLoader():\n", | |
| " def __init__(self, ds, batchs, n_workers=1, collate_fn=collate): fc.store_attr()\n", | |
| " def __iter__(self):\n", | |
| " ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 207, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, batchs=train_samp, collate_fn=collate, n_workers=2)\n", | |
| "it = iter(train_dl)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 208, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([64, 784]), torch.Size([64]))" | |
| ] | |
| }, | |
| "execution_count": 208, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "xb,yb = next(it)\n", | |
| "xb.shape,yb.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### PyTorch DataLoader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 209, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, BatchSampler" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 210, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_samp = BatchSampler(RandomSampler(train_ds), bs, drop_last=False)\n", | |
| "valid_samp = BatchSampler(SequentialSampler(valid_ds), bs, drop_last=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 211, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, batch_sampler=train_samp, collate_fn=collate)\n", | |
| "valid_dl = DataLoader(valid_ds, batch_sampler=valid_samp, collate_fn=collate)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 212, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(0.09, grad_fn=<NllLossBackward0>), tensor(0.98))" | |
| ] | |
| }, | |
| "execution_count": 212, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model,opt = get_model()\n", | |
| "fit()\n", | |
| "loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "PyTorch can auto-generate the BatchSampler for us:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 213, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, bs, sampler=RandomSampler(train_ds), collate_fn=collate)\n", | |
| "valid_dl = DataLoader(valid_ds, bs, sampler=SequentialSampler(valid_ds), collate_fn=collate)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "PyTorch can also generate the Sequential/RandomSamplers too:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 214, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, bs, shuffle=True, drop_last=True, num_workers=2)\n", | |
| "valid_dl = DataLoader(valid_ds, bs, shuffle=False, num_workers=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 215, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(0.05, grad_fn=<NllLossBackward0>), tensor(0.98))" | |
| ] | |
| }, | |
| "execution_count": 215, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model,opt = get_model()\n", | |
| "fit()\n", | |
| "\n", | |
| "loss_func(model(xb), yb), accuracy(model(xb), yb)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Our dataset actually already knows how to sample a batch of indices all at once:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 216, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor([[0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.],\n", | |
| " [0., 0., 0., ..., 0., 0., 0.]]),\n", | |
| " tensor([9, 1, 3]))" | |
| ] | |
| }, | |
| "execution_count": 216, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "train_ds[[4,6,7]]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "...that means that we can actually skip the batch_sampler and collate_fn entirely:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 217, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "train_dl = DataLoader(train_ds, sampler=train_samp)\n", | |
| "valid_dl = DataLoader(valid_ds, sampler=valid_samp)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 218, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([1, 64, 784]), torch.Size([1, 64]))" | |
| ] | |
| }, | |
| "execution_count": 218, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "xb,yb = next(iter(train_dl))\n", | |
| "xb.shape,yb.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Validation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "You **always** should also have a [validation set](http://www.fast.ai/2017/11/13/validation-sets/), in order to identify if you are overfitting.\n", | |
| "\n", | |
| "We will calculate and print the validation loss at the end of each epoch.\n", | |
| "\n", | |
| "(Note that we always call `model.train()` before training, and `model.eval()` before inference, because these are used by layers such as `nn.BatchNorm2d` and `nn.Dropout` to ensure appropriate behaviour for these different phases.)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 105, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def fit(epochs, model, loss_func, opt, train_dl, valid_dl):\n", | |
| " for epoch in range(epochs):\n", | |
| " ...\n", | |
| " print(epoch, tot_loss/count, tot_acc/count)\n", | |
| " return tot_loss/count, tot_acc/count" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 106, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_dls(train_ds, valid_ds, bs, **kwargs):\n", | |
| " return ... " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now, our whole process of obtaining the data loaders and fitting the model can be run in 3 lines of code:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 107, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "0 0.2599395067691803 0.9152\n", | |
| "1 0.1375726773455739 0.9599\n", | |
| "2 0.10632649689242243 0.9696\n", | |
| "3 0.11931819728948176 0.9643\n", | |
| "4 0.15087997979782522 0.9549\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "train_dl,valid_dl = get_dls(train_ds, valid_ds, bs)\n", | |
| "model,opt = get_model()\n", | |
| "loss,acc = fit(5, model, loss_func, opt, train_dl, valid_dl)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.10" | |
| }, | |
| "toc": { | |
| "base_numbering": 1, | |
| "nav_menu": {}, | |
| "number_sections": false, | |
| "sideBar": true, | |
| "skip_h1_title": false, | |
| "title_cell": "Table of Contents", | |
| "title_sidebar": "Contents", | |
| "toc_cell": false, | |
| "toc_position": {}, | |
| "toc_section_display": true, | |
| "toc_window_display": false | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment