Last active
November 22, 2018 20:33
-
-
Save yushangdi/2d3cd31211b7dbd330effd05a5c38265 to your computer and use it in GitHub Desktop.
code modified from https://github.com/rhaps0dy/convnets-as-gps NOT the final version of the code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn.functional as F | |
| import gpytorch | |
| from gpytorch.kernels import Kernel | |
| from typing import List | |
| import numpy as np | |
| # import abc | |
| import exkern | |
| from exkern import ElementwiseExKern | |
| class DeepKernelBase(Kernel): | |
| "General kernel for deep networks" | |
| def __init__(self, | |
| input_shape: List[int], | |
| block_sizes: List[int], | |
| block_strides: List[int], | |
| kernel_size: int, | |
| recurse_kern: ElementwiseExKern, | |
| conv_stride: int = 1, | |
| active_dims: slice = None, | |
| input_type = None, | |
| name: str = None): | |
| input_dim = np.prod(input_shape) | |
| super(DeepKernelBase, self).__init__(input_dim, active_dims) | |
| self.input_shape = list(np.copy(input_shape)) | |
| self.block_sizes = np.copy(block_sizes).astype(np.int32) | |
| self.block_strides = np.copy(block_strides).astype(np.int32) | |
| self.kernel_size = kernel_size | |
| self.recurse_kern = recurse_kern | |
| self.conv_stride = conv_stride | |
| if input_type is None: | |
| input_type = torch.float32 | |
| self.input_type = input_type | |
| self.register_parameter(name = "var_weight", | |
| parameter=torch.nn.Parameter(torch.zeros(1)), prior = None) | |
| self.register_parameter(name = "var_bias", | |
| parameter=torch.nn.Parameter(torch.zeros(1)), prior = None) | |
| def forward(self, x1, x2 = None, **params): | |
| if x2 is None: | |
| return self.K(x1) | |
| else: | |
| if len(x1.size()) == 5: | |
| x1 = x1[0] | |
| if len(x2.size()) == 5: | |
| x2 = x2[0] | |
| return self.K(x1,x2) | |
| def K(self, X, X2=None): | |
| # Concatenate the covariance between X and X2 and their respective | |
| # variances. Only 1 variance is needed if X2 is None. | |
| if X.dtype != self.input_type or ( | |
| X2 is not None and X2.dtype != self.input_type): | |
| raise TypeError("Input dtypes are wrong: {} or {} are not {}" | |
| .format(X.dtype, X2.dtype, self.input_type)) | |
| if X2 is None: | |
| N = N2 = X.size()[0] | |
| var_z_list = [ | |
| torch.reshape(torch.pow(X,2), [N] + self.input_shape), | |
| torch.reshape(X[:, None, :] * X, [N*N] + self.input_shape)] | |
| def apply_recurse_kern(var_a_all, concat_outputs=True): | |
| var_a_1 = var_a_all[:N] | |
| var_a_cross = var_a_all[N:] | |
| vz = [self.recurse_kern.Kdiag(var_a_1), | |
| self.recurse_kern.K(var_a_cross, var_a_1, None)] | |
| if concat_outputs: | |
| return torch.cat(vz, 0) | |
| return vz | |
| else: | |
| N, N2 = X.size()[0], X2.size()[0] | |
| var_z_list = [ | |
| torch.reshape(torch.pow(X,2), [N] + self.input_shape), | |
| torch.reshape(torch.pow(X2,2), [N2] + self.input_shape), | |
| torch.reshape(X[:, None, :] * X2, [N*N2] + self.input_shape)] | |
| cross_start = N + N2 | |
| def apply_recurse_kern(var_a_all, concat_outputs=True): | |
| var_a_1 = var_a_all[:N] | |
| var_a_2 = var_a_all[N:cross_start] | |
| var_a_cross = var_a_all[cross_start:] | |
| vz = [self.recurse_kern.Kdiag(var_a_1), | |
| self.recurse_kern.Kdiag(var_a_2), | |
| self.recurse_kern.K(var_a_cross, var_a_1, var_a_2)] | |
| if concat_outputs: | |
| return torch.cat(vz, 0) | |
| return vz | |
| inputs = torch.cat(var_z_list, 0) | |
| if len(self.block_sizes) > 0: | |
| # Define almost all the network | |
| inputs = self.headless_network(inputs, apply_recurse_kern) | |
| # Last nonlinearity before final dense layer | |
| var_z_list = apply_recurse_kern(inputs, concat_outputs=False) | |
| # averaging for the final dense layer | |
| var_z_cross = torch.reshape(var_z_list[-1], [N, N2, -1]) | |
| var_z_cross_last = torch.mean(var_z_cross,2) | |
| result = F.softplus(self.var_bias) + F.softplus(self.var_weight) * var_z_cross_last | |
| # .type(torch.FloatTensor) | |
| # if self.input_type != torch.float64: | |
| # print("Casting kernel from {} to {}" | |
| # .format(self.input_type, torch.float64)) | |
| # return result.float() | |
| return result | |
| def Kdiag(self, X): | |
| if X.dtype != self.input_type: | |
| raise TypeError("Input dtype is wrong: {} is not {}" | |
| .format(X.dtype, self.input_type)) | |
| inputs = torch.reshape(torch.pow(X,2), [-1] + self.input_shape) | |
| if len(self.block_sizes) > 0: | |
| inputs = self.headless_network(inputs, self.recurse_kern.Kdiag) | |
| # Last dense layer | |
| inputs = self.recurse_kern.Kdiag(inputs) | |
| var_z_last = inputs.copy() | |
| for i in range(len(inputs.shape)): | |
| var_z_last = torch.mean(var_z_last,1) | |
| result = F.softplus(self.var_bias) + F.softplus(self.var_weight) * var_z_last | |
| # if self.input_type != torch.float64: | |
| # print("Casting kernel from {} to {}" | |
| # .format(self.input_type, torch.float64)) | |
| # return result.float() | |
| return result | |
| def headless_network(self, inputs, apply_recurse_kern): | |
| """ | |
| Apply the network that this kernel defines, except the last dense layer. | |
| The last dense layer is different for K and Kdiag. | |
| """ | |
| raise NotImplementedError | |
| class DeepKernelTesting(DeepKernelBase): | |
| """ | |
| Reimplement original DeepKernel to test ResNet | |
| """ | |
| def headless_network(self, inputs, apply_recurse_kern): | |
| in_chans = inputs.size()[1] | |
| W_init = (torch.ones([1, in_chans, self.kernel_size, self.kernel_size])* | |
| F.softplus(self.var_weight) / in_chans) #.type(torch.FloatTensor) | |
| inputs = F.conv2d(inputs, W_init, bias=None, stride=[1,1], padding=0 | |
| ) + F.softplus(self.var_bias) #.type(torch.FloatTensor) | |
| W = (torch.ones([1, 1, self.kernel_size, self.kernel_size])* | |
| F.softplus(self.var_weight))#.type(torch.FloatTensor) # No dividing by fan_in | |
| for _ in range(1, len(self.block_sizes)): | |
| inputs = apply_recurse_kern(inputs) | |
| inputs = F.conv2d(inputs, W, | |
| bias=None, stride=1, padding=0) | |
| inputs = inputs + F.softplus(self.var_bias)#.type(torch.FloatTensor) | |
| return inputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from gpytorch.module import Module | |
| __all__ = ['ElementwiseExKern', 'ExReLU', 'ExErf'] | |
| class ElementwiseExKern(Module): #Parameterized | |
| def K(self, cov, var1, var2=None): | |
| raise NotImplementedError | |
| def Kdiag(self, var): | |
| raise NotImplementedError | |
| def nlin(self, x): | |
| """ | |
| The nonlinearity that this is computing the expected inner product of. | |
| Used for testing. | |
| """ | |
| raise NotImplementedError | |
| class ExReLU(ElementwiseExKern): | |
| # TODO : remove name, not used | |
| def __init__(self, exponent=1, multiply_by_sqrt2=False, name=None): | |
| super(ExReLU, self).__init__() #name=name | |
| self.multiply_by_sqrt2 = multiply_by_sqrt2 | |
| if exponent in {0, 1}: | |
| self.exponent = exponent | |
| else: | |
| raise NotImplementedError | |
| def K(self, cov, var1, var2=None): | |
| if var2 is None: | |
| sqrt1 = sqrt2 = torch.sqrt(var1) | |
| else: | |
| sqrt1, sqrt2 = torch.sqrt(var1), torch.sqrt(var2) | |
| norms_prod = sqrt1[:, None, ...] * sqrt2 | |
| norms_prod = torch.reshape(norms_prod, cov.size())#tf.reshape(norms_prod, tf.shape(cov)) #.size() | |
| cos_theta = torch.clamp(cov / norms_prod, -0.999999, 0.999999) #tf.clip_by_value | |
| theta = torch.acos(cos_theta) # angle wrt the previous RKHS #tf.acos | |
| if self.exponent == 0: | |
| return .5 - theta/(2*np.pi) | |
| sin_theta = torch.sqrt(1. - cos_theta**2) | |
| J = sin_theta + (np.pi - theta) * cos_theta | |
| if self.multiply_by_sqrt2: | |
| div = np.pi | |
| else: | |
| div = 2*np.pi | |
| return norms_prod / div * J | |
| def Kdiag(self, var): | |
| if self.multiply_by_sqrt2: | |
| if self.exponent == 0: | |
| return torch.ones_like(var) | |
| else: | |
| return var | |
| else: | |
| if self.exponent == 0: | |
| return torch.ones_like(var)/2 | |
| else: | |
| return var/2 | |
| def nlin(self, x): | |
| if self.multiply_by_sqrt2: | |
| if self.exponent == 0: | |
| return ((1 + torch.sign(x))/np.sqrt(2)) | |
| elif self.exponent == 1: | |
| return torch.nn.functional.relu(x) * np.sqrt(2) #tf.nn.relu | |
| else: | |
| if self.exponent == 0: | |
| return ((1 + torch.sign(x))/2) | |
| elif self.exponent == 1: | |
| return torch.nn.functional.relu(x) | |
| class ExErf(ElementwiseExKern): | |
| """The Gaussian error function as a nonlinearity. It's very similar to the | |
| tanh. Williams 1997""" | |
| def K(self, cov, var1, var2=None): | |
| if var2 is None: | |
| t1 = t2 = 1+2*var1 | |
| else: | |
| t1, t2 = 1+2*var1, 1+2*var2 | |
| vs = torch.reshape(t1[:, None, ...] * t2, cov.size()) | |
| sin_theta = 2*cov / torch.sqrt(vs) | |
| return (2/np.pi) * torch.asin(sin_theta) | |
| def Kdiag(self, var): | |
| v2 = 2*var | |
| return (2/np.pi) * torch.asin(v2 / (1 + v2)) | |
| def nlin(self, x): | |
| return torch.erf(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import gpytorch | |
| from exkern import ExReLU | |
| import numpy as np | |
| from gpytorch.mlls.variational_elbo import VariationalELBO | |
| from gpytorch.variational import GridInterpolationVariationalStrategy | |
| import dkern_kernel | |
| from torchvision import datasets, transforms | |
| transform_train = transforms.Compose([ | |
| transforms.RandomCrop(32, padding=4), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
| ]) | |
| transform_test = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | |
| ]) | |
| dataloader = datasets.CIFAR10 | |
| num_classes = 10 | |
| trainset = dataloader(root='./data', train=True, download=True, transform=transform_train) | |
| testset = dataloader(root='./data', train=False, download=False, transform=transform_test) | |
| #only use first 5 | |
| train_x = torch.tensor(trainset.train_data).type(torch.FloatTensor).reshape(50000,3,32,32)[:5] | |
| train_y = torch.tensor(trainset.train_labels)[:5] | |
| from gpytorch.models import AbstractVariationalGP | |
| from gpytorch.variational import CholeskyVariationalDistribution | |
| from gpytorch.variational import VariationalStrategy | |
| class GPClassificationModel(AbstractVariationalGP): | |
| def __init__(self, train_x): | |
| variational_distribution = CholeskyVariationalDistribution(train_x.size(0)) | |
| variational_strategy = VariationalStrategy(self, train_x, variational_distribution) | |
| super(GPClassificationModel, self).__init__(variational_strategy) | |
| self.mean_module = gpytorch.means.ConstantMean() | |
| self.covar_module = dkern_kernel.DeepKernelTesting([3,32,32],[2,3,4],[4,5,6], | |
| 1,ExReLU(multiply_by_sqrt2=True)) | |
| def forward(self, x): | |
| mean_x = self.mean_module(x) | |
| covar_x = self.covar_module(x) | |
| latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | |
| return latent_pred | |
| # Initialize model and likelihood | |
| model = GPClassificationModel(train_x) | |
| # will change to something like gpytorch.likelihoods.SoftmaxLikelihood(num_features=1, n_classes=10) | |
| likelihood = gpytorch.likelihoods.BernoulliLikelihood() | |
| # Find optimal model hyperparameters | |
| model.train() | |
| likelihood.train() | |
| # Use the adam optimizer | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
| # "Loss" for GPs - the marginal log likelihood | |
| # num_data refers to the amount of training data | |
| mll = VariationalELBO(likelihood, model, train_y.numel()) | |
| training_iter = 50 | |
| for i in range(training_iter): | |
| # Zero backpropped gradients from previous iteration | |
| optimizer.zero_grad() | |
| # Get predictive output | |
| output = model(train_x) | |
| # Calc loss and backprop gradients | |
| loss = -mll(output, train_y) | |
| print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iter, loss.item())) | |
| loss.backward() | |
| optimizer.step() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment