Last active
January 27, 2021 18:55
-
-
Save antorsae/6b97b2bc71cab1460efc2e17044279e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ScaledWSConv2d(nn.Conv2d): | |
| """2D Conv layer with Scaled Weight Standardization.""" | |
| def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0,dilation=1, groups=1, bias=True, gain=True,eps=1e-4): | |
| nn.Conv2d.__init__(self, in_channels, out_channels,kernel_size, stride,padding, dilation,groups, bias) | |
| if gain: | |
| self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) | |
| else: | |
| self.gain = None | |
| # Epsilon, a small constant to avoid dividing by zero. | |
| self.eps = eps | |
| def get_weight(self): | |
| # Get Scaled WS weight OIHW; | |
| fan_in = np.prod(self.weight.shape[1:]) | |
| mean = torch.mean(self.weight, axis=[1, 2, 3],keepdims=True) | |
| var = torch.var(self.weight, axis=[1, 2, 3],keepdims=True) | |
| weight = (self.weight - mean) / (var * fan_in + self.eps) ** 0.5 | |
| if self.gain is not None: | |
| weight = weight * self.gain | |
| return weight | |
| def forward(self, x): | |
| return F.conv2d(x, self.get_weight(), self.bias,self.stride, self.padding,self.dilation, self.groups) | |
| class SqueezeExcite(nn.Module): | |
| """Simple Squeeze+Excite layers.""" | |
| def __init__(self, in_channels, width, activation): | |
| super().__init__() | |
| self.se_conv0 = nn.Conv2d(in_channels, width,kernel_size=1, bias=True) | |
| self.se_conv1 = nn.Conv2d(width, in_channels,kernel_size=1, bias=True) | |
| self.activation = activation | |
| def forward(self, x): | |
| # Mean pool for NCHW tensors | |
| h = torch.mean(x, axis=[2, 3], keepdims=True) | |
| # Apply two linear layers with activation in between | |
| h = self.se_conv1(self.activation(self.se_conv0(h))) | |
| # Rescale the sigmoid output and return | |
| return (torch.sigmoid(h) * 2) * x | |
| class NFBlock(nn.Module): | |
| """NF-RegNet block.""" | |
| def __init__(self, in_channels, out_channels, stride=1,activation=F.relu, which_conv=ScaledWSConv2d,beta=1.0, alpha=1.0,expansion=2.25, keep_rate=None,se_ratio=0.5, group_size=8): | |
| super().__init__() | |
| self.in_ch, self.out_ch = in_channels, out_channels | |
| self.activation = activation | |
| self.beta, self.alpha = beta, alpha | |
| width = int(in_channels * expansion) | |
| if group_size is None: | |
| groups = 1 | |
| else: | |
| groups = width // group_size | |
| # Round width up if you pick a bad group size | |
| width = int(group_size * groups) | |
| self.stride = stride | |
| self.width = width | |
| self.conv1x1a = which_conv(in_channels, width,kernel_size=1, padding=0) | |
| self.conv3x3 = which_conv(width, width,kernel_size=3, stride=stride,padding=1, groups=groups) | |
| self.conv1x1b = which_conv(width, out_channels,kernel_size=1, padding=0) | |
| if stride > 1 or in_channels != out_channels: | |
| self.conv_shortcut = which_conv(in_channels, out_channels,kernel_size=1, stride=1,padding=0) | |
| else: | |
| self.conv_shortcut = None | |
| # Hidden size of the S+E MLP | |
| se_width = max(1, int(width * se_ratio)) | |
| self.se = SqueezeExcite(width, se_width, self.activation) | |
| self.skipinit_gain = nn.Parameter(torch.zeros(())) | |
| def forward(self,x): | |
| out = self.activation(x) / self.beta | |
| if self.conv_shortcut is not None: | |
| shortcut = self.conv_shortcut(F.avg_pool2d(out, 2)) | |
| else: | |
| shortcut = x | |
| out = self.conv1x1a(out) # Initial bottleneck conv | |
| out = self.conv3x3(self.activation(out)) # Spatial conv | |
| out = self.se(out) # Apply squeeze + excite to middle block. | |
| out = self.conv1x1b(self.activation(out)) | |
| return out * self.skipinit_gain * self.alpha + shortcut | |
| class NFRegNet(nn.Module): | |
| # Nonlinearities. Note that we bake the constant into the | |
| # nonlinearites rather than the WS layers. | |
| nonlinearities = { | |
| 'silu': lambda x: x * torch.sigmoid(x) / .5595, | |
| 'relu': lambda x: F.relu(x) / (0.5 * (1 - 1 / np.pi)) ** 0.5, | |
| 'identity': lambda x: x} | |
| # Block base widths and depths for each variant | |
| params = { | |
| 'NF-RegNet-B0': {'width': [48, 104, 208, 440],'depth': [1, 3, 6, 6], 'train_imsize': 192, 'test_imsize': 224, 'drop_rate': 0.2, 'weight_decay': 2e-5,}, | |
| 'NF-RegNet-B1': {'width': [48, 104, 208, 440],'depth': [2, 4, 7, 7], 'train_imsize': 240, 'test_imsize': 256, 'drop_rate': 0.2, 'weight_decay': 2e-5,}, | |
| 'NF-RegNet-B2': {'width': [56, 112, 232, 488],'depth': [2, 4, 8, 8], 'train_imsize': 240, 'test_imsize': 272, 'drop_rate': 0.3, 'weight_decay': 3e-5,}, | |
| 'NF-RegNet-B3': {'width': [56, 128, 248, 528],'depth': [2, 5, 9, 9], 'train_imsize': 288, 'test_imsize': 320, 'drop_rate': 0.3, 'weight_decay': 4e-5,}, | |
| 'NF-RegNet-B4': {'width': [64, 144, 288, 616],'depth': [2, 6, 11, 11],'train_imsize': 320, 'test_imsize': 384, 'drop_rate': 0.4, 'weight_decay': 4e-5,}, | |
| 'NF-RegNet-B5': {'width': [80, 168, 336, 704],'depth': [3, 7, 14, 14],'train_imsize': 384, 'test_imsize': 456, 'drop_rate': 0.4, 'weight_decay': 5e-5,}, | |
| } | |
| def count_params(module): sum([item.numel() for item in module.parameters()]) | |
| """Normalizer-Free RegNets.""" | |
| def __init__(self, variant='NF-RegNet-B0',num_classes=1000, width=0.75, expansion=2.25,se_ratio=0.5, group_size=8, | |
| alpha=0.2,activation='silu', drop_rate=None,stochdepth_rate=0.0,in_chans=3,global_pool='avg'): | |
| super().__init__() | |
| self.variant = variant | |
| self.width = width | |
| self.expansion = expansion | |
| self.num_classes = num_classes | |
| self.se_ratio = se_ratio | |
| self.group_size = group_size | |
| self.alpha = alpha | |
| self.activation = NFRegNet.nonlinearities.get(activation) | |
| if drop_rate is None: | |
| self.drop_rate = NFRegNet.params[self.variant]['drop_rate'] | |
| else: | |
| self.drop_rate = drop_rate | |
| self.stochdepth_rate = stochdepth_rate | |
| self.which_conv = functools.partial(ScaledWSConv2d,gain=True,bias=True) | |
| # Get width and depth pattern | |
| self.width_pattern = [int(val * self.width) for val in NFRegNet.params[variant]['width']] | |
| self.depth_pattern = NFRegNet.params[self.variant]['depth'] | |
| # Stem conv | |
| in_channels = int(self.width_pattern[0]) | |
| self.initial_conv = self.which_conv(in_chans, in_channels,kernel_size=3,stride=2,padding=1) | |
| # Body | |
| blocks = [] | |
| expected_var = 1.0 | |
| index = 0 | |
| for block_width, stage_depth in zip(self.width_pattern,self.depth_pattern): | |
| for block_index in range(stage_depth): | |
| # Following EffNets, do not expand first block | |
| expand_ratio = expansion if index > 0 else 1 | |
| beta = expected_var ** 0.5 | |
| blocks += [NFBlock(in_channels, block_width,stride=2 if block_index==0 else 1,activation=self.activation,which_conv=self.which_conv,beta=beta, alpha=self.alpha,expansion=expand_ratio,se_ratio=self.se_ratio, group_size=self.group_size)] | |
| # Keep track of output channel count | |
| in_channels = block_width | |
| # Reset expected var at a transition block | |
| if block_index == 0: | |
| expected_var = 1. | |
| # Even if reset occurs, increment expected variance | |
| expected_var += self.alpha ** 2 | |
| index += 1 | |
| self.blocks = nn.Sequential(*blocks) | |
| # Final convolution, following EffNets | |
| ch = int(1280 * in_channels // 440) | |
| self.final_conv = self.which_conv(in_channels, ch,kernel_size=1,padding=0) | |
| in_channels = ch | |
| self.num_features = ch | |
| if self.drop_rate > 0.0: | |
| self.dropout = nn.Dropout(self.drop_rate) | |
| # Classifier layer. Initialize this layer's weight with zeros! | |
| if self.num_classes: | |
| self.fc = nn.Linear(in_channels, self.num_classes, bias=True) | |
| torch.nn.init.zeros_(self.fc.weight) | |
| else: | |
| self.fc = nn.Identity() | |
| self.global_pool = global_pool | |
| def forward(self, x): | |
| """Return the logits without any [log-]softmax.""" | |
| # Stem | |
| out = self.initial_conv(x) | |
| # Blocks | |
| out = self.blocks(out) | |
| # Final activation + conv | |
| out = self.activation(self.final_conv(out)) | |
| # Global average pooling | |
| pool = torch.mean(out, [2, 3]) if self.global_pool == 'avg' else out | |
| if self.drop_rate > 0.0 and self.training: pool = self.dropout(pool) | |
| # Return logits | |
| return self.fc(pool) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment