Last active
January 17, 2024 17:12
-
-
Save hsm207/7bfbe524bfd9b60d1a9e209759064180 to your computer and use it in GitHub Desktop.
Code to accompany my blog post at https://bit.ly/2KfmQ76
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "toc": true | |
| }, | |
| "source": [ | |
| "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n", | |
| "<div class=\"toc\" style=\"margin-top: 1em;\"><ul class=\"toc-item\"><li><span><a href=\"#Introduction\" data-toc-modified-id=\"Introduction-1\"><span class=\"toc-item-num\">1 </span>Introduction</a></span></li><li><span><a href=\"#Libraries\" data-toc-modified-id=\"Libraries-2\"><span class=\"toc-item-num\">2 </span>Libraries</a></span></li><li><span><a href=\"#Explanation\" data-toc-modified-id=\"Explanation-3\"><span class=\"toc-item-num\">3 </span>Explanation</a></span><ul class=\"toc-item\"><li><span><a href=\"#Example:-Single-channel-image-and-convolution-has-only-1-output-channel\" data-toc-modified-id=\"Example:-Single-channel-image-and-convolution-has-only-1-output-channel-3.1\"><span class=\"toc-item-num\">3.1 </span>Example: Single channel image and convolution has only 1 output channel</a></span></li><li><span><a href=\"#Bigger-Example\" data-toc-modified-id=\"Bigger-Example-3.2\"><span class=\"toc-item-num\">3.2 </span>Bigger Example</a></span></li></ul></li></ul></div>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Introduction" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "This notebook describes how to express a 2D Convolution in terms of matrix multiplication:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Libraries" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We only need pytorch:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'1.0.0'" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.__version__" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Explanation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Example: Single channel image and convolution has only 1 output channel" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let $X$ be a $4 \\times 4$ single channel input image:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[ 1., 2., 3., 4.],\n", | |
| " [ 5., 6., 7., 8.],\n", | |
| " [ 9., 10., 11., 12.],\n", | |
| " [13., 14., 15., 16.]]]])" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X = torch.arange(1, 17).view(-1, 1, 4, 4).float()\n", | |
| "X" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's define a 2D convolution with the following properties:\n", | |
| "\n", | |
| "* kernel size: $2 \\times 2$\n", | |
| "* padding: 0\n", | |
| "* stride: 1\n", | |
| "* bias: 0\n", | |
| "* output channels: 1\n", | |
| "* initial weights, $W$: $\\begin{bmatrix}\n", | |
| " 1 & 2 \\\\\n", | |
| " 3 & 4 \\\\ \n", | |
| "\\end{bmatrix}$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "conv = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=1)\n", | |
| "W = torch.arange(1, 5).view(-1, 1, 2, 2).float()\n", | |
| "\n", | |
| "conv.weight.data = W\n", | |
| "conv.bias.data = torch.zeros([1])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Given the dimension of the input image and the 2D convolution, the size of the output (height and wdith) after applying the convolution to the image is:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def output_size_after_convolution(image_dim, n_padding, kernel_size, stride):\n", | |
| " return (image_dim - kernel_size + 2 * n_padding)/stride + 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "3" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "n_output = output_size_after_convolution(image_dim=4, kernel_size=2, n_padding=0, stride=1)\n", | |
| "n_output = int(n_output)\n", | |
| "n_output" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Compute the result of doing the convolution using PyTorch's built-in function:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[ 44., 54., 64.],\n", | |
| " [ 84., 94., 104.],\n", | |
| " [124., 134., 144.]]]], grad_fn=<MkldnnConvolutionBackward>)" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv2d_torch = conv(X)\n", | |
| "result_conv2d_torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now we compute the result of the convolution ourselves using matrix multiplication.\n", | |
| "\n", | |
| "First, we can express all the image patches that the kernel will get passed to the kernel as a $4 \\times 9$ matrix:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],\n", | |
| " [ 2., 3., 4., 6., 7., 8., 10., 11., 12.],\n", | |
| " [ 5., 6., 7., 9., 10., 11., 13., 14., 15.],\n", | |
| " [ 6., 7., 8., 10., 11., 12., 14., 15., 16.]]])" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "unfold = torch.nn.Unfold(kernel_size=2, padding=0, stride=1)\n", | |
| "X_unfold = unfold(X)\n", | |
| "X_unfold" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Notice that the i-th column corresponds to the image patch seen by the i-th output neuron." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Similarly, we can express the kernel of the convolution's operator as a $1 \\times 4$ matrix:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[1., 2., 3., 4.]]])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W_unfold = unfold(W).transpose(2, 1)\n", | |
| "W_unfold" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Notice that each row represents the flattened weights of the i-th kernel:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now we can do the multiplication:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[ 44., 54., 64., 84., 94., 104., 124., 134., 144.]]])" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv2d_matmul = torch.matmul(W_unfold, X_unfold)\n", | |
| "result_conv2d_matmul" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "All that is left is to reshape it to the expected shape:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[ 44., 54., 64.],\n", | |
| " [ 84., 94., 104.],\n", | |
| " [124., 134., 144.]]]])" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv2d_matmul = result_conv2d_matmul.view(-1, conv.out_channels, n_output, n_output)\n", | |
| "result_conv2d_matmul" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Check that results are as expected:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert torch.equal(result_conv2d_matmul, result_conv2d_torch)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Bigger Example" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Let's experiment with a $4 \\times 4 \\times 3$ image, $X$:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[ 1., 2., 3., 4.],\n", | |
| " [ 5., 6., 7., 8.],\n", | |
| " [ 9., 10., 11., 12.],\n", | |
| " [13., 14., 15., 16.]],\n", | |
| "\n", | |
| " [[17., 18., 19., 20.],\n", | |
| " [21., 22., 23., 24.],\n", | |
| " [25., 26., 27., 28.],\n", | |
| " [29., 30., 31., 32.]],\n", | |
| "\n", | |
| " [[33., 34., 35., 36.],\n", | |
| " [37., 38., 39., 40.],\n", | |
| " [41., 42., 43., 44.],\n", | |
| " [45., 46., 47., 48.]]]])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X = torch.arange(1, 49).view(-1, 3, 4, 4).float()\n", | |
| "X" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Define the convolution operation to have kernel size $2 \\times 2$, $0$ padding, stride $1$, $0$ bias and $2$ output channels:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[ 1., 2.],\n", | |
| " [ 3., 4.]],\n", | |
| "\n", | |
| " [[ 5., 6.],\n", | |
| " [ 7., 8.]],\n", | |
| "\n", | |
| " [[ 9., 10.],\n", | |
| " [11., 12.]]],\n", | |
| "\n", | |
| "\n", | |
| " [[[13., 14.],\n", | |
| " [15., 16.]],\n", | |
| "\n", | |
| " [[17., 18.],\n", | |
| " [19., 20.]],\n", | |
| "\n", | |
| " [[21., 22.],\n", | |
| " [23., 24.]]]])" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "conv = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=2, stride=1)\n", | |
| "W = torch.arange(1, 25).view(-1, conv.in_channels, 2, 2).float()\n", | |
| "\n", | |
| "conv.weight.data = W\n", | |
| "conv.bias.data = torch.zeros([conv.out_channels])\n", | |
| "conv.weight.data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Perform the convolution with PyTorch's built-in functions:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[2060., 2138., 2216.],\n", | |
| " [2372., 2450., 2528.],\n", | |
| " [2684., 2762., 2840.]],\n", | |
| "\n", | |
| " [[4868., 5090., 5312.],\n", | |
| " [5756., 5978., 6200.],\n", | |
| " [6644., 6866., 7088.]]]], grad_fn=<MkldnnConvolutionBackward>)" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv2d_torch = conv(X)\n", | |
| "result_conv2d_torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[2060., 2138., 2216.],\n", | |
| " [2372., 2450., 2528.],\n", | |
| " [2684., 2762., 2840.]],\n", | |
| "\n", | |
| " [[4868., 5090., 5312.],\n", | |
| " [5756., 5978., 6200.],\n", | |
| " [6644., 6866., 7088.]]]], grad_fn=<MkldnnConvolutionBackward>)" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv2d_torch = conv(X)\n", | |
| "result_conv2d_torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Compute the convolution using matrix multiplication:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],\n", | |
| " [ 2., 3., 4., 6., 7., 8., 10., 11., 12.],\n", | |
| " [ 5., 6., 7., 9., 10., 11., 13., 14., 15.],\n", | |
| " [ 6., 7., 8., 10., 11., 12., 14., 15., 16.],\n", | |
| " [17., 18., 19., 21., 22., 23., 25., 26., 27.],\n", | |
| " [18., 19., 20., 22., 23., 24., 26., 27., 28.],\n", | |
| " [21., 22., 23., 25., 26., 27., 29., 30., 31.],\n", | |
| " [22., 23., 24., 26., 27., 28., 30., 31., 32.],\n", | |
| " [33., 34., 35., 37., 38., 39., 41., 42., 43.],\n", | |
| " [34., 35., 36., 38., 39., 40., 42., 43., 44.],\n", | |
| " [37., 38., 39., 41., 42., 43., 45., 46., 47.],\n", | |
| " [38., 39., 40., 42., 43., 44., 46., 47., 48.]]])" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_unfold = unfold(X)\n", | |
| "X_unfold" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]],\n", | |
| "\n", | |
| " [[13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]]])" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W_unfold = unfold(W).transpose(2, 1)\n", | |
| "W_unfold" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[2060., 2138., 2216., 2372., 2450., 2528., 2684., 2762., 2840.]],\n", | |
| "\n", | |
| " [[4868., 5090., 5312., 5756., 5978., 6200., 6644., 6866., 7088.]]])" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv2d_matmul = torch.matmul(W_unfold, X_unfold)\n", | |
| "result_conv2d_matmul" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[2060., 2138., 2216.],\n", | |
| " [2372., 2450., 2528.],\n", | |
| " [2684., 2762., 2840.]],\n", | |
| "\n", | |
| " [[4868., 5090., 5312.],\n", | |
| " [5756., 5978., 6200.],\n", | |
| " [6644., 6866., 7088.]]]])" | |
| ] | |
| }, | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv2d_matmul = result_conv2d_matmul.view(-1, conv.out_channels, n_output, n_output)\n", | |
| "result_conv2d_matmul" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert torch.equal(result_conv2d_torch, result_conv2d_matmul)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Appendix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## 1D Convolution" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Specify an input vector:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.]],\n", | |
| "\n", | |
| " [[13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]],\n", | |
| "\n", | |
| " [[25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36.]]]])" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# (batch size, channel, height, width)\n", | |
| "X = torch.arange(1, 37).view(-1, 3, 1, 12).float()\n", | |
| "X" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Params for the convolution operation:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "in_channels = X.shape[1]\n", | |
| "out_channels = 2\n", | |
| "kernel_size = (1, 4)\n", | |
| "stride = 2\n", | |
| "padding = 0\n", | |
| "bias = False" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Compute the expected output size of the convolution:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(1, 5)" | |
| ] | |
| }, | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "_, _, image_h, image_w = X.shape\n", | |
| "kernel_h, kernel_w = kernel_size\n", | |
| "\n", | |
| "output_h = output_size_after_convolution(image_dim=image_h, n_padding=padding, kernel_size=kernel_h, stride=stride)\n", | |
| "output_w = output_size_after_convolution(image_dim=image_w, n_padding=padding, kernel_size=kernel_w, stride=stride)\n", | |
| "\n", | |
| "output_h = int(output_h)\n", | |
| "output_w = int(output_w)\n", | |
| "\n", | |
| "output_h, output_w" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Define and perform the convolution operation using pytorch:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "conv = torch.nn.Conv1d(in_channels=in_channels,\n", | |
| " out_channels=out_channels,\n", | |
| " kernel_size=kernel_size,\n", | |
| " stride=stride,\n", | |
| " padding=padding,\n", | |
| " bias=bias)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "initial weights in the convolution operation:\n", | |
| "tensor([[[[ 1., 2., 3., 4.]],\n", | |
| "\n", | |
| " [[ 5., 6., 7., 8.]],\n", | |
| "\n", | |
| " [[ 9., 10., 11., 12.]]],\n", | |
| "\n", | |
| "\n", | |
| " [[[13., 14., 15., 16.]],\n", | |
| "\n", | |
| " [[17., 18., 19., 20.]],\n", | |
| "\n", | |
| " [[21., 22., 23., 24.]]]])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "W = torch.arange(1, kernel_w * in_channels * out_channels + 1).view(out_channels, in_channels, 1, kernel_w).float()\n", | |
| "print(f'initial weights in the convolution operation:\\n{W}')\n", | |
| "\n", | |
| "conv.weight.data = W" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[1530., 1686., 1842., 1998., 2154.]],\n", | |
| "\n", | |
| " [[3618., 4062., 4506., 4950., 5394.]]]],\n", | |
| " grad_fn=<MkldnnConvolutionBackward>)" | |
| ] | |
| }, | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv1d_torch = conv(X)\n", | |
| "result_conv1d_torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 2, 1, 5])" | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv1d_torch.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now we do the convolution ourselves using matrix multiplication:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Create the matrix of image patches:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "unfold = torch.nn.Unfold(kernel_size=kernel_size,\n", | |
| " padding=padding,\n", | |
| " stride=stride)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[ 1., 3., 5., 7., 9.],\n", | |
| " [ 2., 4., 6., 8., 10.],\n", | |
| " [ 3., 5., 7., 9., 11.],\n", | |
| " [ 4., 6., 8., 10., 12.],\n", | |
| " [13., 15., 17., 19., 21.],\n", | |
| " [14., 16., 18., 20., 22.],\n", | |
| " [15., 17., 19., 21., 23.],\n", | |
| " [16., 18., 20., 22., 24.],\n", | |
| " [25., 27., 29., 31., 33.],\n", | |
| " [26., 28., 30., 32., 34.],\n", | |
| " [27., 29., 31., 33., 35.],\n", | |
| " [28., 30., 32., 34., 36.]]])" | |
| ] | |
| }, | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_unfold = unfold(X)\n", | |
| "X_unfold" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 12, 5])" | |
| ] | |
| }, | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_unfold.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Also unfold the parameters in the convolution operation:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],\n", | |
| " [13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24.]])" | |
| ] | |
| }, | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W_unfold = W.view(-1, kernel_h * kernel_w * in_channels)\n", | |
| "W_unfold" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 12])" | |
| ] | |
| }, | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W_unfold.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Perform the matrix multiplication and reshape to the correct output:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[1530., 1686., 1842., 1998., 2154.],\n", | |
| " [3618., 4062., 4506., 4950., 5394.]]])" | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv1d_matmul = torch.matmul(W_unfold, X_unfold)\n", | |
| "result_conv1d_matmul" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[1530., 1686., 1842., 1998., 2154.]],\n", | |
| "\n", | |
| " [[3618., 4062., 4506., 4950., 5394.]]]])" | |
| ] | |
| }, | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "result_conv1d_matmul = result_conv1d_matmul.view(-1, out_channels, output_h, output_w)\n", | |
| "result_conv1d_matmul" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert torch.equal(result_conv1d_torch, result_conv1d_matmul)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "fastai" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.2" | |
| }, | |
| "toc": { | |
| "nav_menu": {}, | |
| "number_sections": true, | |
| "sideBar": true, | |
| "skip_h1_title": false, | |
| "toc_cell": true, | |
| "toc_position": {}, | |
| "toc_section_display": "block", | |
| "toc_window_display": true | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment