Last active
June 21, 2020 21:06
-
-
Save mrnabati/275aa79bc1810af9f137659b685bdef5 to your computer and use it in GitHub Desktop.
PyTorch 101: Freezing Layers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Adv. PyTorch: Freezing Layers\n", | |
| "\n", | |
| "If you're planning to fine-tune a trained model on a different dataset, chances are you're going to freeze some of the early layers and only update the later layers. I won't go into the details of why you may want to freeze some layers and which ones should be frozen, but I'll show you how to do it in PyTorch. Let's get started!\n", | |
| "\n", | |
| "We first need a pre-trained model to start with. The [models subpackage](https://pytorch.org/docs/stable/torchvision/models.html) in the `torchvision` package provides definitions for many of the poplular model architectures for image classification. You can construct these models by simply calling their constructor, which would initialize the model with random weights. To use the pre-trained models from the PyTorch Model Zoo, you can call the constructor with the `pretrained=True` argument. Let's load the pretrained VGG16 model:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torchvision.models as models\n", | |
| "\n", | |
| "vgg16 = models.vgg16(pretrained=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "This will start downloading the pretrained model into your computer's PyTorch cache folder, which usually is the `.cache/torch/checkpoints` folder under your home directory.\n", | |
| "\n", | |
| "There are multiple ways you can look into the model to see its modules and layers. One way is using the `.modules()` member function, which returns in iterator containing all the member objects of the model. The `.modules()` functions recursively goes thruogh all the modules and submodules of the model:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[VGG(\n", | |
| " (features): Sequential(\n", | |
| " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (1): ReLU(inplace=True)\n", | |
| " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (3): ReLU(inplace=True)\n", | |
| " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (6): ReLU(inplace=True)\n", | |
| " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (8): ReLU(inplace=True)\n", | |
| " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (11): ReLU(inplace=True)\n", | |
| " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (13): ReLU(inplace=True)\n", | |
| " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (15): ReLU(inplace=True)\n", | |
| " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (18): ReLU(inplace=True)\n", | |
| " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (20): ReLU(inplace=True)\n", | |
| " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (22): ReLU(inplace=True)\n", | |
| " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (25): ReLU(inplace=True)\n", | |
| " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (27): ReLU(inplace=True)\n", | |
| " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (29): ReLU(inplace=True)\n", | |
| " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " )\n", | |
| " (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n", | |
| " (classifier): Sequential(\n", | |
| " (0): Linear(in_features=25088, out_features=4096, bias=True)\n", | |
| " (1): ReLU(inplace=True)\n", | |
| " (2): Dropout(p=0.5, inplace=False)\n", | |
| " (3): Linear(in_features=4096, out_features=4096, bias=True)\n", | |
| " (4): ReLU(inplace=True)\n", | |
| " (5): Dropout(p=0.5, inplace=False)\n", | |
| " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", | |
| " )\n", | |
| "), Sequential(\n", | |
| " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (1): ReLU(inplace=True)\n", | |
| " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (3): ReLU(inplace=True)\n", | |
| " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (6): ReLU(inplace=True)\n", | |
| " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (8): ReLU(inplace=True)\n", | |
| " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (11): ReLU(inplace=True)\n", | |
| " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (13): ReLU(inplace=True)\n", | |
| " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (15): ReLU(inplace=True)\n", | |
| " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (18): ReLU(inplace=True)\n", | |
| " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (20): ReLU(inplace=True)\n", | |
| " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (22): ReLU(inplace=True)\n", | |
| " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (25): ReLU(inplace=True)\n", | |
| " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (27): ReLU(inplace=True)\n", | |
| " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| " (29): ReLU(inplace=True)\n", | |
| " (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| "), Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), ReLU(inplace=True), MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), AdaptiveAvgPool2d(output_size=(7, 7)), Sequential(\n", | |
| " (0): Linear(in_features=25088, out_features=4096, bias=True)\n", | |
| " (1): ReLU(inplace=True)\n", | |
| " (2): Dropout(p=0.5, inplace=False)\n", | |
| " (3): Linear(in_features=4096, out_features=4096, bias=True)\n", | |
| " (4): ReLU(inplace=True)\n", | |
| " (5): Dropout(p=0.5, inplace=False)\n", | |
| " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", | |
| "), Linear(in_features=25088, out_features=4096, bias=True), ReLU(inplace=True), Dropout(p=0.5, inplace=False), Linear(in_features=4096, out_features=4096, bias=True), ReLU(inplace=True), Dropout(p=0.5, inplace=False), Linear(in_features=4096, out_features=1000, bias=True)]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(list(vgg16.modules()))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "That's a lot of information spewed out onto the screen! Let's use the `.named_module()` function instead, which returns a (name, module) tuple and only print the names:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n", | |
| "features\n", | |
| "features.0\n", | |
| "features.1\n", | |
| "features.2\n", | |
| "features.3\n", | |
| "features.4\n", | |
| "features.5\n", | |
| "features.6\n", | |
| "features.7\n", | |
| "features.8\n", | |
| "features.9\n", | |
| "features.10\n", | |
| "features.11\n", | |
| "features.12\n", | |
| "features.13\n", | |
| "features.14\n", | |
| "features.15\n", | |
| "features.16\n", | |
| "features.17\n", | |
| "features.18\n", | |
| "features.19\n", | |
| "features.20\n", | |
| "features.21\n", | |
| "features.22\n", | |
| "features.23\n", | |
| "features.24\n", | |
| "features.25\n", | |
| "features.26\n", | |
| "features.27\n", | |
| "features.28\n", | |
| "features.29\n", | |
| "features.30\n", | |
| "avgpool\n", | |
| "classifier\n", | |
| "classifier.0\n", | |
| "classifier.1\n", | |
| "classifier.2\n", | |
| "classifier.3\n", | |
| "classifier.4\n", | |
| "classifier.5\n", | |
| "classifier.6\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for (name, module) in vgg16.named_modules():\n", | |
| " print(name)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "That's much better! We can see the top level modules are *features*, *avgpool* and *classifier*. We can also see that the *features* and *calssifier* modules consist of 31 and 7 layers respectively. These layers are not named, and only have numbers associated with them. If you want to see an even more concise representation of the network, you can use the `.named_children()` function which does not go inside the top level modules recursively:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "features\n", | |
| "avgpool\n", | |
| "classifier\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for (name, module) in vgg16.named_children():\n", | |
| " print(name)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now let's see what layers are there under the *features* module. Here we use the `.children()` function to get the layers under the *features* module, since these layers are not 'named':" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| "Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| "Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| "Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", | |
| "Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", | |
| "ReLU(inplace=True)\n", | |
| "MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for (name, module) in vgg16.named_children():\n", | |
| " if name == 'features':\n", | |
| " for layer in module.children():\n", | |
| " print(layer)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We can even go deeper and look at the parameters in each layer. Let's get the parameters of the first layer under the *features* module:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Parameter containing:\n", | |
| "tensor([[[[-5.5373e-01, 1.4270e-01, 5.2896e-01],\n", | |
| " [-5.8312e-01, 3.5655e-01, 7.6566e-01],\n", | |
| " [-6.9022e-01, -4.8019e-02, 4.8409e-01]],\n", | |
| "\n", | |
| " [[ 1.7548e-01, 9.8630e-03, -8.1413e-02],\n", | |
| " [ 4.4089e-02, -7.0323e-02, -2.6035e-01],\n", | |
| " [ 1.3239e-01, -1.7279e-01, -1.3226e-01]],\n", | |
| "\n", | |
| " [[ 3.1303e-01, -1.6591e-01, -4.2752e-01],\n", | |
| " [ 4.7519e-01, -8.2677e-02, -4.8700e-01],\n", | |
| " [ 6.3203e-01, 1.9308e-02, -2.7753e-01]]],\n", | |
| "\n", | |
| "\n", | |
| " [[[ 2.3254e-01, 1.2666e-01, 1.8605e-01],\n", | |
| " [-4.2805e-01, -2.4349e-01, 2.4628e-01],\n", | |
| " [-2.5066e-01, 1.4177e-01, -5.4864e-03]],\n", | |
| "\n", | |
| " [[-1.4076e-01, -2.1903e-01, 1.5041e-01],\n", | |
| " [-8.4127e-01, -3.5176e-01, 5.6398e-01],\n", | |
| " [-2.4194e-01, 5.1928e-01, 5.3915e-01]],\n", | |
| "\n", | |
| " [[-3.1432e-01, -3.7048e-01, -1.3094e-01],\n", | |
| " [-4.7144e-01, -1.5503e-01, 3.4589e-01],\n", | |
| " [ 5.4384e-02, 5.8683e-01, 4.9580e-01]]],\n", | |
| "\n", | |
| "\n", | |
| " [[[ 1.7715e-01, 5.2149e-01, 9.8740e-03],\n", | |
| " [-2.7185e-01, -7.1709e-01, 3.1292e-01],\n", | |
| " [-7.5753e-02, -2.2079e-01, 3.3455e-01]],\n", | |
| "\n", | |
| " [[ 3.0924e-01, 6.7071e-01, 2.0546e-02],\n", | |
| " [-4.6607e-01, -1.0697e+00, 3.3501e-01],\n", | |
| " [-8.0284e-02, -3.0522e-01, 5.4460e-01]],\n", | |
| "\n", | |
| " [[ 3.1572e-01, 4.2335e-01, -3.4976e-01],\n", | |
| " [ 8.6354e-02, -4.6457e-01, 1.1803e-02],\n", | |
| " [ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],\n", | |
| "\n", | |
| "\n", | |
| " ...,\n", | |
| "\n", | |
| "\n", | |
| " [[[ 7.7599e-02, 1.2692e-01, 3.2305e-02],\n", | |
| " [ 2.2131e-01, 2.4681e-01, -4.6637e-02],\n", | |
| " [ 4.6407e-02, 2.8246e-02, 1.7528e-02]],\n", | |
| "\n", | |
| " [[-1.8327e-01, -6.7425e-02, -7.2120e-03],\n", | |
| " [-4.8855e-02, 7.0427e-03, -1.2883e-01],\n", | |
| " [-6.4601e-02, -6.4566e-02, 4.4235e-02]],\n", | |
| "\n", | |
| " [[-2.2547e-01, -1.1931e-01, -2.3425e-02],\n", | |
| " [-9.9171e-02, -1.5143e-02, 9.5385e-04],\n", | |
| " [-2.6137e-02, 1.3567e-03, 1.4282e-01]]],\n", | |
| "\n", | |
| "\n", | |
| " [[[ 1.6520e-02, -3.2225e-02, -3.8450e-03],\n", | |
| " [-6.8206e-02, -1.9445e-01, -1.4166e-01],\n", | |
| " [-6.9528e-02, -1.8340e-01, -1.7422e-01]],\n", | |
| "\n", | |
| " [[ 4.2781e-02, -6.7529e-02, -7.0309e-03],\n", | |
| " [ 1.1765e-02, -1.4958e-01, -1.2361e-01],\n", | |
| " [ 1.0205e-02, -1.0393e-01, -1.1742e-01]],\n", | |
| "\n", | |
| " [[ 1.2661e-01, 8.5046e-02, 1.3066e-01],\n", | |
| " [ 1.7585e-01, 1.1288e-01, 1.1937e-01],\n", | |
| " [ 1.4656e-01, 9.8892e-02, 1.0348e-01]]],\n", | |
| "\n", | |
| "\n", | |
| " [[[ 3.2176e-02, -1.0766e-01, -2.6388e-01],\n", | |
| " [ 2.7957e-01, -3.7416e-02, -2.5471e-01],\n", | |
| " [ 3.4872e-01, 3.0041e-02, -5.5898e-02]],\n", | |
| "\n", | |
| " [[ 2.5063e-01, 1.5543e-01, -1.7432e-01],\n", | |
| " [ 3.9255e-01, 3.2306e-02, -3.5191e-01],\n", | |
| " [ 1.9299e-01, -1.9898e-01, -2.9713e-01]],\n", | |
| "\n", | |
| " [[ 4.6032e-01, 4.3399e-01, 2.8352e-01],\n", | |
| " [ 1.6341e-01, -5.8165e-02, -1.9196e-01],\n", | |
| " [-1.9521e-01, -4.5630e-01, -4.2732e-01]]]], requires_grad=True)\n", | |
| "Parameter containing:\n", | |
| "tensor([ 0.4034, 0.3778, 0.4644, -0.3228, 0.3940, -0.3953, 0.3951, -0.5496,\n", | |
| " 0.2693, -0.7602, -0.3508, 0.2334, -1.3239, -0.1694, 0.3938, -0.1026,\n", | |
| " 0.0460, -0.6995, 0.1549, 0.5628, 0.3011, 0.3425, 0.1073, 0.4651,\n", | |
| " 0.1295, 0.0788, -0.0492, -0.5638, 0.1465, -0.3890, -0.0715, 0.0649,\n", | |
| " 0.2768, 0.3279, 0.5682, -1.2640, -0.8368, -0.9485, 0.1358, 0.2727,\n", | |
| " 0.1841, -0.5325, 0.3507, -0.0827, -1.0248, -0.6912, -0.7711, 0.2612,\n", | |
| " 0.4033, -0.4802, -0.3066, 0.5807, -1.3325, 0.4844, -0.8160, 0.2386,\n", | |
| " 0.2300, 0.4979, 0.5553, 0.5230, -0.2182, 0.0117, -0.5516, 0.2108],\n", | |
| " requires_grad=True)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for (name, module) in vgg16.named_children():\n", | |
| " if name == 'features':\n", | |
| " for layer in module.children():\n", | |
| " for param in layer.parameters():\n", | |
| " print(param)\n", | |
| " break" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now that we have access to all the modules, layers and their parameters, we can easily freeze them by setting the parameters' `requires_grad` flag to `False`. This would prevent calculating the gradients for these parameters in the `backward` step which in turn prevents the optimizer from updating them.\n", | |
| "\n", | |
| "Now let's freeze all the parameters in the *features* module:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Layer \"0\" in module \"features\" was frozen!\n", | |
| "Layer \"1\" in module \"features\" was frozen!\n", | |
| "Layer \"2\" in module \"features\" was frozen!\n", | |
| "Layer \"3\" in module \"features\" was frozen!\n", | |
| "Layer \"4\" in module \"features\" was frozen!\n", | |
| "Layer \"5\" in module \"features\" was frozen!\n", | |
| "Layer \"6\" in module \"features\" was frozen!\n", | |
| "Layer \"7\" in module \"features\" was frozen!\n", | |
| "Layer \"8\" in module \"features\" was frozen!\n", | |
| "Layer \"9\" in module \"features\" was frozen!\n", | |
| "Layer \"10\" in module \"features\" was frozen!\n", | |
| "Layer \"11\" in module \"features\" was frozen!\n", | |
| "Layer \"12\" in module \"features\" was frozen!\n", | |
| "Layer \"13\" in module \"features\" was frozen!\n", | |
| "Layer \"14\" in module \"features\" was frozen!\n", | |
| "Layer \"15\" in module \"features\" was frozen!\n", | |
| "Layer \"16\" in module \"features\" was frozen!\n", | |
| "Layer \"17\" in module \"features\" was frozen!\n", | |
| "Layer \"18\" in module \"features\" was frozen!\n", | |
| "Layer \"19\" in module \"features\" was frozen!\n", | |
| "Layer \"20\" in module \"features\" was frozen!\n", | |
| "Layer \"21\" in module \"features\" was frozen!\n", | |
| "Layer \"22\" in module \"features\" was frozen!\n", | |
| "Layer \"23\" in module \"features\" was frozen!\n", | |
| "Layer \"24\" in module \"features\" was frozen!\n", | |
| "Layer \"25\" in module \"features\" was frozen!\n", | |
| "Layer \"26\" in module \"features\" was frozen!\n", | |
| "Layer \"27\" in module \"features\" was frozen!\n", | |
| "Layer \"28\" in module \"features\" was frozen!\n", | |
| "Layer \"29\" in module \"features\" was frozen!\n", | |
| "Layer \"30\" in module \"features\" was frozen!\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "layer_counter = 0\n", | |
| "for (name, module) in vgg16.named_children():\n", | |
| " if name == 'features':\n", | |
| " for layer in module.children():\n", | |
| " for param in layer.parameters():\n", | |
| " param.requires_grad = False\n", | |
| " \n", | |
| " print('Layer \"{}\" in module \"{}\" was frozen!'.format(layer_counter, name))\n", | |
| " layer_counter+=1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Now that some of the parameters are frozen, the optimizer needs to be modified to only get the parameters with `requires_grad=True`. We can do this by writing a Lambda function when constructing the optimizer:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, vgg16.parameters()), lr=0.001)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "You can now start training your partially frozen model!" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3.7.7", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment