Created
March 23, 2022 07:32
-
-
Save moarshy/0b1edde8afd538e5073fb771b2753315 to your computer and use it in GitHub Desktop.
MNUnetModel.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "MNUnetModel.ipynb", | |
| "provenance": [], | |
| "collapsed_sections": [ | |
| "nl_Y1QW9OsZR" | |
| ], | |
| "authorship_tag": "ABX9TyNFaFW9sZHD/tLhto6ay9vQ", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU", | |
| "widgets": { | |
| "application/vnd.jupyter.widget-state+json": { | |
| "8eb841b4c62b4042bd9a5af15342c012": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HBoxModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HBoxModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HBoxView", | |
| "box_style": "", | |
| "children": [ | |
| "IPY_MODEL_dd011ba65aa04664bd63eb8f79ca29c8", | |
| "IPY_MODEL_46490715f4934745aa454ae8e119f40d", | |
| "IPY_MODEL_f7444e1bb42a4c6f913e05514445ce59" | |
| ], | |
| "layout": "IPY_MODEL_cd7643d9744340ec99319a52c2faf2c9" | |
| } | |
| }, | |
| "dd011ba65aa04664bd63eb8f79ca29c8": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_a1860ff990b042a5b50f87066e30bb9b", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_18da3006e55d432d9038da4f74558e4f", | |
| "value": "100%" | |
| } | |
| }, | |
| "46490715f4934745aa454ae8e119f40d": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "FloatProgressModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "FloatProgressModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "ProgressView", | |
| "bar_style": "success", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_5ca90a8d4e0c4be3888dbccafe6aafce", | |
| "max": 178793939, | |
| "min": 0, | |
| "orientation": "horizontal", | |
| "style": "IPY_MODEL_56aa6084019b49b1ad09c0f1aaabdf01", | |
| "value": 178793939 | |
| } | |
| }, | |
| "f7444e1bb42a4c6f913e05514445ce59": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "HTMLModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_dom_classes": [], | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "HTMLModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/controls", | |
| "_view_module_version": "1.5.0", | |
| "_view_name": "HTMLView", | |
| "description": "", | |
| "description_tooltip": null, | |
| "layout": "IPY_MODEL_e43806d23dad4c86afb2428ce56504b3", | |
| "placeholder": "", | |
| "style": "IPY_MODEL_4eec62679381415ab632253d8f043108", | |
| "value": " 171M/171M [00:03<00:00, 44.6MB/s]" | |
| } | |
| }, | |
| "cd7643d9744340ec99319a52c2faf2c9": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "a1860ff990b042a5b50f87066e30bb9b": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "18da3006e55d432d9038da4f74558e4f": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| }, | |
| "5ca90a8d4e0c4be3888dbccafe6aafce": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "56aa6084019b49b1ad09c0f1aaabdf01": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "ProgressStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "ProgressStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "bar_color": null, | |
| "description_width": "" | |
| } | |
| }, | |
| "e43806d23dad4c86afb2428ce56504b3": { | |
| "model_module": "@jupyter-widgets/base", | |
| "model_name": "LayoutModel", | |
| "model_module_version": "1.2.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/base", | |
| "_model_module_version": "1.2.0", | |
| "_model_name": "LayoutModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "LayoutView", | |
| "align_content": null, | |
| "align_items": null, | |
| "align_self": null, | |
| "border": null, | |
| "bottom": null, | |
| "display": null, | |
| "flex": null, | |
| "flex_flow": null, | |
| "grid_area": null, | |
| "grid_auto_columns": null, | |
| "grid_auto_flow": null, | |
| "grid_auto_rows": null, | |
| "grid_column": null, | |
| "grid_gap": null, | |
| "grid_row": null, | |
| "grid_template_areas": null, | |
| "grid_template_columns": null, | |
| "grid_template_rows": null, | |
| "height": null, | |
| "justify_content": null, | |
| "justify_items": null, | |
| "left": null, | |
| "margin": null, | |
| "max_height": null, | |
| "max_width": null, | |
| "min_height": null, | |
| "min_width": null, | |
| "object_fit": null, | |
| "object_position": null, | |
| "order": null, | |
| "overflow": null, | |
| "overflow_x": null, | |
| "overflow_y": null, | |
| "padding": null, | |
| "right": null, | |
| "top": null, | |
| "visibility": null, | |
| "width": null | |
| } | |
| }, | |
| "4eec62679381415ab632253d8f043108": { | |
| "model_module": "@jupyter-widgets/controls", | |
| "model_name": "DescriptionStyleModel", | |
| "model_module_version": "1.5.0", | |
| "state": { | |
| "_model_module": "@jupyter-widgets/controls", | |
| "_model_module_version": "1.5.0", | |
| "_model_name": "DescriptionStyleModel", | |
| "_view_count": null, | |
| "_view_module": "@jupyter-widgets/base", | |
| "_view_module_version": "1.2.0", | |
| "_view_name": "StyleView", | |
| "description_width": "" | |
| } | |
| } | |
| } | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/moarshy/0b1edde8afd538e5073fb771b2753315/mnunetmodel.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip install timm fastai -Uqq" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "cJ_Y_aOreCYt", | |
| "outputId": "bdc1576e-a1fa-4618-9123-e69fd9b926be" | |
| }, | |
| "execution_count": 1, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\u001b[K |████████████████████████████████| 431 kB 4.0 MB/s \n", | |
| "\u001b[K |████████████████████████████████| 189 kB 45.0 MB/s \n", | |
| "\u001b[K |████████████████████████████████| 55 kB 3.9 MB/s \n", | |
| "\u001b[?25h" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": { | |
| "id": "ynxvBvjJMH8F" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "\n", | |
| "import timm\n", | |
| "from timm import create_model\n", | |
| "from timm.models.efficientnet_blocks import DepthwiseSeparableConv\n", | |
| "\n", | |
| "from fastai.vision.all import *\n", | |
| "from fastai.callback.all import *" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "In this notebook,\n", | |
| "- [x] Implement this [paper](https://openaccess.thecvf.com/content/CVPR2021W/MAI/papers/Zhang_A_Simple_Baseline_for_Fast_and_Accurate_Depth_Estimation_on_CVPRW_2021_paper.pdf)\n", | |
| "- [x] Train using knowledge distillation\n" | |
| ], | |
| "metadata": { | |
| "id": "QBoa9XfUjYsV" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "**The paper abstract**\n", | |
| "\n", | |
| " " | |
| ], | |
| "metadata": { | |
| "id": "UseLzzkckySY" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "**The architecture**\n", | |
| "" | |
| ], | |
| "metadata": { | |
| "id": "k5lA8zzHk-z9" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Some codes are based off https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31\n", | |
| "\n", | |
| "class Conv2dBnAct(nn.Module):\n", | |
| " def __init__(self, \n", | |
| " in_channels, \n", | |
| " out_channels, \n", | |
| " kernel_size, \n", | |
| " padding=0,\n", | |
| " stride=1, \n", | |
| " act_layer=nn.ReLU, \n", | |
| " norm_layer=nn.BatchNorm2d\n", | |
| " ):\n", | |
| " super().__init__()\n", | |
| " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)\n", | |
| " self.bn = norm_layer(out_channels)\n", | |
| " self.act = act_layer(inplace=True)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " x = self.conv(x)\n", | |
| " x = self.bn(x)\n", | |
| " x = self.act(x)\n", | |
| " return x\n", | |
| "\n", | |
| "\n", | |
| "class FeatureFusionModule(nn.Module):\n", | |
| " def __init__(self, \n", | |
| " enc_in_channels, # Encoder in channels\n", | |
| " enc_out_channels, # Encoder out channels\n", | |
| " dec_in_channels, # Decoder in channels\n", | |
| " out_channels, # Final out channels\n", | |
| " ):\n", | |
| " super().__init__()\n", | |
| " #encoderoutput\n", | |
| " self.enc_conv1 = nn.Conv2d(enc_in_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n", | |
| " self.enc_up = nn.ConvTranspose2d(enc_out_channels, enc_out_channels, kernel_size=1)\n", | |
| " self.enc_dconv = DepthwiseSeparableConv(enc_out_channels, enc_out_channels)\n", | |
| " self.enc_conv2 = nn.Conv2d(enc_out_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n", | |
| " \n", | |
| " #decoderoutput\n", | |
| " self.dec_dconv = DepthwiseSeparableConv(enc_out_channels+dec_in_channels, enc_out_channels+dec_in_channels)\n", | |
| " self.dec_conv1 = nn.Conv2d(enc_out_channels+dec_in_channels, out_channels, kernel_size=1, stride=1, padding='same')\n", | |
| "\n", | |
| "\n", | |
| " def forward(self, enc_x, dec_x):\n", | |
| " enc_x = self.enc_conv1(enc_x)\n", | |
| " enc_x = self.enc_up(enc_x)\n", | |
| " enc_x = self.enc_dconv(enc_x)\n", | |
| " enc_x = self.enc_conv2(enc_x)\n", | |
| " \n", | |
| " x = torch.cat([enc_x, dec_x], dim=1)\n", | |
| "\n", | |
| " dec_x = self.dec_dconv(x)\n", | |
| " dec_x = self.dec_conv1(dec_x)\n", | |
| " \n", | |
| " return dec_x\n", | |
| "\n", | |
| "\n", | |
| "class DecoderBlock(nn.Module):\n", | |
| " def __init__(self, \n", | |
| " enc_channels,\n", | |
| " dec_prev_channels, \n", | |
| " dec_channels,\n", | |
| " act_layer=nn.ReLU, \n", | |
| " norm_layer=nn.BatchNorm2d,\n", | |
| " ffm=True,\n", | |
| " ):\n", | |
| " super().__init__()\n", | |
| " conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)\n", | |
| " self.ffm = ffm\n", | |
| " \n", | |
| " if ffm:\n", | |
| " self.ffm = FeatureFusionModule(enc_channels, enc_channels, dec_prev_channels, dec_channels)\n", | |
| "\n", | |
| " self.conv1 = Conv2dBnAct(enc_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n", | |
| " self.conv2 = Conv2dBnAct(dec_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n", | |
| " \n", | |
| "\n", | |
| " def forward(self, x_enc, x_dec):\n", | |
| " if self.ffm:\n", | |
| " x = self.ffm(x_enc, x_dec)\n", | |
| "\n", | |
| " x = F.interpolate(x_enc, scale_factor=2, mode='nearest')\n", | |
| "\n", | |
| " x = self.conv1(x)\n", | |
| " x = self.conv2(x)\n", | |
| "\n", | |
| " return x\n", | |
| "\n", | |
| "\n", | |
| "class UnetDecoder(nn.Module):\n", | |
| "\n", | |
| " def __init__(self,\n", | |
| " encoder_channels,\n", | |
| " decoder_channels=(256, 128, 64, 32, 16),\n", | |
| " final_channels=3,\n", | |
| " norm_layer=nn.BatchNorm2d,\n", | |
| " ):\n", | |
| " super().__init__()\n", | |
| "\n", | |
| " self.decoders = nn.ModuleList()\n", | |
| " for i, (e_ch, d_ch) in enumerate(zip(encoder_channels, decoder_channels)):\n", | |
| " if i== 0:\n", | |
| " self.decoders.append(DecoderBlock(enc_channels=e_ch, \n", | |
| " dec_prev_channels=None, \n", | |
| " dec_channels=d_ch,\n", | |
| " act_layer=nn.ReLU, \n", | |
| " norm_layer=nn.BatchNorm2d,\n", | |
| " ffm=False,\n", | |
| " ))\n", | |
| "\n", | |
| " else:\n", | |
| " self.decoders.append(DecoderBlock(enc_channels=e_ch, \n", | |
| " dec_prev_channels=decoder_channels[i-1], \n", | |
| " dec_channels=d_ch,\n", | |
| " act_layer=nn.ReLU, \n", | |
| " norm_layer=nn.BatchNorm2d,\n", | |
| " ffm=True,\n", | |
| " ))\n", | |
| "\n", | |
| " self.final_conv = nn.Conv2d(decoder_channels[-1], final_channels, kernel_size=(1, 1))\n", | |
| " self.tensor_base = ToTensorBase()\n", | |
| " self._init_weight()\n", | |
| "\n", | |
| " def _init_weight(self):\n", | |
| " for m in self.modules():\n", | |
| " if isinstance(m, nn.Conv2d):\n", | |
| " torch.nn.init.kaiming_normal_(m.weight)\n", | |
| " elif isinstance(m, nn.BatchNorm2d):\n", | |
| " m.weight.data.fill_(1)\n", | |
| " m.bias.data.zero_()\n", | |
| "\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " enc_outs_r = x\n", | |
| " dec_out = None\n", | |
| " for i, each in enumerate(self.decoders):\n", | |
| " dec_out = each(enc_outs_r[i], dec_out)\n", | |
| " x = self.final_conv(dec_out)\n", | |
| " x = self.tensor_base(x)\n", | |
| " return x\n", | |
| "\n", | |
| "\n", | |
| "class Unet(nn.Module):\n", | |
| " def __init__(self,\n", | |
| " backbone='resnet50',\n", | |
| " backbone_kwargs=None,\n", | |
| " backbone_indices=None,\n", | |
| " decoder_use_batchnorm=True,\n", | |
| " decoder_channels=(256, 128, 64, 32, 16),\n", | |
| " in_chans=3,\n", | |
| " num_classes=3,\n", | |
| " norm_layer=nn.BatchNorm2d,\n", | |
| " pretrained=True,\n", | |
| " ):\n", | |
| " super().__init__()\n", | |
| " backbone_kwargs = backbone_kwargs or {}\n", | |
| " # NOTE some models need different backbone indices specified based on the alignment of features\n", | |
| " # and some models won't have a full enough range of feature strides to work properly.\n", | |
| " encoder = create_model(\n", | |
| " backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,\n", | |
| " pretrained=pretrained, **backbone_kwargs)\n", | |
| " encoder_channels = encoder.feature_info.channels()[::-1]\n", | |
| " self.encoder = encoder\n", | |
| "\n", | |
| " self.decoder = UnetDecoder(\n", | |
| " encoder_channels=encoder_channels,\n", | |
| " decoder_channels=decoder_channels,\n", | |
| " final_channels=num_classes,\n", | |
| " norm_layer=norm_layer,\n", | |
| " )\n", | |
| "\n", | |
| " def forward(self, x: torch.Tensor):\n", | |
| " x = self.encoder(x)\n", | |
| " x.reverse() # torchscript doesn't work with [::-1]\n", | |
| " x = self.decoder(x)\n", | |
| " return x" | |
| ], | |
| "metadata": { | |
| "id": "t0IVmDv-BZsN" | |
| }, | |
| "execution_count": 24, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Test individuals" | |
| ], | |
| "metadata": { | |
| "id": "nl_Y1QW9OsZR" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model = Unet('mobilenetv3_rw')" | |
| ], | |
| "metadata": { | |
| "id": "acc4eR-HNati" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "o = model(torch.randn(2,3,128,160))" | |
| ], | |
| "metadata": { | |
| "id": "JMxOElmXNyfJ" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "o.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "s5ITHIt_N3Za", | |
| "outputId": "985af779-f5a7-4c0e-aded-398b5cce175f" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 3, 128, 160])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 10 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "encoder = create_model(\n", | |
| " 'mobilenetv3_rw', features_only=True, out_indices=None, in_chans=3,\n", | |
| " pretrained=True,)" | |
| ], | |
| "metadata": { | |
| "id": "d6nwuci0vEJi" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "enc_outs = encoder(torch.randn(2, 3, 128, 160))\n", | |
| "enc_outs_r = enc_outs[::-1]" | |
| ], | |
| "metadata": { | |
| "id": "cpIXAnerTZ_v" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "for e in enc_outs:\n", | |
| " print(e.shape)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "oAE8LNQQTqwG", | |
| "outputId": "512ff956-62a4-4bd1-ad36-25c4bd418ef7" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "torch.Size([2, 16, 64, 80])\n", | |
| "torch.Size([2, 24, 32, 40])\n", | |
| "torch.Size([2, 40, 16, 20])\n", | |
| "torch.Size([2, 112, 8, 10])\n", | |
| "torch.Size([2, 960, 4, 5])\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "enc_channels = encoder.feature_info.channels()[::-1]; enc_channels" | |
| ], | |
| "metadata": { | |
| "id": "YhvGiHYNNgcH", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "acf5e4d3-7f1e-4332-a7ad-08ae84ee10ef" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "[960, 112, 40, 24, 16]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 159 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec_channels = [256, 128, 64, 32, 16]" | |
| ], | |
| "metadata": { | |
| "id": "yOoZwawcNvaD" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec1 = DecoderBlock(enc_channels[0], \n", | |
| " None, \n", | |
| " dec_channels[0],\n", | |
| " ffm=False)" | |
| ], | |
| "metadata": { | |
| "id": "oZLm9YwkODQv" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec1_out = dec1(enc_outs_r[0], None)" | |
| ], | |
| "metadata": { | |
| "id": "3h-2jsN2OVNO" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec1_out.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "oE-X8RKNOl8e", | |
| "outputId": "51ccf30e-782d-462f-fcfc-28a0b8347e7b" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 256, 8, 10])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 163 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec2 = DecoderBlock(enc_channels[1], \n", | |
| " dec_channels[1-1], \n", | |
| " dec_channels[1],\n", | |
| " ffm=True)" | |
| ], | |
| "metadata": { | |
| "id": "uchSrgNlwUg5" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec2_out = dec2(enc_outs_r[1], dec1_out)" | |
| ], | |
| "metadata": { | |
| "id": "PTGDO0ooweXJ" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec2_out.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "2iqQyBVvwpFx", | |
| "outputId": "3c1e185f-2d82-49ae-92b8-64d52000c875" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 128, 16, 20])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 154 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec3 = DecoderBlock(enc_channels[2], \n", | |
| " dec_channels[2-1], \n", | |
| " dec_channels[2],\n", | |
| " ffm=True)" | |
| ], | |
| "metadata": { | |
| "id": "YI38w3tCwq9K" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec3_out = dec3(enc_outs_r[2], dec2_out)" | |
| ], | |
| "metadata": { | |
| "id": "Id5PcFlDw4M6" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec3_out.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "QR3fRqsh1Hrx", | |
| "outputId": "76152af4-972f-45c9-c5d1-dcc795edec2a" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 64, 32, 40])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 169 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec4 = DecoderBlock(enc_channels[3], \n", | |
| " dec_channels[3-1], \n", | |
| " dec_channels[3],\n", | |
| " ffm=True)" | |
| ], | |
| "metadata": { | |
| "id": "ytatkui20UNi" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec4_out = dec4(enc_outs_r[3], dec3_out)" | |
| ], | |
| "metadata": { | |
| "id": "0YxQrVnB0dDl" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec4_out.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "2G3Uht6T1Kbi", | |
| "outputId": "7df41391-e29d-4c2f-ffa9-439f7033717d" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 32, 64, 80])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 172 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec5 = DecoderBlock(enc_channels[4], \n", | |
| " dec_channels[4-1], \n", | |
| " dec_channels[4],\n", | |
| " ffm=True)" | |
| ], | |
| "metadata": { | |
| "id": "dITgOOl81QcL" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec5_out = dec5(enc_outs_r[4], dec4_out)" | |
| ], | |
| "metadata": { | |
| "id": "LqHdur5d1WCI" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec5_out.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "s1SuPses1dKL", | |
| "outputId": "3c23592b-b428-43df-ea7c-6335f7ddd58b" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([2, 16, 128, 160])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 175 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "" | |
| ], | |
| "metadata": { | |
| "id": "VzKVZSeTPXx0" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dec = UnetDecoder(enc_channels)" | |
| ], | |
| "metadata": { | |
| "id": "HCu10Ce_jhrj" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "x = dec(enc_outs_r)" | |
| ], | |
| "metadata": { | |
| "id": "QSZ27lM3jr7L" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "" | |
| ], | |
| "metadata": { | |
| "id": "9QPdsDft2SrY" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Test Unet training" | |
| ], | |
| "metadata": { | |
| "id": "axnPMcaJOvGp" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "path = untar_data(URLs.CAMVID)\n", | |
| "path.ls()" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "QJjxTv5xOxNF", | |
| "outputId": "a72a4983-fd95-44e6-dd81-b3eddd77291b" | |
| }, | |
| "execution_count": 25, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(#4) [Path('/root/.fastai/data/camvid/codes.txt'),Path('/root/.fastai/data/camvid/images'),Path('/root/.fastai/data/camvid/valid.txt'),Path('/root/.fastai/data/camvid/labels')]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 25 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "codes = np.loadtxt(path/'codes.txt', dtype=str)\n", | |
| "codes" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "PCkr63ykO-LR", | |
| "outputId": "b1c7586b-9e84-4d89-9c19-7cf98384e74a" | |
| }, | |
| "execution_count": 26, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "array(['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',\n", | |
| " 'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',\n", | |
| " 'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving',\n", | |
| " 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk',\n", | |
| " 'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone',\n", | |
| " 'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel',\n", | |
| " 'VegetationMisc', 'Void', 'Wall'], dtype='<U17')" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 26 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "fnames = get_image_files(path/\"images\")" | |
| ], | |
| "metadata": { | |
| "id": "dKQMJJ80PJwY" | |
| }, | |
| "execution_count": 27, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def label_func(fn): return path/\"labels\"/f\"{fn.stem}_P{fn.suffix}\"" | |
| ], | |
| "metadata": { | |
| "id": "9SZzi4m6PMpS" | |
| }, | |
| "execution_count": 28, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "name2id = {v:k for k,v in enumerate(codes)}\n", | |
| "void_code = name2id['Void']\n", | |
| "def acc_camvid(inp, targ):\n", | |
| " targ = targ.squeeze(1)\n", | |
| " mask = targ != void_code\n", | |
| " return np.mean(inp.argmax(dim=1)[mask].cpu().numpy()==targ[mask].cpu().numpy())" | |
| ], | |
| "metadata": { | |
| "id": "lU6X1jDGQExa" | |
| }, | |
| "execution_count": 29, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "dls = SegmentationDataLoaders.from_label_func(path, \n", | |
| " bs=8, \n", | |
| " fnames = fnames, \n", | |
| " label_func = label_func, \n", | |
| " codes = codes,\n", | |
| " item_tfms=Resize((128, 160)))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "bj7NRyDCPNjl", | |
| "outputId": "939b58d5-9db2-429e-dc88-c12f2a1220c3" | |
| }, | |
| "execution_count": 30, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
| " ret = func(*args, **kwargs)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## with Dice Loss" | |
| ], | |
| "metadata": { | |
| "id": "vWuMb0ExqbSM" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model = Unet('mobilenetv3_rw',\n", | |
| " num_classes=32)\n", | |
| "\n", | |
| "learn = Learner(dls, \n", | |
| " model,\n", | |
| " loss_func=DiceLoss(),\n", | |
| " metrics=acc_camvid)" | |
| ], | |
| "metadata": { | |
| "id": "Utq5Utm0PQA0", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "2b494288-aaf0-4463-a33d-86ace3ea2ce9" | |
| }, | |
| "execution_count": 12, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth\" to /root/.cache/torch/hub/checkpoints/mobilenetv3_100-35495452.pth\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "learn.fit_one_cycle(10, 1e-3)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 418 | |
| }, | |
| "id": "RQkxMCbWQTgR", | |
| "outputId": "6d31edfb-ed01-4f7c-a434-8fac551ec433" | |
| }, | |
| "execution_count": 13, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "\n", | |
| "<style>\n", | |
| " /* Turns off some styling */\n", | |
| " progress {\n", | |
| " /* gets rid of default border in Firefox and Opera. */\n", | |
| " border: none;\n", | |
| " /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
| " background-size: auto;\n", | |
| " }\n", | |
| " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
| " background: #F44336;\n", | |
| " }\n", | |
| "</style>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: left;\">\n", | |
| " <th>epoch</th>\n", | |
| " <th>train_loss</th>\n", | |
| " <th>valid_loss</th>\n", | |
| " <th>acc_camvid</th>\n", | |
| " <th>time</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <td>0</td>\n", | |
| " <td>252.215271</td>\n", | |
| " <td>246.689651</td>\n", | |
| " <td>0.183169</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1</td>\n", | |
| " <td>245.648544</td>\n", | |
| " <td>237.859009</td>\n", | |
| " <td>0.438296</td>\n", | |
| " <td>00:26</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>2</td>\n", | |
| " <td>240.166595</td>\n", | |
| " <td>233.821426</td>\n", | |
| " <td>0.527989</td>\n", | |
| " <td>00:31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>3</td>\n", | |
| " <td>236.032440</td>\n", | |
| " <td>230.155685</td>\n", | |
| " <td>0.579975</td>\n", | |
| " <td>00:26</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>4</td>\n", | |
| " <td>232.820297</td>\n", | |
| " <td>227.664886</td>\n", | |
| " <td>0.598176</td>\n", | |
| " <td>00:26</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>5</td>\n", | |
| " <td>230.536667</td>\n", | |
| " <td>226.120468</td>\n", | |
| " <td>0.612691</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>6</td>\n", | |
| " <td>229.252609</td>\n", | |
| " <td>225.270096</td>\n", | |
| " <td>0.618059</td>\n", | |
| " <td>00:26</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>7</td>\n", | |
| " <td>228.376007</td>\n", | |
| " <td>224.750214</td>\n", | |
| " <td>0.621125</td>\n", | |
| " <td>00:26</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>8</td>\n", | |
| " <td>227.940201</td>\n", | |
| " <td>224.663452</td>\n", | |
| " <td>0.622177</td>\n", | |
| " <td>00:26</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>9</td>\n", | |
| " <td>227.806488</td>\n", | |
| " <td>224.604568</td>\n", | |
| " <td>0.624282</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
| " ret = func(*args, **kwargs)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "learn.summary()" | |
| ], | |
| "metadata": { | |
| "id": "L53Vk3fCZKDE" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## with CrossEntropy loss" | |
| ], | |
| "metadata": { | |
| "id": "tF_pW798qdgI" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model = Unet('mobilenetv3_rw',\n", | |
| " num_classes=32)\n", | |
| "\n", | |
| "learn = Learner(dls, \n", | |
| " model,\n", | |
| " metrics=acc_camvid)" | |
| ], | |
| "metadata": { | |
| "id": "ZSLhn68ZmkmZ" | |
| }, | |
| "execution_count": 31, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "learn.fit_one_cycle(25, 1e-3)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 888 | |
| }, | |
| "id": "qLJF8V6zmrs3", | |
| "outputId": "5adf8515-9496-4f3a-be10-2ccbe7952a43" | |
| }, | |
| "execution_count": 32, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "\n", | |
| "<style>\n", | |
| " /* Turns off some styling */\n", | |
| " progress {\n", | |
| " /* gets rid of default border in Firefox and Opera. */\n", | |
| " border: none;\n", | |
| " /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
| " background-size: auto;\n", | |
| " }\n", | |
| " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
| " background: #F44336;\n", | |
| " }\n", | |
| "</style>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: left;\">\n", | |
| " <th>epoch</th>\n", | |
| " <th>train_loss</th>\n", | |
| " <th>valid_loss</th>\n", | |
| " <th>acc_camvid</th>\n", | |
| " <th>time</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <td>0</td>\n", | |
| " <td>3.952452</td>\n", | |
| " <td>3.837103</td>\n", | |
| " <td>0.009524</td>\n", | |
| " <td>00:33</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1</td>\n", | |
| " <td>3.651547</td>\n", | |
| " <td>3.323963</td>\n", | |
| " <td>0.034026</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>2</td>\n", | |
| " <td>3.141333</td>\n", | |
| " <td>2.725803</td>\n", | |
| " <td>0.219234</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>3</td>\n", | |
| " <td>2.604731</td>\n", | |
| " <td>2.217971</td>\n", | |
| " <td>0.475758</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>4</td>\n", | |
| " <td>2.168982</td>\n", | |
| " <td>1.879862</td>\n", | |
| " <td>0.516642</td>\n", | |
| " <td>00:29</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>5</td>\n", | |
| " <td>1.846897</td>\n", | |
| " <td>1.635621</td>\n", | |
| " <td>0.553690</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>6</td>\n", | |
| " <td>1.600007</td>\n", | |
| " <td>1.454400</td>\n", | |
| " <td>0.619151</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>7</td>\n", | |
| " <td>1.441168</td>\n", | |
| " <td>1.355009</td>\n", | |
| " <td>0.629975</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>8</td>\n", | |
| " <td>1.343089</td>\n", | |
| " <td>1.295351</td>\n", | |
| " <td>0.633226</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>9</td>\n", | |
| " <td>1.287046</td>\n", | |
| " <td>1.254252</td>\n", | |
| " <td>0.639729</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>10</td>\n", | |
| " <td>1.246888</td>\n", | |
| " <td>1.224334</td>\n", | |
| " <td>0.643034</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>11</td>\n", | |
| " <td>1.220407</td>\n", | |
| " <td>1.203212</td>\n", | |
| " <td>0.646837</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>12</td>\n", | |
| " <td>1.192617</td>\n", | |
| " <td>1.186723</td>\n", | |
| " <td>0.649282</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>13</td>\n", | |
| " <td>1.178188</td>\n", | |
| " <td>1.172160</td>\n", | |
| " <td>0.654529</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>14</td>\n", | |
| " <td>1.167311</td>\n", | |
| " <td>1.163162</td>\n", | |
| " <td>0.654595</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>15</td>\n", | |
| " <td>1.166651</td>\n", | |
| " <td>1.157911</td>\n", | |
| " <td>0.654737</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>16</td>\n", | |
| " <td>1.148694</td>\n", | |
| " <td>1.148224</td>\n", | |
| " <td>0.660711</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>17</td>\n", | |
| " <td>1.145139</td>\n", | |
| " <td>1.146453</td>\n", | |
| " <td>0.659014</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>18</td>\n", | |
| " <td>1.135504</td>\n", | |
| " <td>1.138023</td>\n", | |
| " <td>0.661375</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>19</td>\n", | |
| " <td>1.139854</td>\n", | |
| " <td>1.135780</td>\n", | |
| " <td>0.661414</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>20</td>\n", | |
| " <td>1.132565</td>\n", | |
| " <td>1.136258</td>\n", | |
| " <td>0.660869</td>\n", | |
| " <td>00:28</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>21</td>\n", | |
| " <td>1.126904</td>\n", | |
| " <td>1.133178</td>\n", | |
| " <td>0.661520</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>22</td>\n", | |
| " <td>1.126309</td>\n", | |
| " <td>1.132679</td>\n", | |
| " <td>0.663100</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>23</td>\n", | |
| " <td>1.129516</td>\n", | |
| " <td>1.131822</td>\n", | |
| " <td>0.663083</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>24</td>\n", | |
| " <td>1.132021</td>\n", | |
| " <td>1.131812</td>\n", | |
| " <td>0.662620</td>\n", | |
| " <td>00:27</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
| " ret = func(*args, **kwargs)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# fastai unet-resnet34 performance" | |
| ], | |
| "metadata": { | |
| "id": "ZlkdkgOAjTGP" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "teacher_learn = unet_learner(dls, \n", | |
| " resnet101, \n", | |
| " metrics=acc_camvid)" | |
| ], | |
| "metadata": { | |
| "id": "sy8eP1xjbI4h", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 121, | |
| "referenced_widgets": [ | |
| "8eb841b4c62b4042bd9a5af15342c012", | |
| "dd011ba65aa04664bd63eb8f79ca29c8", | |
| "46490715f4934745aa454ae8e119f40d", | |
| "f7444e1bb42a4c6f913e05514445ce59", | |
| "cd7643d9744340ec99319a52c2faf2c9", | |
| "a1860ff990b042a5b50f87066e30bb9b", | |
| "18da3006e55d432d9038da4f74558e4f", | |
| "5ca90a8d4e0c4be3888dbccafe6aafce", | |
| "56aa6084019b49b1ad09c0f1aaabdf01", | |
| "e43806d23dad4c86afb2428ce56504b3", | |
| "4eec62679381415ab632253d8f043108" | |
| ] | |
| }, | |
| "outputId": "f3d18777-eb58-4522-e37d-04532869eb1c" | |
| }, | |
| "execution_count": 34, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
| " ret = func(*args, **kwargs)\n", | |
| "Downloading: \"https://download.pytorch.org/models/resnet101-63fe2227.pth\" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| " 0%| | 0.00/171M [00:00<?, ?B/s]" | |
| ], | |
| "application/vnd.jupyter.widget-view+json": { | |
| "version_major": 2, | |
| "version_minor": 0, | |
| "model_id": "8eb841b4c62b4042bd9a5af15342c012" | |
| } | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "teacher_learn.fit_one_cycle(10, 1e-3)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 418 | |
| }, | |
| "id": "ojFW8MxuouxA", | |
| "outputId": "06e9580e-f19b-4a31-ce60-63b522b03020" | |
| }, | |
| "execution_count": 35, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "\n", | |
| "<style>\n", | |
| " /* Turns off some styling */\n", | |
| " progress {\n", | |
| " /* gets rid of default border in Firefox and Opera. */\n", | |
| " border: none;\n", | |
| " /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
| " background-size: auto;\n", | |
| " }\n", | |
| " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
| " background: #F44336;\n", | |
| " }\n", | |
| "</style>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: left;\">\n", | |
| " <th>epoch</th>\n", | |
| " <th>train_loss</th>\n", | |
| " <th>valid_loss</th>\n", | |
| " <th>acc_camvid</th>\n", | |
| " <th>time</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <td>0</td>\n", | |
| " <td>1.079316</td>\n", | |
| " <td>1.945745</td>\n", | |
| " <td>0.791928</td>\n", | |
| " <td>02:54</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1</td>\n", | |
| " <td>1.672722</td>\n", | |
| " <td>3.414593</td>\n", | |
| " <td>0.467281</td>\n", | |
| " <td>02:33</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>2</td>\n", | |
| " <td>2.836250</td>\n", | |
| " <td>1.068983</td>\n", | |
| " <td>0.708052</td>\n", | |
| " <td>02:31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>3</td>\n", | |
| " <td>1.262702</td>\n", | |
| " <td>1.077650</td>\n", | |
| " <td>0.819915</td>\n", | |
| " <td>02:31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>4</td>\n", | |
| " <td>0.755539</td>\n", | |
| " <td>0.741171</td>\n", | |
| " <td>0.839445</td>\n", | |
| " <td>02:30</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>5</td>\n", | |
| " <td>0.586151</td>\n", | |
| " <td>0.615320</td>\n", | |
| " <td>0.857607</td>\n", | |
| " <td>02:31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>6</td>\n", | |
| " <td>0.495401</td>\n", | |
| " <td>0.540893</td>\n", | |
| " <td>0.869873</td>\n", | |
| " <td>02:31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>7</td>\n", | |
| " <td>0.429804</td>\n", | |
| " <td>0.520235</td>\n", | |
| " <td>0.879431</td>\n", | |
| " <td>02:31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>8</td>\n", | |
| " <td>0.395290</td>\n", | |
| " <td>0.510423</td>\n", | |
| " <td>0.879006</td>\n", | |
| " <td>02:31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>9</td>\n", | |
| " <td>0.374698</td>\n", | |
| " <td>0.482728</td>\n", | |
| " <td>0.882320</td>\n", | |
| " <td>02:31</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
| " ret = func(*args, **kwargs)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "teacher_learn.loss_func" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "_ZMICwZzo0MN", | |
| "outputId": "659ea2c1-8cc6-4aaa-8776-95e17d8c9e02" | |
| }, | |
| "execution_count": 36, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "FlattenedLoss of CrossEntropyLoss()" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 36 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Knowledge Distillation" | |
| ], | |
| "metadata": { | |
| "id": "IVvuN2gEqjv6" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class DistillationLoss(nn.Module):\n", | |
| " def __init__(self):\n", | |
| " super(DistillationLoss, self).__init__()\n", | |
| " self.distillation_loss = nn.KLDivLoss(reduction='batchmean')\n", | |
| " \n", | |
| " def forward(self,\n", | |
| " student_preds, \n", | |
| " teacher_preds, \n", | |
| " acutal_target, \n", | |
| " T, \n", | |
| " alpha\n", | |
| " ):\n", | |
| "\n", | |
| " return self.distillation_loss(F.softmax(student_preds / T, dim=1).reshape(-1),\n", | |
| " F.softmax(teacher_preds / T, dim=1).reshape(-1))\n", | |
| " \n", | |
| "\n", | |
| "\n", | |
| "class KnowledgeDistillation(Callback):\n", | |
| " def __init__(self, \n", | |
| " teacher:Learner, \n", | |
| " T:float=20., \n", | |
| " a:float=0.7):\n", | |
| " super(KnowledgeDistillation, self).__init__()\n", | |
| " self.teacher = teacher\n", | |
| " self.T, self.a = T, a\n", | |
| " self.distillation_loss = DistillationLoss()\n", | |
| " \n", | |
| " def after_loss(self):\n", | |
| " teacher_preds = self.teacher.model(self.learn.xb[0])\n", | |
| " student_loss = self.learn.loss_grad * self.a\n", | |
| " distillation_loss = self.distillation_loss(self.learn.pred, # Student preds\n", | |
| " teacher_preds, # Teacher preds\n", | |
| " self.learn.yb, # Ground truth\n", | |
| " self.T, \n", | |
| " self.a) * (1 - self.a)\n", | |
| " self.learn.loss_grad = student_loss + distillation_loss" | |
| ], | |
| "metadata": { | |
| "id": "f1IO5Vw4tPXX" | |
| }, | |
| "execution_count": 37, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model = Unet('mobilenetv3_rw',\n", | |
| " num_classes=32)\n", | |
| "\n", | |
| "student_learn = Learner(dls, \n", | |
| " model,\n", | |
| " metrics=acc_camvid,\n", | |
| " cbs=[KnowledgeDistillation(teacher=teacher_learn)])" | |
| ], | |
| "metadata": { | |
| "id": "Si1Zl_2uy43B" | |
| }, | |
| "execution_count": 38, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "student_learn.fit_one_cycle(25, 1e-3)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 888 | |
| }, | |
| "id": "QjXcKthW0B4t", | |
| "outputId": "3aaf9f7b-5598-433e-8d36-f20d44a06b57" | |
| }, | |
| "execution_count": 39, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "\n", | |
| "<style>\n", | |
| " /* Turns off some styling */\n", | |
| " progress {\n", | |
| " /* gets rid of default border in Firefox and Opera. */\n", | |
| " border: none;\n", | |
| " /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
| " background-size: auto;\n", | |
| " }\n", | |
| " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
| " background: #F44336;\n", | |
| " }\n", | |
| "</style>\n" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ], | |
| "text/html": [ | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: left;\">\n", | |
| " <th>epoch</th>\n", | |
| " <th>train_loss</th>\n", | |
| " <th>valid_loss</th>\n", | |
| " <th>acc_camvid</th>\n", | |
| " <th>time</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <td>0</td>\n", | |
| " <td>4.149306</td>\n", | |
| " <td>4.083494</td>\n", | |
| " <td>0.012385</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>1</td>\n", | |
| " <td>3.843522</td>\n", | |
| " <td>3.565912</td>\n", | |
| " <td>0.096796</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>2</td>\n", | |
| " <td>3.263084</td>\n", | |
| " <td>2.797809</td>\n", | |
| " <td>0.369143</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>3</td>\n", | |
| " <td>2.626706</td>\n", | |
| " <td>2.134217</td>\n", | |
| " <td>0.512857</td>\n", | |
| " <td>02:24</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>4</td>\n", | |
| " <td>2.120926</td>\n", | |
| " <td>1.793929</td>\n", | |
| " <td>0.577740</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>5</td>\n", | |
| " <td>1.782979</td>\n", | |
| " <td>1.582902</td>\n", | |
| " <td>0.600909</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>6</td>\n", | |
| " <td>1.565306</td>\n", | |
| " <td>1.450105</td>\n", | |
| " <td>0.618120</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>7</td>\n", | |
| " <td>1.431738</td>\n", | |
| " <td>1.363392</td>\n", | |
| " <td>0.626684</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>8</td>\n", | |
| " <td>1.350202</td>\n", | |
| " <td>1.306182</td>\n", | |
| " <td>0.632427</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>9</td>\n", | |
| " <td>1.291344</td>\n", | |
| " <td>1.265584</td>\n", | |
| " <td>0.640540</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>10</td>\n", | |
| " <td>1.253604</td>\n", | |
| " <td>1.234486</td>\n", | |
| " <td>0.646430</td>\n", | |
| " <td>02:24</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>11</td>\n", | |
| " <td>1.225404</td>\n", | |
| " <td>1.220467</td>\n", | |
| " <td>0.650743</td>\n", | |
| " <td>02:24</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>12</td>\n", | |
| " <td>1.201750</td>\n", | |
| " <td>1.195458</td>\n", | |
| " <td>0.652415</td>\n", | |
| " <td>02:24</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>13</td>\n", | |
| " <td>1.185998</td>\n", | |
| " <td>1.175060</td>\n", | |
| " <td>0.658728</td>\n", | |
| " <td>02:24</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>14</td>\n", | |
| " <td>1.171920</td>\n", | |
| " <td>1.162316</td>\n", | |
| " <td>0.664814</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>15</td>\n", | |
| " <td>1.160010</td>\n", | |
| " <td>1.154163</td>\n", | |
| " <td>0.666426</td>\n", | |
| " <td>02:24</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>16</td>\n", | |
| " <td>1.153085</td>\n", | |
| " <td>1.149265</td>\n", | |
| " <td>0.667741</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>17</td>\n", | |
| " <td>1.141003</td>\n", | |
| " <td>1.143118</td>\n", | |
| " <td>0.670634</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>18</td>\n", | |
| " <td>1.138742</td>\n", | |
| " <td>1.137632</td>\n", | |
| " <td>0.671683</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>19</td>\n", | |
| " <td>1.129604</td>\n", | |
| " <td>1.136192</td>\n", | |
| " <td>0.672162</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>20</td>\n", | |
| " <td>1.125707</td>\n", | |
| " <td>1.133340</td>\n", | |
| " <td>0.672295</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>21</td>\n", | |
| " <td>1.130975</td>\n", | |
| " <td>1.133293</td>\n", | |
| " <td>0.670490</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>22</td>\n", | |
| " <td>1.134533</td>\n", | |
| " <td>1.132475</td>\n", | |
| " <td>0.670508</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>23</td>\n", | |
| " <td>1.129584</td>\n", | |
| " <td>1.130723</td>\n", | |
| " <td>0.672191</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <td>24</td>\n", | |
| " <td>1.126617</td>\n", | |
| " <td>1.130564</td>\n", | |
| " <td>0.672605</td>\n", | |
| " <td>02:23</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>" | |
| ] | |
| }, | |
| "metadata": {} | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
| " ret = func(*args, **kwargs)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "" | |
| ], | |
| "metadata": { | |
| "id": "sgCb1pswUtv5" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment