Last active
July 10, 2024 12:50
-
-
Save brookisme/52079e106255f75c996d8595cd3988b0 to your computer and use it in GitHub Desktop.
UNET with Squeeze and Excitation Blocks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "from torchsummary import summary" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1, 64, 568, 568])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " AdaptiveAvgPool2d-1 [-1, 64, 1, 1] 0\n", | |
| " Linear-2 [-1, 4] 260\n", | |
| " ReLU-3 [-1, 4] 0\n", | |
| " Linear-4 [-1, 64] 320\n", | |
| " Sigmoid-5 [-1, 64] 0\n", | |
| "================================================================\n", | |
| "Total params: 580\n", | |
| "Trainable params: 580\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "class SqueezeExcitation(nn.Module):\n", | |
| " def __init__(self, nb_channels, reduction=16):\n", | |
| " super(SqueezeExcitation, self).__init__()\n", | |
| " self.nb_channels=nb_channels\n", | |
| " self.avg_pool=nn.AdaptiveAvgPool2d(1)\n", | |
| " self.fc=nn.Sequential(\n", | |
| " nn.Linear(nb_channels, nb_channels // reduction),\n", | |
| " nn.ReLU(inplace=True),\n", | |
| " nn.Linear(nb_channels // reduction, nb_channels),\n", | |
| " nn.Sigmoid())\n", | |
| "\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " y = self.avg_pool(x).view(-1,self.nb_channels)\n", | |
| " y = self.fc(y).view(-1,self.nb_channels,1,1)\n", | |
| " return x * y\n", | |
| " \n", | |
| "\n", | |
| "print(SqueezeExcitation(64)(torch.rand(64,568,568)).shape)\n", | |
| "summary(SqueezeExcitation(64),input_size=(64,568,568))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(568, 64)\n", | |
| "torch.Size([1, 64, 568, 568])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 570, 570] 640\n", | |
| " Conv2d-2 [-1, 64, 568, 568] 36,928\n", | |
| " BatchNorm2d-3 [-1, 64, 568, 568] 128\n", | |
| " AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
| " Linear-5 [-1, 4] 260\n", | |
| " ReLU-6 [-1, 4] 0\n", | |
| " Linear-7 [-1, 64] 320\n", | |
| " Sigmoid-8 [-1, 64] 0\n", | |
| " SqueezeExcitation-9 [-1, 64, 568, 568] 0\n", | |
| "================================================================\n", | |
| "Total params: 38,276\n", | |
| "Trainable params: 38,276\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "class ConvBlock(nn.Module):\n", | |
| "\n", | |
| " def __init__(self,\n", | |
| " in_ch,\n", | |
| " in_size,\n", | |
| " depth=2, \n", | |
| " kernel_size=3, \n", | |
| " stride=1, \n", | |
| " padding=0, \n", | |
| " out_ch=None,\n", | |
| " bn=True,\n", | |
| " se=True,\n", | |
| " act='relu',\n", | |
| " act_kwargs={}):\n", | |
| " super(ConvBlock, self).__init__()\n", | |
| " self.out_ch=out_ch or 2*in_ch\n", | |
| " self._set_post_processes(self.out_ch,bn,se,act,act_kwargs)\n", | |
| " self._set_conv_layers(\n", | |
| " depth,\n", | |
| " in_ch,\n", | |
| " kernel_size,\n", | |
| " stride,\n", | |
| " padding)\n", | |
| " self.out_size=in_size-depth*2*((kernel_size-1)/2-padding)\n", | |
| "\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " x=self.conv_layers(x)\n", | |
| " if self.bn:\n", | |
| " x=self.bn(x)\n", | |
| " if self.act:\n", | |
| " x=self._activation(x)\n", | |
| " if self.se:\n", | |
| " x=self.se(x)\n", | |
| " return x\n", | |
| "\n", | |
| " \n", | |
| " def _set_post_processes(self,out_channels,bn,se,act,act_kwargs):\n", | |
| " if bn:\n", | |
| " self.bn=nn.BatchNorm2d(out_channels)\n", | |
| " else:\n", | |
| " self.bn=False\n", | |
| " if se:\n", | |
| " self.se=SqueezeExcitation(out_channels)\n", | |
| " else:\n", | |
| " self.se=False\n", | |
| " self.act=act\n", | |
| " self.act_kwargs=act_kwargs\n", | |
| "\n", | |
| " \n", | |
| " def _set_conv_layers(\n", | |
| " self,\n", | |
| " depth,\n", | |
| " in_ch,\n", | |
| " kernel_size,\n", | |
| " stride,\n", | |
| " padding):\n", | |
| " layers=[]\n", | |
| " for index in range(depth):\n", | |
| " if index!=0:\n", | |
| " in_ch=self.out_ch\n", | |
| " layers.append(\n", | |
| " nn.Conv2d(\n", | |
| " in_channels=in_ch,\n", | |
| " out_channels=self.out_ch,\n", | |
| " kernel_size=kernel_size,\n", | |
| " stride=stride,\n", | |
| " padding=padding))\n", | |
| " self.conv_layers=nn.Sequential(*layers)\n", | |
| "\n", | |
| " \n", | |
| " def _activation(self,x):\n", | |
| " return getattr(F,self.act,**self.act_kwargs)(x)\n", | |
| "\n", | |
| " \n", | |
| "conv_block=ConvBlock(1,572,out_ch=64)\n", | |
| "print(conv_block.out_size,conv_block.out_ch)\n", | |
| "print(conv_block(torch.rand(1,1,572,572)).shape)\n", | |
| "summary(ConvBlock(1,572,out_ch=64),input_size=(1,572,572))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(276, 128)\n", | |
| "torch.Size([1, 128, 276, 276])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " MaxPool2d-1 [-1, 64, 284, 284] 0\n", | |
| " Conv2d-2 [-1, 128, 282, 282] 73,856\n", | |
| " Conv2d-3 [-1, 128, 280, 280] 147,584\n", | |
| " Conv2d-4 [-1, 128, 278, 278] 147,584\n", | |
| " Conv2d-5 [-1, 128, 276, 276] 147,584\n", | |
| " BatchNorm2d-6 [-1, 128, 276, 276] 256\n", | |
| " AdaptiveAvgPool2d-7 [-1, 128, 1, 1] 0\n", | |
| " Linear-8 [-1, 8] 1,032\n", | |
| " ReLU-9 [-1, 8] 0\n", | |
| " Linear-10 [-1, 128] 1,152\n", | |
| " Sigmoid-11 [-1, 128] 0\n", | |
| "SqueezeExcitation-12 [-1, 128, 276, 276] 0\n", | |
| " ConvBlock-13 [-1, 128, 276, 276] 0\n", | |
| "================================================================\n", | |
| "Total params: 519,048\n", | |
| "Trainable params: 519,048\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "class DownBlock(nn.Module):\n", | |
| " \n", | |
| " def __init__(self,\n", | |
| " in_ch,\n", | |
| " in_size,\n", | |
| " out_ch=None,\n", | |
| " depth=2,\n", | |
| " padding=0,\n", | |
| " bn=True,\n", | |
| " se=True,\n", | |
| " act='relu',\n", | |
| " act_kwargs={}):\n", | |
| " super(DownBlock, self).__init__()\n", | |
| " self.out_size=(in_size//2)-depth*(1-padding)*2\n", | |
| " self.out_ch=out_ch or in_ch*2\n", | |
| " self.down=nn.MaxPool2d(kernel_size=2)\n", | |
| " self.conv_block=ConvBlock(\n", | |
| " in_ch=in_ch,\n", | |
| " out_ch=self.out_ch,\n", | |
| " in_size=in_size//2,\n", | |
| " depth=depth,\n", | |
| " padding=padding,\n", | |
| " bn=bn,\n", | |
| " se=se,\n", | |
| " act=act,\n", | |
| " act_kwargs=act_kwargs)\n", | |
| "\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " x=self.down(x)\n", | |
| " return self.conv_block(x)\n", | |
| "\n", | |
| " \n", | |
| "db_out=DownBlock(64,568,depth=4)\n", | |
| "print(db_out.out_size,db_out.out_ch)\n", | |
| "print(db_out(torch.rand(1,64,568,568)).shape)\n", | |
| "summary(db_out,input_size=(64,568,568))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(196, 128)\n", | |
| "torch.Size([1, 128, 196, 196])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "class UpBlock(nn.Module):\n", | |
| " \n", | |
| " @staticmethod\n", | |
| " def cropping(skip_size,size):\n", | |
| " return (skip_size-size)//2\n", | |
| " \n", | |
| " \n", | |
| " def __init__(self,\n", | |
| " in_ch,\n", | |
| " in_size,\n", | |
| " out_ch=None,\n", | |
| " bilinear=False,\n", | |
| " crop=None,\n", | |
| " depth=2,\n", | |
| " padding=0,\n", | |
| " bn=True,\n", | |
| " se=True,\n", | |
| " act='relu',\n", | |
| " act_kwargs={}):\n", | |
| " super(UpBlock, self).__init__()\n", | |
| " self.crop=crop\n", | |
| " self.padding=padding\n", | |
| " self.out_size=(in_size*2)-depth*(1-padding)*2\n", | |
| " self.out_ch=out_ch or in_ch//2\n", | |
| " if bilinear:\n", | |
| " self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)\n", | |
| " else:\n", | |
| " self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)\n", | |
| " self.conv_block=ConvBlock(\n", | |
| " in_ch,\n", | |
| " self.out_size,\n", | |
| " out_ch=self.out_ch,\n", | |
| " depth=depth,\n", | |
| " padding=padding,\n", | |
| " bn=bn,\n", | |
| " se=se,\n", | |
| " act=act,\n", | |
| " act_kwargs=act_kwargs)\n", | |
| " \n", | |
| " \n", | |
| " def forward(self, x, skip):\n", | |
| " x = self.up(x)\n", | |
| " skip = self._crop(skip,x)\n", | |
| " x = torch.cat([skip, x], dim=1)\n", | |
| " x = self.conv_block(x)\n", | |
| " return x\n", | |
| "\n", | |
| " \n", | |
| " def _crop(self,skip,x):\n", | |
| " if self.padding is 0:\n", | |
| " if self.crop is None:\n", | |
| " self.crop=self.cropping(skip.size()[-1],x.size()[-1])\n", | |
| " skip=skip[:,:,self.crop:-self.crop,self.crop:-self.crop]\n", | |
| " return skip\n", | |
| "\n", | |
| " \n", | |
| "db_out=UpBlock(256,100)\n", | |
| "print(db_out.out_size,db_out.out_ch)\n", | |
| "print(db_out(torch.rand(1,256,100,100),torch.rand(1,128,280,280)).shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(4, 2)\n", | |
| "(388, 2)\n", | |
| "torch.Size([1, 2, 388, 388])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 570, 570] 640\n", | |
| " Conv2d-2 [-1, 64, 568, 568] 36,928\n", | |
| " BatchNorm2d-3 [-1, 64, 568, 568] 128\n", | |
| " AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
| " Linear-5 [-1, 4] 260\n", | |
| " ReLU-6 [-1, 4] 0\n", | |
| " Linear-7 [-1, 64] 320\n", | |
| " Sigmoid-8 [-1, 64] 0\n", | |
| " SqueezeExcitation-9 [-1, 64, 568, 568] 0\n", | |
| " ConvBlock-10 [-1, 64, 568, 568] 0\n", | |
| " MaxPool2d-11 [-1, 64, 284, 284] 0\n", | |
| " Conv2d-12 [-1, 128, 282, 282] 73,856\n", | |
| " Conv2d-13 [-1, 128, 280, 280] 147,584\n", | |
| " BatchNorm2d-14 [-1, 128, 280, 280] 256\n", | |
| "AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n", | |
| " Linear-16 [-1, 8] 1,032\n", | |
| " ReLU-17 [-1, 8] 0\n", | |
| " Linear-18 [-1, 128] 1,152\n", | |
| " Sigmoid-19 [-1, 128] 0\n", | |
| "SqueezeExcitation-20 [-1, 128, 280, 280] 0\n", | |
| " ConvBlock-21 [-1, 128, 280, 280] 0\n", | |
| " DownBlock-22 [-1, 128, 280, 280] 0\n", | |
| " MaxPool2d-23 [-1, 128, 140, 140] 0\n", | |
| " Conv2d-24 [-1, 256, 138, 138] 295,168\n", | |
| " Conv2d-25 [-1, 256, 136, 136] 590,080\n", | |
| " BatchNorm2d-26 [-1, 256, 136, 136] 512\n", | |
| "AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n", | |
| " Linear-28 [-1, 16] 4,112\n", | |
| " ReLU-29 [-1, 16] 0\n", | |
| " Linear-30 [-1, 256] 4,352\n", | |
| " Sigmoid-31 [-1, 256] 0\n", | |
| "SqueezeExcitation-32 [-1, 256, 136, 136] 0\n", | |
| " ConvBlock-33 [-1, 256, 136, 136] 0\n", | |
| " DownBlock-34 [-1, 256, 136, 136] 0\n", | |
| " MaxPool2d-35 [-1, 256, 68, 68] 0\n", | |
| " Conv2d-36 [-1, 512, 66, 66] 1,180,160\n", | |
| " Conv2d-37 [-1, 512, 64, 64] 2,359,808\n", | |
| " BatchNorm2d-38 [-1, 512, 64, 64] 1,024\n", | |
| "AdaptiveAvgPool2d-39 [-1, 512, 1, 1] 0\n", | |
| " Linear-40 [-1, 32] 16,416\n", | |
| " ReLU-41 [-1, 32] 0\n", | |
| " Linear-42 [-1, 512] 16,896\n", | |
| " Sigmoid-43 [-1, 512] 0\n", | |
| "SqueezeExcitation-44 [-1, 512, 64, 64] 0\n", | |
| " ConvBlock-45 [-1, 512, 64, 64] 0\n", | |
| " DownBlock-46 [-1, 512, 64, 64] 0\n", | |
| " MaxPool2d-47 [-1, 512, 32, 32] 0\n", | |
| " Conv2d-48 [-1, 1024, 30, 30] 4,719,616\n", | |
| " Conv2d-49 [-1, 1024, 28, 28] 9,438,208\n", | |
| " BatchNorm2d-50 [-1, 1024, 28, 28] 2,048\n", | |
| "AdaptiveAvgPool2d-51 [-1, 1024, 1, 1] 0\n", | |
| " Linear-52 [-1, 64] 65,600\n", | |
| " ReLU-53 [-1, 64] 0\n", | |
| " Linear-54 [-1, 1024] 66,560\n", | |
| " Sigmoid-55 [-1, 1024] 0\n", | |
| "SqueezeExcitation-56 [-1, 1024, 28, 28] 0\n", | |
| " ConvBlock-57 [-1, 1024, 28, 28] 0\n", | |
| " DownBlock-58 [-1, 1024, 28, 28] 0\n", | |
| " ConvTranspose2d-59 [-1, 512, 56, 56] 2,097,664\n", | |
| " Conv2d-60 [-1, 512, 54, 54] 4,719,104\n", | |
| " Conv2d-61 [-1, 512, 52, 52] 2,359,808\n", | |
| " BatchNorm2d-62 [-1, 512, 52, 52] 1,024\n", | |
| "AdaptiveAvgPool2d-63 [-1, 512, 1, 1] 0\n", | |
| " Linear-64 [-1, 32] 16,416\n", | |
| " ReLU-65 [-1, 32] 0\n", | |
| " Linear-66 [-1, 512] 16,896\n", | |
| " Sigmoid-67 [-1, 512] 0\n", | |
| "SqueezeExcitation-68 [-1, 512, 52, 52] 0\n", | |
| " ConvBlock-69 [-1, 512, 52, 52] 0\n", | |
| " UpBlock-70 [-1, 512, 52, 52] 0\n", | |
| " ConvTranspose2d-71 [-1, 256, 104, 104] 524,544\n", | |
| " Conv2d-72 [-1, 256, 102, 102] 1,179,904\n", | |
| " Conv2d-73 [-1, 256, 100, 100] 590,080\n", | |
| " BatchNorm2d-74 [-1, 256, 100, 100] 512\n", | |
| "AdaptiveAvgPool2d-75 [-1, 256, 1, 1] 0\n", | |
| " Linear-76 [-1, 16] 4,112\n", | |
| " ReLU-77 [-1, 16] 0\n", | |
| " Linear-78 [-1, 256] 4,352\n", | |
| " Sigmoid-79 [-1, 256] 0\n", | |
| "SqueezeExcitation-80 [-1, 256, 100, 100] 0\n", | |
| " ConvBlock-81 [-1, 256, 100, 100] 0\n", | |
| " UpBlock-82 [-1, 256, 100, 100] 0\n", | |
| " ConvTranspose2d-83 [-1, 128, 200, 200] 131,200\n", | |
| " Conv2d-84 [-1, 128, 198, 198] 295,040\n", | |
| " Conv2d-85 [-1, 128, 196, 196] 147,584\n", | |
| " BatchNorm2d-86 [-1, 128, 196, 196] 256\n", | |
| "AdaptiveAvgPool2d-87 [-1, 128, 1, 1] 0\n", | |
| " Linear-88 [-1, 8] 1,032\n", | |
| " ReLU-89 [-1, 8] 0\n", | |
| " Linear-90 [-1, 128] 1,152\n", | |
| " Sigmoid-91 [-1, 128] 0\n", | |
| "SqueezeExcitation-92 [-1, 128, 196, 196] 0\n", | |
| " ConvBlock-93 [-1, 128, 196, 196] 0\n", | |
| " UpBlock-94 [-1, 128, 196, 196] 0\n", | |
| " ConvTranspose2d-95 [-1, 64, 392, 392] 32,832\n", | |
| " Conv2d-96 [-1, 64, 390, 390] 73,792\n", | |
| " Conv2d-97 [-1, 64, 388, 388] 36,928\n", | |
| " BatchNorm2d-98 [-1, 64, 388, 388] 128\n", | |
| "AdaptiveAvgPool2d-99 [-1, 64, 1, 1] 0\n", | |
| " Linear-100 [-1, 4] 260\n", | |
| " ReLU-101 [-1, 4] 0\n", | |
| " Linear-102 [-1, 64] 320\n", | |
| " Sigmoid-103 [-1, 64] 0\n", | |
| "SqueezeExcitation-104 [-1, 64, 388, 388] 0\n", | |
| " ConvBlock-105 [-1, 64, 388, 388] 0\n", | |
| " UpBlock-106 [-1, 64, 388, 388] 0\n", | |
| " Conv2d-107 [-1, 2, 388, 388] 130\n", | |
| "================================================================\n", | |
| "Total params: 31,257,786\n", | |
| "Trainable params: 31,257,786\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "class UNet(nn.Module):\n", | |
| "\n", | |
| " def __init__(self,\n", | |
| " network_depth=4,\n", | |
| " conv_depth=2,\n", | |
| " in_size=572,\n", | |
| " in_ch=1,\n", | |
| " out_ch=2,\n", | |
| " init_ch=64,\n", | |
| " padding=0,\n", | |
| " bn=True,\n", | |
| " se=True,\n", | |
| " act='relu',\n", | |
| " act_kwargs={}):\n", | |
| " super(UNet, self).__init__()\n", | |
| " self.network_depth=network_depth\n", | |
| " self.conv_depth=conv_depth\n", | |
| " self.out_ch=out_ch\n", | |
| " self.padding=padding\n", | |
| " self.input_conv=ConvBlock(\n", | |
| " in_ch=in_ch,\n", | |
| " in_size=in_size,\n", | |
| " out_ch=init_ch,\n", | |
| " depth=self.conv_depth,\n", | |
| " padding=padding,\n", | |
| " bn=bn,\n", | |
| " se=se,\n", | |
| " act=act,\n", | |
| " act_kwargs=act_kwargs)\n", | |
| " down_layers=self._down_layers(\n", | |
| " self.input_conv.out_ch,\n", | |
| " self.input_conv.out_size,\n", | |
| " bn=bn,\n", | |
| " se=se,\n", | |
| " act=act,\n", | |
| " act_kwargs=act_kwargs)\n", | |
| " self.down_blocks=nn.ModuleList(down_layers)\n", | |
| " up_layers=self._up_layers(\n", | |
| " down_layers,\n", | |
| " bn=bn,\n", | |
| " se=se,\n", | |
| " act=act,\n", | |
| " act_kwargs=act_kwargs)\n", | |
| " self.up_blocks=nn.ModuleList(up_layers)\n", | |
| " self.out_size=self.up_blocks[-1].out_size\n", | |
| " self.output_conv=self._output_layer(out_ch)\n", | |
| "\n", | |
| " \n", | |
| " def forward(self, x):\n", | |
| " x=self.input_conv(x)\n", | |
| " skips=[x]\n", | |
| " for block in self.down_blocks:\n", | |
| " x=block(x)\n", | |
| " skips.append(x)\n", | |
| " skips.pop()\n", | |
| " skips=skips[::-1]\n", | |
| " for skip,block in zip(skips,self.up_blocks):\n", | |
| " x=block(x,skip)\n", | |
| " x=self.output_conv(x)\n", | |
| " return x\n", | |
| " \n", | |
| " \n", | |
| " def _down_layers(self,in_ch,in_size,bn,se,act,act_kwargs):\n", | |
| " layers=[]\n", | |
| " for index in range(1,self.network_depth+1):\n", | |
| " layer=DownBlock(\n", | |
| " in_ch,\n", | |
| " in_size,\n", | |
| " depth=self.conv_depth,\n", | |
| " padding=self.padding,\n", | |
| " bn=bn,\n", | |
| " se=se,\n", | |
| " act=act,\n", | |
| " act_kwargs=act_kwargs)\n", | |
| " in_ch=layer.out_ch\n", | |
| " in_size=layer.out_size\n", | |
| " layers.append(layer)\n", | |
| " return layers\n", | |
| "\n", | |
| " \n", | |
| " def _up_layers(self,down_layers,bn,se,act,act_kwargs):\n", | |
| " down_layers=down_layers[::-1]\n", | |
| " down_layers.append(self.input_conv)\n", | |
| " first=down_layers.pop(0)\n", | |
| " in_ch=first.out_ch\n", | |
| " in_size=first.out_size\n", | |
| " layers=[]\n", | |
| " for down_layer in down_layers:\n", | |
| " crop=UpBlock.cropping(down_layer.out_size,2*in_size)\n", | |
| " layer=UpBlock(\n", | |
| " in_ch,\n", | |
| " in_size,\n", | |
| " depth=self.conv_depth,\n", | |
| " crop=crop,\n", | |
| " padding=self.padding,\n", | |
| " bn=bn,\n", | |
| " se=se,\n", | |
| " act=act,\n", | |
| " act_kwargs=act_kwargs)\n", | |
| " in_ch=layer.out_ch\n", | |
| " in_size=layer.out_size\n", | |
| " layers.append(layer)\n", | |
| " return layers\n", | |
| "\n", | |
| " \n", | |
| " def _output_layer(self,out_ch):\n", | |
| " return nn.Conv2d(\n", | |
| " in_channels=64,\n", | |
| " out_channels=out_ch,\n", | |
| " kernel_size=1,\n", | |
| " stride=1,\n", | |
| " padding=0)\n", | |
| " \n", | |
| " \n", | |
| "unet=UNet(in_size=572,network_depth=4,conv_depth=2)\n", | |
| "print(unet.network_depth,unet.conv_depth)\n", | |
| "print(unet.out_size,unet.out_ch)\n", | |
| "print(unet(torch.rand(1,1,572,572)).shape)\n", | |
| "summary(unet,input_size=(1,572,572))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(2, 4)\n", | |
| "(492, 2)\n", | |
| "torch.Size([1, 2, 492, 492])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 570, 570] 640\n", | |
| " Conv2d-2 [-1, 64, 568, 568] 36,928\n", | |
| " Conv2d-3 [-1, 64, 566, 566] 36,928\n", | |
| " Conv2d-4 [-1, 64, 564, 564] 36,928\n", | |
| " BatchNorm2d-5 [-1, 64, 564, 564] 128\n", | |
| " AdaptiveAvgPool2d-6 [-1, 64, 1, 1] 0\n", | |
| " Linear-7 [-1, 4] 260\n", | |
| " ReLU-8 [-1, 4] 0\n", | |
| " Linear-9 [-1, 64] 320\n", | |
| " Sigmoid-10 [-1, 64] 0\n", | |
| "SqueezeExcitation-11 [-1, 64, 564, 564] 0\n", | |
| " ConvBlock-12 [-1, 64, 564, 564] 0\n", | |
| " MaxPool2d-13 [-1, 64, 282, 282] 0\n", | |
| " Conv2d-14 [-1, 128, 280, 280] 73,856\n", | |
| " Conv2d-15 [-1, 128, 278, 278] 147,584\n", | |
| " Conv2d-16 [-1, 128, 276, 276] 147,584\n", | |
| " Conv2d-17 [-1, 128, 274, 274] 147,584\n", | |
| " BatchNorm2d-18 [-1, 128, 274, 274] 256\n", | |
| "AdaptiveAvgPool2d-19 [-1, 128, 1, 1] 0\n", | |
| " Linear-20 [-1, 8] 1,032\n", | |
| " ReLU-21 [-1, 8] 0\n", | |
| " Linear-22 [-1, 128] 1,152\n", | |
| " Sigmoid-23 [-1, 128] 0\n", | |
| "SqueezeExcitation-24 [-1, 128, 274, 274] 0\n", | |
| " ConvBlock-25 [-1, 128, 274, 274] 0\n", | |
| " DownBlock-26 [-1, 128, 274, 274] 0\n", | |
| " MaxPool2d-27 [-1, 128, 137, 137] 0\n", | |
| " Conv2d-28 [-1, 256, 135, 135] 295,168\n", | |
| " Conv2d-29 [-1, 256, 133, 133] 590,080\n", | |
| " Conv2d-30 [-1, 256, 131, 131] 590,080\n", | |
| " Conv2d-31 [-1, 256, 129, 129] 590,080\n", | |
| " BatchNorm2d-32 [-1, 256, 129, 129] 512\n", | |
| "AdaptiveAvgPool2d-33 [-1, 256, 1, 1] 0\n", | |
| " Linear-34 [-1, 16] 4,112\n", | |
| " ReLU-35 [-1, 16] 0\n", | |
| " Linear-36 [-1, 256] 4,352\n", | |
| " Sigmoid-37 [-1, 256] 0\n", | |
| "SqueezeExcitation-38 [-1, 256, 129, 129] 0\n", | |
| " ConvBlock-39 [-1, 256, 129, 129] 0\n", | |
| " DownBlock-40 [-1, 256, 129, 129] 0\n", | |
| " ConvTranspose2d-41 [-1, 128, 258, 258] 131,200\n", | |
| " Conv2d-42 [-1, 128, 256, 256] 295,040\n", | |
| " Conv2d-43 [-1, 128, 254, 254] 147,584\n", | |
| " Conv2d-44 [-1, 128, 252, 252] 147,584\n", | |
| " Conv2d-45 [-1, 128, 250, 250] 147,584\n", | |
| " BatchNorm2d-46 [-1, 128, 250, 250] 256\n", | |
| "AdaptiveAvgPool2d-47 [-1, 128, 1, 1] 0\n", | |
| " Linear-48 [-1, 8] 1,032\n", | |
| " ReLU-49 [-1, 8] 0\n", | |
| " Linear-50 [-1, 128] 1,152\n", | |
| " Sigmoid-51 [-1, 128] 0\n", | |
| "SqueezeExcitation-52 [-1, 128, 250, 250] 0\n", | |
| " ConvBlock-53 [-1, 128, 250, 250] 0\n", | |
| " UpBlock-54 [-1, 128, 250, 250] 0\n", | |
| " ConvTranspose2d-55 [-1, 64, 500, 500] 32,832\n", | |
| " Conv2d-56 [-1, 64, 498, 498] 73,792\n", | |
| " Conv2d-57 [-1, 64, 496, 496] 36,928\n", | |
| " Conv2d-58 [-1, 64, 494, 494] 36,928\n", | |
| " Conv2d-59 [-1, 64, 492, 492] 36,928\n", | |
| " BatchNorm2d-60 [-1, 64, 492, 492] 128\n", | |
| "AdaptiveAvgPool2d-61 [-1, 64, 1, 1] 0\n", | |
| " Linear-62 [-1, 4] 260\n", | |
| " ReLU-63 [-1, 4] 0\n", | |
| " Linear-64 [-1, 64] 320\n", | |
| " Sigmoid-65 [-1, 64] 0\n", | |
| "SqueezeExcitation-66 [-1, 64, 492, 492] 0\n", | |
| " ConvBlock-67 [-1, 64, 492, 492] 0\n", | |
| " UpBlock-68 [-1, 64, 492, 492] 0\n", | |
| " Conv2d-69 [-1, 2, 492, 492] 130\n", | |
| "================================================================\n", | |
| "Total params: 3,795,242\n", | |
| "Trainable params: 3,795,242\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "unet=UNet(in_size=572,network_depth=2,conv_depth=4)\n", | |
| "print(unet.network_depth,unet.conv_depth)\n", | |
| "print(unet.out_size,unet.out_ch)\n", | |
| "print(unet(torch.rand(1,1,572,572)).shape)\n", | |
| "summary(unet,input_size=(1,572,572))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(2, 2)\n", | |
| "(216, 2)\n", | |
| "torch.Size([1, 2, 216, 216])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 254, 254] 640\n", | |
| " Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
| " BatchNorm2d-3 [-1, 64, 252, 252] 128\n", | |
| " AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
| " Linear-5 [-1, 4] 260\n", | |
| " ReLU-6 [-1, 4] 0\n", | |
| " Linear-7 [-1, 64] 320\n", | |
| " Sigmoid-8 [-1, 64] 0\n", | |
| " SqueezeExcitation-9 [-1, 64, 252, 252] 0\n", | |
| " ConvBlock-10 [-1, 64, 252, 252] 0\n", | |
| " MaxPool2d-11 [-1, 64, 126, 126] 0\n", | |
| " Conv2d-12 [-1, 128, 124, 124] 73,856\n", | |
| " Conv2d-13 [-1, 128, 122, 122] 147,584\n", | |
| " BatchNorm2d-14 [-1, 128, 122, 122] 256\n", | |
| "AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n", | |
| " Linear-16 [-1, 8] 1,032\n", | |
| " ReLU-17 [-1, 8] 0\n", | |
| " Linear-18 [-1, 128] 1,152\n", | |
| " Sigmoid-19 [-1, 128] 0\n", | |
| "SqueezeExcitation-20 [-1, 128, 122, 122] 0\n", | |
| " ConvBlock-21 [-1, 128, 122, 122] 0\n", | |
| " DownBlock-22 [-1, 128, 122, 122] 0\n", | |
| " MaxPool2d-23 [-1, 128, 61, 61] 0\n", | |
| " Conv2d-24 [-1, 256, 59, 59] 295,168\n", | |
| " Conv2d-25 [-1, 256, 57, 57] 590,080\n", | |
| " BatchNorm2d-26 [-1, 256, 57, 57] 512\n", | |
| "AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n", | |
| " Linear-28 [-1, 16] 4,112\n", | |
| " ReLU-29 [-1, 16] 0\n", | |
| " Linear-30 [-1, 256] 4,352\n", | |
| " Sigmoid-31 [-1, 256] 0\n", | |
| "SqueezeExcitation-32 [-1, 256, 57, 57] 0\n", | |
| " ConvBlock-33 [-1, 256, 57, 57] 0\n", | |
| " DownBlock-34 [-1, 256, 57, 57] 0\n", | |
| " ConvTranspose2d-35 [-1, 128, 114, 114] 131,200\n", | |
| " Conv2d-36 [-1, 128, 112, 112] 295,040\n", | |
| " Conv2d-37 [-1, 128, 110, 110] 147,584\n", | |
| " BatchNorm2d-38 [-1, 128, 110, 110] 256\n", | |
| "AdaptiveAvgPool2d-39 [-1, 128, 1, 1] 0\n", | |
| " Linear-40 [-1, 8] 1,032\n", | |
| " ReLU-41 [-1, 8] 0\n", | |
| " Linear-42 [-1, 128] 1,152\n", | |
| " Sigmoid-43 [-1, 128] 0\n", | |
| "SqueezeExcitation-44 [-1, 128, 110, 110] 0\n", | |
| " ConvBlock-45 [-1, 128, 110, 110] 0\n", | |
| " UpBlock-46 [-1, 128, 110, 110] 0\n", | |
| " ConvTranspose2d-47 [-1, 64, 220, 220] 32,832\n", | |
| " Conv2d-48 [-1, 64, 218, 218] 73,792\n", | |
| " Conv2d-49 [-1, 64, 216, 216] 36,928\n", | |
| " BatchNorm2d-50 [-1, 64, 216, 216] 128\n", | |
| "AdaptiveAvgPool2d-51 [-1, 64, 1, 1] 0\n", | |
| " Linear-52 [-1, 4] 260\n", | |
| " ReLU-53 [-1, 4] 0\n", | |
| " Linear-54 [-1, 64] 320\n", | |
| " Sigmoid-55 [-1, 64] 0\n", | |
| "SqueezeExcitation-56 [-1, 64, 216, 216] 0\n", | |
| " ConvBlock-57 [-1, 64, 216, 216] 0\n", | |
| " UpBlock-58 [-1, 64, 216, 216] 0\n", | |
| " Conv2d-59 [-1, 2, 216, 216] 130\n", | |
| "================================================================\n", | |
| "Total params: 1,877,034\n", | |
| "Trainable params: 1,877,034\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "SIZE=256\n", | |
| "unet=UNet(in_size=SIZE,network_depth=2)\n", | |
| "print(unet.network_depth,unet.conv_depth)\n", | |
| "print(unet.out_size,unet.out_ch)\n", | |
| "print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
| "summary(unet,input_size=(1,SIZE,SIZE))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(5, 2)\n", | |
| "(256, 2)\n", | |
| "torch.Size([1, 2, 256, 256])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 256, 256] 640\n", | |
| " Conv2d-2 [-1, 64, 256, 256] 36,928\n", | |
| " BatchNorm2d-3 [-1, 64, 256, 256] 128\n", | |
| " AdaptiveAvgPool2d-4 [-1, 64, 1, 1] 0\n", | |
| " Linear-5 [-1, 4] 260\n", | |
| " ReLU-6 [-1, 4] 0\n", | |
| " Linear-7 [-1, 64] 320\n", | |
| " Sigmoid-8 [-1, 64] 0\n", | |
| " SqueezeExcitation-9 [-1, 64, 256, 256] 0\n", | |
| " ConvBlock-10 [-1, 64, 256, 256] 0\n", | |
| " MaxPool2d-11 [-1, 64, 128, 128] 0\n", | |
| " Conv2d-12 [-1, 128, 128, 128] 73,856\n", | |
| " Conv2d-13 [-1, 128, 128, 128] 147,584\n", | |
| " BatchNorm2d-14 [-1, 128, 128, 128] 256\n", | |
| "AdaptiveAvgPool2d-15 [-1, 128, 1, 1] 0\n", | |
| " Linear-16 [-1, 8] 1,032\n", | |
| " ReLU-17 [-1, 8] 0\n", | |
| " Linear-18 [-1, 128] 1,152\n", | |
| " Sigmoid-19 [-1, 128] 0\n", | |
| "SqueezeExcitation-20 [-1, 128, 128, 128] 0\n", | |
| " ConvBlock-21 [-1, 128, 128, 128] 0\n", | |
| " DownBlock-22 [-1, 128, 128, 128] 0\n", | |
| " MaxPool2d-23 [-1, 128, 64, 64] 0\n", | |
| " Conv2d-24 [-1, 256, 64, 64] 295,168\n", | |
| " Conv2d-25 [-1, 256, 64, 64] 590,080\n", | |
| " BatchNorm2d-26 [-1, 256, 64, 64] 512\n", | |
| "AdaptiveAvgPool2d-27 [-1, 256, 1, 1] 0\n", | |
| " Linear-28 [-1, 16] 4,112\n", | |
| " ReLU-29 [-1, 16] 0\n", | |
| " Linear-30 [-1, 256] 4,352\n", | |
| " Sigmoid-31 [-1, 256] 0\n", | |
| "SqueezeExcitation-32 [-1, 256, 64, 64] 0\n", | |
| " ConvBlock-33 [-1, 256, 64, 64] 0\n", | |
| " DownBlock-34 [-1, 256, 64, 64] 0\n", | |
| " MaxPool2d-35 [-1, 256, 32, 32] 0\n", | |
| " Conv2d-36 [-1, 512, 32, 32] 1,180,160\n", | |
| " Conv2d-37 [-1, 512, 32, 32] 2,359,808\n", | |
| " BatchNorm2d-38 [-1, 512, 32, 32] 1,024\n", | |
| "AdaptiveAvgPool2d-39 [-1, 512, 1, 1] 0\n", | |
| " Linear-40 [-1, 32] 16,416\n", | |
| " ReLU-41 [-1, 32] 0\n", | |
| " Linear-42 [-1, 512] 16,896\n", | |
| " Sigmoid-43 [-1, 512] 0\n", | |
| "SqueezeExcitation-44 [-1, 512, 32, 32] 0\n", | |
| " ConvBlock-45 [-1, 512, 32, 32] 0\n", | |
| " DownBlock-46 [-1, 512, 32, 32] 0\n", | |
| " MaxPool2d-47 [-1, 512, 16, 16] 0\n", | |
| " Conv2d-48 [-1, 1024, 16, 16] 4,719,616\n", | |
| " Conv2d-49 [-1, 1024, 16, 16] 9,438,208\n", | |
| " BatchNorm2d-50 [-1, 1024, 16, 16] 2,048\n", | |
| "AdaptiveAvgPool2d-51 [-1, 1024, 1, 1] 0\n", | |
| " Linear-52 [-1, 64] 65,600\n", | |
| " ReLU-53 [-1, 64] 0\n", | |
| " Linear-54 [-1, 1024] 66,560\n", | |
| " Sigmoid-55 [-1, 1024] 0\n", | |
| "SqueezeExcitation-56 [-1, 1024, 16, 16] 0\n", | |
| " ConvBlock-57 [-1, 1024, 16, 16] 0\n", | |
| " DownBlock-58 [-1, 1024, 16, 16] 0\n", | |
| " MaxPool2d-59 [-1, 1024, 8, 8] 0\n", | |
| " Conv2d-60 [-1, 2048, 8, 8] 18,876,416\n", | |
| " Conv2d-61 [-1, 2048, 8, 8] 37,750,784\n", | |
| " BatchNorm2d-62 [-1, 2048, 8, 8] 4,096\n", | |
| "AdaptiveAvgPool2d-63 [-1, 2048, 1, 1] 0\n", | |
| " Linear-64 [-1, 128] 262,272\n", | |
| " ReLU-65 [-1, 128] 0\n", | |
| " Linear-66 [-1, 2048] 264,192\n", | |
| " Sigmoid-67 [-1, 2048] 0\n", | |
| "SqueezeExcitation-68 [-1, 2048, 8, 8] 0\n", | |
| " ConvBlock-69 [-1, 2048, 8, 8] 0\n", | |
| " DownBlock-70 [-1, 2048, 8, 8] 0\n", | |
| " ConvTranspose2d-71 [-1, 1024, 16, 16] 8,389,632\n", | |
| " Conv2d-72 [-1, 1024, 16, 16] 18,875,392\n", | |
| " Conv2d-73 [-1, 1024, 16, 16] 9,438,208\n", | |
| " BatchNorm2d-74 [-1, 1024, 16, 16] 2,048\n", | |
| "AdaptiveAvgPool2d-75 [-1, 1024, 1, 1] 0\n", | |
| " Linear-76 [-1, 64] 65,600\n", | |
| " ReLU-77 [-1, 64] 0\n", | |
| " Linear-78 [-1, 1024] 66,560\n", | |
| " Sigmoid-79 [-1, 1024] 0\n", | |
| "SqueezeExcitation-80 [-1, 1024, 16, 16] 0\n", | |
| " ConvBlock-81 [-1, 1024, 16, 16] 0\n", | |
| " UpBlock-82 [-1, 1024, 16, 16] 0\n", | |
| " ConvTranspose2d-83 [-1, 512, 32, 32] 2,097,664\n", | |
| " Conv2d-84 [-1, 512, 32, 32] 4,719,104\n", | |
| " Conv2d-85 [-1, 512, 32, 32] 2,359,808\n", | |
| " BatchNorm2d-86 [-1, 512, 32, 32] 1,024\n", | |
| "AdaptiveAvgPool2d-87 [-1, 512, 1, 1] 0\n", | |
| " Linear-88 [-1, 32] 16,416\n", | |
| " ReLU-89 [-1, 32] 0\n", | |
| " Linear-90 [-1, 512] 16,896\n", | |
| " Sigmoid-91 [-1, 512] 0\n", | |
| "SqueezeExcitation-92 [-1, 512, 32, 32] 0\n", | |
| " ConvBlock-93 [-1, 512, 32, 32] 0\n", | |
| " UpBlock-94 [-1, 512, 32, 32] 0\n", | |
| " ConvTranspose2d-95 [-1, 256, 64, 64] 524,544\n", | |
| " Conv2d-96 [-1, 256, 64, 64] 1,179,904\n", | |
| " Conv2d-97 [-1, 256, 64, 64] 590,080\n", | |
| " BatchNorm2d-98 [-1, 256, 64, 64] 512\n", | |
| "AdaptiveAvgPool2d-99 [-1, 256, 1, 1] 0\n", | |
| " Linear-100 [-1, 16] 4,112\n", | |
| " ReLU-101 [-1, 16] 0\n", | |
| " Linear-102 [-1, 256] 4,352\n", | |
| " Sigmoid-103 [-1, 256] 0\n", | |
| "SqueezeExcitation-104 [-1, 256, 64, 64] 0\n", | |
| " ConvBlock-105 [-1, 256, 64, 64] 0\n", | |
| " UpBlock-106 [-1, 256, 64, 64] 0\n", | |
| " ConvTranspose2d-107 [-1, 128, 128, 128] 131,200\n", | |
| " Conv2d-108 [-1, 128, 128, 128] 295,040\n", | |
| " Conv2d-109 [-1, 128, 128, 128] 147,584\n", | |
| " BatchNorm2d-110 [-1, 128, 128, 128] 256\n", | |
| "AdaptiveAvgPool2d-111 [-1, 128, 1, 1] 0\n", | |
| " Linear-112 [-1, 8] 1,032\n", | |
| " ReLU-113 [-1, 8] 0\n", | |
| " Linear-114 [-1, 128] 1,152\n", | |
| " Sigmoid-115 [-1, 128] 0\n", | |
| "SqueezeExcitation-116 [-1, 128, 128, 128] 0\n", | |
| " ConvBlock-117 [-1, 128, 128, 128] 0\n", | |
| " UpBlock-118 [-1, 128, 128, 128] 0\n", | |
| " ConvTranspose2d-119 [-1, 64, 256, 256] 32,832\n", | |
| " Conv2d-120 [-1, 64, 256, 256] 73,792\n", | |
| " Conv2d-121 [-1, 64, 256, 256] 36,928\n", | |
| " BatchNorm2d-122 [-1, 64, 256, 256] 128\n", | |
| "AdaptiveAvgPool2d-123 [-1, 64, 1, 1] 0\n", | |
| " Linear-124 [-1, 4] 260\n", | |
| " ReLU-125 [-1, 4] 0\n", | |
| " Linear-126 [-1, 64] 320\n", | |
| " Sigmoid-127 [-1, 64] 0\n", | |
| "SqueezeExcitation-128 [-1, 64, 256, 256] 0\n", | |
| " ConvBlock-129 [-1, 64, 256, 256] 0\n", | |
| " UpBlock-130 [-1, 64, 256, 256] 0\n", | |
| " Conv2d-131 [-1, 2, 256, 256] 130\n", | |
| "================================================================\n", | |
| "Total params: 125,252,986\n", | |
| "Trainable params: 125,252,986\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "SIZE=256\n", | |
| "unet=UNet(in_size=SIZE,network_depth=5,padding=1)\n", | |
| "print(unet.network_depth,unet.conv_depth)\n", | |
| "print(unet.out_size,unet.out_ch)\n", | |
| "print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
| "summary(unet,input_size=(1,SIZE,SIZE))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(2, 2)\n", | |
| "(216, 2)\n", | |
| "torch.Size([1, 2, 216, 216])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 254, 254] 640\n", | |
| " Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
| " ConvBlock-3 [-1, 64, 252, 252] 0\n", | |
| " MaxPool2d-4 [-1, 64, 126, 126] 0\n", | |
| " Conv2d-5 [-1, 128, 124, 124] 73,856\n", | |
| " Conv2d-6 [-1, 128, 122, 122] 147,584\n", | |
| " ConvBlock-7 [-1, 128, 122, 122] 0\n", | |
| " DownBlock-8 [-1, 128, 122, 122] 0\n", | |
| " MaxPool2d-9 [-1, 128, 61, 61] 0\n", | |
| " Conv2d-10 [-1, 256, 59, 59] 295,168\n", | |
| " Conv2d-11 [-1, 256, 57, 57] 590,080\n", | |
| " ConvBlock-12 [-1, 256, 57, 57] 0\n", | |
| " DownBlock-13 [-1, 256, 57, 57] 0\n", | |
| " ConvTranspose2d-14 [-1, 128, 114, 114] 131,200\n", | |
| " Conv2d-15 [-1, 128, 112, 112] 295,040\n", | |
| " Conv2d-16 [-1, 128, 110, 110] 147,584\n", | |
| " ConvBlock-17 [-1, 128, 110, 110] 0\n", | |
| " UpBlock-18 [-1, 128, 110, 110] 0\n", | |
| " ConvTranspose2d-19 [-1, 64, 220, 220] 32,832\n", | |
| " Conv2d-20 [-1, 64, 218, 218] 73,792\n", | |
| " Conv2d-21 [-1, 64, 216, 216] 36,928\n", | |
| " ConvBlock-22 [-1, 64, 216, 216] 0\n", | |
| " UpBlock-23 [-1, 64, 216, 216] 0\n", | |
| " Conv2d-24 [-1, 2, 216, 216] 130\n", | |
| "================================================================\n", | |
| "Total params: 1,861,762\n", | |
| "Trainable params: 1,861,762\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "SIZE=256\n", | |
| "unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=False)\n", | |
| "print(unet.network_depth,unet.conv_depth)\n", | |
| "print(unet.out_size,unet.out_ch)\n", | |
| "print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
| "summary(unet,input_size=(1,SIZE,SIZE))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(2, 2)\n", | |
| "(216, 2)\n", | |
| "torch.Size([1, 2, 216, 216])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 254, 254] 640\n", | |
| " Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
| " AdaptiveAvgPool2d-3 [-1, 64, 1, 1] 0\n", | |
| " Linear-4 [-1, 4] 260\n", | |
| " ReLU-5 [-1, 4] 0\n", | |
| " Linear-6 [-1, 64] 320\n", | |
| " Sigmoid-7 [-1, 64] 0\n", | |
| " SqueezeExcitation-8 [-1, 64, 252, 252] 0\n", | |
| " ConvBlock-9 [-1, 64, 252, 252] 0\n", | |
| " MaxPool2d-10 [-1, 64, 126, 126] 0\n", | |
| " Conv2d-11 [-1, 128, 124, 124] 73,856\n", | |
| " Conv2d-12 [-1, 128, 122, 122] 147,584\n", | |
| "AdaptiveAvgPool2d-13 [-1, 128, 1, 1] 0\n", | |
| " Linear-14 [-1, 8] 1,032\n", | |
| " ReLU-15 [-1, 8] 0\n", | |
| " Linear-16 [-1, 128] 1,152\n", | |
| " Sigmoid-17 [-1, 128] 0\n", | |
| "SqueezeExcitation-18 [-1, 128, 122, 122] 0\n", | |
| " ConvBlock-19 [-1, 128, 122, 122] 0\n", | |
| " DownBlock-20 [-1, 128, 122, 122] 0\n", | |
| " MaxPool2d-21 [-1, 128, 61, 61] 0\n", | |
| " Conv2d-22 [-1, 256, 59, 59] 295,168\n", | |
| " Conv2d-23 [-1, 256, 57, 57] 590,080\n", | |
| "AdaptiveAvgPool2d-24 [-1, 256, 1, 1] 0\n", | |
| " Linear-25 [-1, 16] 4,112\n", | |
| " ReLU-26 [-1, 16] 0\n", | |
| " Linear-27 [-1, 256] 4,352\n", | |
| " Sigmoid-28 [-1, 256] 0\n", | |
| "SqueezeExcitation-29 [-1, 256, 57, 57] 0\n", | |
| " ConvBlock-30 [-1, 256, 57, 57] 0\n", | |
| " DownBlock-31 [-1, 256, 57, 57] 0\n", | |
| " ConvTranspose2d-32 [-1, 128, 114, 114] 131,200\n", | |
| " Conv2d-33 [-1, 128, 112, 112] 295,040\n", | |
| " Conv2d-34 [-1, 128, 110, 110] 147,584\n", | |
| "AdaptiveAvgPool2d-35 [-1, 128, 1, 1] 0\n", | |
| " Linear-36 [-1, 8] 1,032\n", | |
| " ReLU-37 [-1, 8] 0\n", | |
| " Linear-38 [-1, 128] 1,152\n", | |
| " Sigmoid-39 [-1, 128] 0\n", | |
| "SqueezeExcitation-40 [-1, 128, 110, 110] 0\n", | |
| " ConvBlock-41 [-1, 128, 110, 110] 0\n", | |
| " UpBlock-42 [-1, 128, 110, 110] 0\n", | |
| " ConvTranspose2d-43 [-1, 64, 220, 220] 32,832\n", | |
| " Conv2d-44 [-1, 64, 218, 218] 73,792\n", | |
| " Conv2d-45 [-1, 64, 216, 216] 36,928\n", | |
| "AdaptiveAvgPool2d-46 [-1, 64, 1, 1] 0\n", | |
| " Linear-47 [-1, 4] 260\n", | |
| " ReLU-48 [-1, 4] 0\n", | |
| " Linear-49 [-1, 64] 320\n", | |
| " Sigmoid-50 [-1, 64] 0\n", | |
| "SqueezeExcitation-51 [-1, 64, 216, 216] 0\n", | |
| " ConvBlock-52 [-1, 64, 216, 216] 0\n", | |
| " UpBlock-53 [-1, 64, 216, 216] 0\n", | |
| " Conv2d-54 [-1, 2, 216, 216] 130\n", | |
| "================================================================\n", | |
| "Total params: 1,875,754\n", | |
| "Trainable params: 1,875,754\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "SIZE=256\n", | |
| "unet=UNet(in_size=SIZE,network_depth=2,bn=False,se=True)\n", | |
| "print(unet.network_depth,unet.conv_depth)\n", | |
| "print(unet.out_size,unet.out_ch)\n", | |
| "print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
| "summary(unet,input_size=(1,SIZE,SIZE))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(2, 2)\n", | |
| "(216, 2)\n", | |
| "torch.Size([1, 2, 216, 216])\n", | |
| "----------------------------------------------------------------\n", | |
| " Layer (type) Output Shape Param #\n", | |
| "================================================================\n", | |
| " Conv2d-1 [-1, 64, 254, 254] 640\n", | |
| " Conv2d-2 [-1, 64, 252, 252] 36,928\n", | |
| " BatchNorm2d-3 [-1, 64, 252, 252] 128\n", | |
| " ConvBlock-4 [-1, 64, 252, 252] 0\n", | |
| " MaxPool2d-5 [-1, 64, 126, 126] 0\n", | |
| " Conv2d-6 [-1, 128, 124, 124] 73,856\n", | |
| " Conv2d-7 [-1, 128, 122, 122] 147,584\n", | |
| " BatchNorm2d-8 [-1, 128, 122, 122] 256\n", | |
| " ConvBlock-9 [-1, 128, 122, 122] 0\n", | |
| " DownBlock-10 [-1, 128, 122, 122] 0\n", | |
| " MaxPool2d-11 [-1, 128, 61, 61] 0\n", | |
| " Conv2d-12 [-1, 256, 59, 59] 295,168\n", | |
| " Conv2d-13 [-1, 256, 57, 57] 590,080\n", | |
| " BatchNorm2d-14 [-1, 256, 57, 57] 512\n", | |
| " ConvBlock-15 [-1, 256, 57, 57] 0\n", | |
| " DownBlock-16 [-1, 256, 57, 57] 0\n", | |
| " ConvTranspose2d-17 [-1, 128, 114, 114] 131,200\n", | |
| " Conv2d-18 [-1, 128, 112, 112] 295,040\n", | |
| " Conv2d-19 [-1, 128, 110, 110] 147,584\n", | |
| " BatchNorm2d-20 [-1, 128, 110, 110] 256\n", | |
| " ConvBlock-21 [-1, 128, 110, 110] 0\n", | |
| " UpBlock-22 [-1, 128, 110, 110] 0\n", | |
| " ConvTranspose2d-23 [-1, 64, 220, 220] 32,832\n", | |
| " Conv2d-24 [-1, 64, 218, 218] 73,792\n", | |
| " Conv2d-25 [-1, 64, 216, 216] 36,928\n", | |
| " BatchNorm2d-26 [-1, 64, 216, 216] 128\n", | |
| " ConvBlock-27 [-1, 64, 216, 216] 0\n", | |
| " UpBlock-28 [-1, 64, 216, 216] 0\n", | |
| " Conv2d-29 [-1, 2, 216, 216] 130\n", | |
| "================================================================\n", | |
| "Total params: 1,863,042\n", | |
| "Trainable params: 1,863,042\n", | |
| "Non-trainable params: 0\n", | |
| "----------------------------------------------------------------\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "SIZE=256\n", | |
| "unet=UNet(in_size=SIZE,network_depth=2,act='elu',se=False)\n", | |
| "print(unet.network_depth,unet.conv_depth)\n", | |
| "print(unet.out_size,unet.out_ch)\n", | |
| "print(unet(torch.rand(1,1,SIZE,SIZE)).shape)\n", | |
| "summary(unet,input_size=(1,SIZE,SIZE))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 2", | |
| "language": "python", | |
| "name": "python2" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.13" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.