Skip to content

Instantly share code, notes, and snippets.

@yushangdi
Last active November 22, 2018 20:33
Show Gist options
  • Select an option

  • Save yushangdi/2d3cd31211b7dbd330effd05a5c38265 to your computer and use it in GitHub Desktop.

Select an option

Save yushangdi/2d3cd31211b7dbd330effd05a5c38265 to your computer and use it in GitHub Desktop.
code modified from https://github.com/rhaps0dy/convnets-as-gps NOT the final version of the code
import torch
import torch.nn.functional as F
import gpytorch
from gpytorch.kernels import Kernel
from typing import List
import numpy as np
# import abc
import exkern
from exkern import ElementwiseExKern
class DeepKernelBase(Kernel):
"General kernel for deep networks"
def __init__(self,
input_shape: List[int],
block_sizes: List[int],
block_strides: List[int],
kernel_size: int,
recurse_kern: ElementwiseExKern,
conv_stride: int = 1,
active_dims: slice = None,
input_type = None,
name: str = None):
input_dim = np.prod(input_shape)
super(DeepKernelBase, self).__init__(input_dim, active_dims)
self.input_shape = list(np.copy(input_shape))
self.block_sizes = np.copy(block_sizes).astype(np.int32)
self.block_strides = np.copy(block_strides).astype(np.int32)
self.kernel_size = kernel_size
self.recurse_kern = recurse_kern
self.conv_stride = conv_stride
if input_type is None:
input_type = torch.float32
self.input_type = input_type
self.register_parameter(name = "var_weight",
parameter=torch.nn.Parameter(torch.zeros(1)), prior = None)
self.register_parameter(name = "var_bias",
parameter=torch.nn.Parameter(torch.zeros(1)), prior = None)
def forward(self, x1, x2 = None, **params):
if x2 is None:
return self.K(x1)
else:
if len(x1.size()) == 5:
x1 = x1[0]
if len(x2.size()) == 5:
x2 = x2[0]
return self.K(x1,x2)
def K(self, X, X2=None):
# Concatenate the covariance between X and X2 and their respective
# variances. Only 1 variance is needed if X2 is None.
if X.dtype != self.input_type or (
X2 is not None and X2.dtype != self.input_type):
raise TypeError("Input dtypes are wrong: {} or {} are not {}"
.format(X.dtype, X2.dtype, self.input_type))
if X2 is None:
N = N2 = X.size()[0]
var_z_list = [
torch.reshape(torch.pow(X,2), [N] + self.input_shape),
torch.reshape(X[:, None, :] * X, [N*N] + self.input_shape)]
def apply_recurse_kern(var_a_all, concat_outputs=True):
var_a_1 = var_a_all[:N]
var_a_cross = var_a_all[N:]
vz = [self.recurse_kern.Kdiag(var_a_1),
self.recurse_kern.K(var_a_cross, var_a_1, None)]
if concat_outputs:
return torch.cat(vz, 0)
return vz
else:
N, N2 = X.size()[0], X2.size()[0]
var_z_list = [
torch.reshape(torch.pow(X,2), [N] + self.input_shape),
torch.reshape(torch.pow(X2,2), [N2] + self.input_shape),
torch.reshape(X[:, None, :] * X2, [N*N2] + self.input_shape)]
cross_start = N + N2
def apply_recurse_kern(var_a_all, concat_outputs=True):
var_a_1 = var_a_all[:N]
var_a_2 = var_a_all[N:cross_start]
var_a_cross = var_a_all[cross_start:]
vz = [self.recurse_kern.Kdiag(var_a_1),
self.recurse_kern.Kdiag(var_a_2),
self.recurse_kern.K(var_a_cross, var_a_1, var_a_2)]
if concat_outputs:
return torch.cat(vz, 0)
return vz
inputs = torch.cat(var_z_list, 0)
if len(self.block_sizes) > 0:
# Define almost all the network
inputs = self.headless_network(inputs, apply_recurse_kern)
# Last nonlinearity before final dense layer
var_z_list = apply_recurse_kern(inputs, concat_outputs=False)
# averaging for the final dense layer
var_z_cross = torch.reshape(var_z_list[-1], [N, N2, -1])
var_z_cross_last = torch.mean(var_z_cross,2)
result = F.softplus(self.var_bias) + F.softplus(self.var_weight) * var_z_cross_last
# .type(torch.FloatTensor)
# if self.input_type != torch.float64:
# print("Casting kernel from {} to {}"
# .format(self.input_type, torch.float64))
# return result.float()
return result
def Kdiag(self, X):
if X.dtype != self.input_type:
raise TypeError("Input dtype is wrong: {} is not {}"
.format(X.dtype, self.input_type))
inputs = torch.reshape(torch.pow(X,2), [-1] + self.input_shape)
if len(self.block_sizes) > 0:
inputs = self.headless_network(inputs, self.recurse_kern.Kdiag)
# Last dense layer
inputs = self.recurse_kern.Kdiag(inputs)
var_z_last = inputs.copy()
for i in range(len(inputs.shape)):
var_z_last = torch.mean(var_z_last,1)
result = F.softplus(self.var_bias) + F.softplus(self.var_weight) * var_z_last
# if self.input_type != torch.float64:
# print("Casting kernel from {} to {}"
# .format(self.input_type, torch.float64))
# return result.float()
return result
def headless_network(self, inputs, apply_recurse_kern):
"""
Apply the network that this kernel defines, except the last dense layer.
The last dense layer is different for K and Kdiag.
"""
raise NotImplementedError
class DeepKernelTesting(DeepKernelBase):
"""
Reimplement original DeepKernel to test ResNet
"""
def headless_network(self, inputs, apply_recurse_kern):
in_chans = inputs.size()[1]
W_init = (torch.ones([1, in_chans, self.kernel_size, self.kernel_size])*
F.softplus(self.var_weight) / in_chans) #.type(torch.FloatTensor)
inputs = F.conv2d(inputs, W_init, bias=None, stride=[1,1], padding=0
) + F.softplus(self.var_bias) #.type(torch.FloatTensor)
W = (torch.ones([1, 1, self.kernel_size, self.kernel_size])*
F.softplus(self.var_weight))#.type(torch.FloatTensor) # No dividing by fan_in
for _ in range(1, len(self.block_sizes)):
inputs = apply_recurse_kern(inputs)
inputs = F.conv2d(inputs, W,
bias=None, stride=1, padding=0)
inputs = inputs + F.softplus(self.var_bias)#.type(torch.FloatTensor)
return inputs
import numpy as np
import torch
import torch.nn.functional as F
from gpytorch.module import Module
__all__ = ['ElementwiseExKern', 'ExReLU', 'ExErf']
class ElementwiseExKern(Module): #Parameterized
def K(self, cov, var1, var2=None):
raise NotImplementedError
def Kdiag(self, var):
raise NotImplementedError
def nlin(self, x):
"""
The nonlinearity that this is computing the expected inner product of.
Used for testing.
"""
raise NotImplementedError
class ExReLU(ElementwiseExKern):
# TODO : remove name, not used
def __init__(self, exponent=1, multiply_by_sqrt2=False, name=None):
super(ExReLU, self).__init__() #name=name
self.multiply_by_sqrt2 = multiply_by_sqrt2
if exponent in {0, 1}:
self.exponent = exponent
else:
raise NotImplementedError
def K(self, cov, var1, var2=None):
if var2 is None:
sqrt1 = sqrt2 = torch.sqrt(var1)
else:
sqrt1, sqrt2 = torch.sqrt(var1), torch.sqrt(var2)
norms_prod = sqrt1[:, None, ...] * sqrt2
norms_prod = torch.reshape(norms_prod, cov.size())#tf.reshape(norms_prod, tf.shape(cov)) #.size()
cos_theta = torch.clamp(cov / norms_prod, -0.999999, 0.999999) #tf.clip_by_value
theta = torch.acos(cos_theta) # angle wrt the previous RKHS #tf.acos
if self.exponent == 0:
return .5 - theta/(2*np.pi)
sin_theta = torch.sqrt(1. - cos_theta**2)
J = sin_theta + (np.pi - theta) * cos_theta
if self.multiply_by_sqrt2:
div = np.pi
else:
div = 2*np.pi
return norms_prod / div * J
def Kdiag(self, var):
if self.multiply_by_sqrt2:
if self.exponent == 0:
return torch.ones_like(var)
else:
return var
else:
if self.exponent == 0:
return torch.ones_like(var)/2
else:
return var/2
def nlin(self, x):
if self.multiply_by_sqrt2:
if self.exponent == 0:
return ((1 + torch.sign(x))/np.sqrt(2))
elif self.exponent == 1:
return torch.nn.functional.relu(x) * np.sqrt(2) #tf.nn.relu
else:
if self.exponent == 0:
return ((1 + torch.sign(x))/2)
elif self.exponent == 1:
return torch.nn.functional.relu(x)
class ExErf(ElementwiseExKern):
"""The Gaussian error function as a nonlinearity. It's very similar to the
tanh. Williams 1997"""
def K(self, cov, var1, var2=None):
if var2 is None:
t1 = t2 = 1+2*var1
else:
t1, t2 = 1+2*var1, 1+2*var2
vs = torch.reshape(t1[:, None, ...] * t2, cov.size())
sin_theta = 2*cov / torch.sqrt(vs)
return (2/np.pi) * torch.asin(sin_theta)
def Kdiag(self, var):
v2 = 2*var
return (2/np.pi) * torch.asin(v2 / (1 + v2))
def nlin(self, x):
return torch.erf(x)
import torch
import gpytorch
from exkern import ExReLU
import numpy as np
from gpytorch.mlls.variational_elbo import VariationalELBO
from gpytorch.variational import GridInterpolationVariationalStrategy
import dkern_kernel
from torchvision import datasets, transforms
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
dataloader = datasets.CIFAR10
num_classes = 10
trainset = dataloader(root='./data', train=True, download=True, transform=transform_train)
testset = dataloader(root='./data', train=False, download=False, transform=transform_test)
#only use first 5
train_x = torch.tensor(trainset.train_data).type(torch.FloatTensor).reshape(50000,3,32,32)[:5]
train_y = torch.tensor(trainset.train_labels)[:5]
from gpytorch.models import AbstractVariationalGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
class GPClassificationModel(AbstractVariationalGP):
def __init__(self, train_x):
variational_distribution = CholeskyVariationalDistribution(train_x.size(0))
variational_strategy = VariationalStrategy(self, train_x, variational_distribution)
super(GPClassificationModel, self).__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = dkern_kernel.DeepKernelTesting([3,32,32],[2,3,4],[4,5,6],
1,ExReLU(multiply_by_sqrt2=True))
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
return latent_pred
# Initialize model and likelihood
model = GPClassificationModel(train_x)
# will change to something like gpytorch.likelihoods.SoftmaxLikelihood(num_features=1, n_classes=10)
likelihood = gpytorch.likelihoods.BernoulliLikelihood()
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# "Loss" for GPs - the marginal log likelihood
# num_data refers to the amount of training data
mll = VariationalELBO(likelihood, model, train_y.numel())
training_iter = 50
for i in range(training_iter):
# Zero backpropped gradients from previous iteration
optimizer.zero_grad()
# Get predictive output
output = model(train_x)
# Calc loss and backprop gradients
loss = -mll(output, train_y)
print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iter, loss.item()))
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment