Skip to content

Instantly share code, notes, and snippets.

@antorsae
Last active January 27, 2021 18:55
Show Gist options
  • Select an option

  • Save antorsae/6b97b2bc71cab1460efc2e17044279e1 to your computer and use it in GitHub Desktop.

Select an option

Save antorsae/6b97b2bc71cab1460efc2e17044279e1 to your computer and use it in GitHub Desktop.
class ScaledWSConv2d(nn.Conv2d):
"""2D Conv layer with Scaled Weight Standardization."""
def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0,dilation=1, groups=1, bias=True, gain=True,eps=1e-4):
nn.Conv2d.__init__(self, in_channels, out_channels,kernel_size, stride,padding, dilation,groups, bias)
if gain:
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
else:
self.gain = None
# Epsilon, a small constant to avoid dividing by zero.
self.eps = eps
def get_weight(self):
# Get Scaled WS weight OIHW;
fan_in = np.prod(self.weight.shape[1:])
mean = torch.mean(self.weight, axis=[1, 2, 3],keepdims=True)
var = torch.var(self.weight, axis=[1, 2, 3],keepdims=True)
weight = (self.weight - mean) / (var * fan_in + self.eps) ** 0.5
if self.gain is not None:
weight = weight * self.gain
return weight
def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias,self.stride, self.padding,self.dilation, self.groups)
class SqueezeExcite(nn.Module):
"""Simple Squeeze+Excite layers."""
def __init__(self, in_channels, width, activation):
super().__init__()
self.se_conv0 = nn.Conv2d(in_channels, width,kernel_size=1, bias=True)
self.se_conv1 = nn.Conv2d(width, in_channels,kernel_size=1, bias=True)
self.activation = activation
def forward(self, x):
# Mean pool for NCHW tensors
h = torch.mean(x, axis=[2, 3], keepdims=True)
# Apply two linear layers with activation in between
h = self.se_conv1(self.activation(self.se_conv0(h)))
# Rescale the sigmoid output and return
return (torch.sigmoid(h) * 2) * x
class NFBlock(nn.Module):
"""NF-RegNet block."""
def __init__(self, in_channels, out_channels, stride=1,activation=F.relu, which_conv=ScaledWSConv2d,beta=1.0, alpha=1.0,expansion=2.25, keep_rate=None,se_ratio=0.5, group_size=8):
super().__init__()
self.in_ch, self.out_ch = in_channels, out_channels
self.activation = activation
self.beta, self.alpha = beta, alpha
width = int(in_channels * expansion)
if group_size is None:
groups = 1
else:
groups = width // group_size
# Round width up if you pick a bad group size
width = int(group_size * groups)
self.stride = stride
self.width = width
self.conv1x1a = which_conv(in_channels, width,kernel_size=1, padding=0)
self.conv3x3 = which_conv(width, width,kernel_size=3, stride=stride,padding=1, groups=groups)
self.conv1x1b = which_conv(width, out_channels,kernel_size=1, padding=0)
if stride > 1 or in_channels != out_channels:
self.conv_shortcut = which_conv(in_channels, out_channels,kernel_size=1, stride=1,padding=0)
else:
self.conv_shortcut = None
# Hidden size of the S+E MLP
se_width = max(1, int(width * se_ratio))
self.se = SqueezeExcite(width, se_width, self.activation)
self.skipinit_gain = nn.Parameter(torch.zeros(()))
def forward(self,x):
out = self.activation(x) / self.beta
if self.conv_shortcut is not None:
shortcut = self.conv_shortcut(F.avg_pool2d(out, 2))
else:
shortcut = x
out = self.conv1x1a(out) # Initial bottleneck conv
out = self.conv3x3(self.activation(out)) # Spatial conv
out = self.se(out) # Apply squeeze + excite to middle block.
out = self.conv1x1b(self.activation(out))
return out * self.skipinit_gain * self.alpha + shortcut
class NFRegNet(nn.Module):
# Nonlinearities. Note that we bake the constant into the
# nonlinearites rather than the WS layers.
nonlinearities = {
'silu': lambda x: x * torch.sigmoid(x) / .5595,
'relu': lambda x: F.relu(x) / (0.5 * (1 - 1 / np.pi)) ** 0.5,
'identity': lambda x: x}
# Block base widths and depths for each variant
params = {
'NF-RegNet-B0': {'width': [48, 104, 208, 440],'depth': [1, 3, 6, 6], 'train_imsize': 192, 'test_imsize': 224, 'drop_rate': 0.2, 'weight_decay': 2e-5,},
'NF-RegNet-B1': {'width': [48, 104, 208, 440],'depth': [2, 4, 7, 7], 'train_imsize': 240, 'test_imsize': 256, 'drop_rate': 0.2, 'weight_decay': 2e-5,},
'NF-RegNet-B2': {'width': [56, 112, 232, 488],'depth': [2, 4, 8, 8], 'train_imsize': 240, 'test_imsize': 272, 'drop_rate': 0.3, 'weight_decay': 3e-5,},
'NF-RegNet-B3': {'width': [56, 128, 248, 528],'depth': [2, 5, 9, 9], 'train_imsize': 288, 'test_imsize': 320, 'drop_rate': 0.3, 'weight_decay': 4e-5,},
'NF-RegNet-B4': {'width': [64, 144, 288, 616],'depth': [2, 6, 11, 11],'train_imsize': 320, 'test_imsize': 384, 'drop_rate': 0.4, 'weight_decay': 4e-5,},
'NF-RegNet-B5': {'width': [80, 168, 336, 704],'depth': [3, 7, 14, 14],'train_imsize': 384, 'test_imsize': 456, 'drop_rate': 0.4, 'weight_decay': 5e-5,},
}
def count_params(module): sum([item.numel() for item in module.parameters()])
"""Normalizer-Free RegNets."""
def __init__(self, variant='NF-RegNet-B0',num_classes=1000, width=0.75, expansion=2.25,se_ratio=0.5, group_size=8,
alpha=0.2,activation='silu', drop_rate=None,stochdepth_rate=0.0,in_chans=3,global_pool='avg'):
super().__init__()
self.variant = variant
self.width = width
self.expansion = expansion
self.num_classes = num_classes
self.se_ratio = se_ratio
self.group_size = group_size
self.alpha = alpha
self.activation = NFRegNet.nonlinearities.get(activation)
if drop_rate is None:
self.drop_rate = NFRegNet.params[self.variant]['drop_rate']
else:
self.drop_rate = drop_rate
self.stochdepth_rate = stochdepth_rate
self.which_conv = functools.partial(ScaledWSConv2d,gain=True,bias=True)
# Get width and depth pattern
self.width_pattern = [int(val * self.width) for val in NFRegNet.params[variant]['width']]
self.depth_pattern = NFRegNet.params[self.variant]['depth']
# Stem conv
in_channels = int(self.width_pattern[0])
self.initial_conv = self.which_conv(in_chans, in_channels,kernel_size=3,stride=2,padding=1)
# Body
blocks = []
expected_var = 1.0
index = 0
for block_width, stage_depth in zip(self.width_pattern,self.depth_pattern):
for block_index in range(stage_depth):
# Following EffNets, do not expand first block
expand_ratio = expansion if index > 0 else 1
beta = expected_var ** 0.5
blocks += [NFBlock(in_channels, block_width,stride=2 if block_index==0 else 1,activation=self.activation,which_conv=self.which_conv,beta=beta, alpha=self.alpha,expansion=expand_ratio,se_ratio=self.se_ratio, group_size=self.group_size)]
# Keep track of output channel count
in_channels = block_width
# Reset expected var at a transition block
if block_index == 0:
expected_var = 1.
# Even if reset occurs, increment expected variance
expected_var += self.alpha ** 2
index += 1
self.blocks = nn.Sequential(*blocks)
# Final convolution, following EffNets
ch = int(1280 * in_channels // 440)
self.final_conv = self.which_conv(in_channels, ch,kernel_size=1,padding=0)
in_channels = ch
self.num_features = ch
if self.drop_rate > 0.0:
self.dropout = nn.Dropout(self.drop_rate)
# Classifier layer. Initialize this layer's weight with zeros!
if self.num_classes:
self.fc = nn.Linear(in_channels, self.num_classes, bias=True)
torch.nn.init.zeros_(self.fc.weight)
else:
self.fc = nn.Identity()
self.global_pool = global_pool
def forward(self, x):
"""Return the logits without any [log-]softmax."""
# Stem
out = self.initial_conv(x)
# Blocks
out = self.blocks(out)
# Final activation + conv
out = self.activation(self.final_conv(out))
# Global average pooling
pool = torch.mean(out, [2, 3]) if self.global_pool == 'avg' else out
if self.drop_rate > 0.0 and self.training: pool = self.dropout(pool)
# Return logits
return self.fc(pool)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment