Created
June 18, 2025 15:27
-
-
Save GStechschulte/c0a125e197d9d8a4444c2683987a6644 to your computer and use it in GitHub Desktop.
Hierarchical Gaussian process model with GPyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import gpytorch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| ## Simulate hierarchical Data ## | |
| group_sizes = [30, 50, 20] | |
| group_means = [0.0, 2.5, -2.5] | |
| group_noises = [0.3, 0.4, 0.2] | |
| all_x, all_y, all_groups = [], [], [] | |
| for i, (size, mean, noise) in enumerate(zip(group_sizes, group_means, group_noises)): | |
| x_cont = torch.linspace(-5, 5, size) | |
| y = torch.sin(x_cont) + mean | |
| y_noisy = y + torch.randn(size) * noise | |
| all_x.append(x_cont) | |
| all_y.append(y_noisy) | |
| all_groups.append(torch.full((size,), float(i))) | |
| # Combine all groups into single training tensors | |
| train_x_continuous = torch.cat(all_x) | |
| train_y = torch.cat(all_y) | |
| train_x_groups = torch.cat(all_groups) | |
| train_x = torch.stack([train_x_continuous, train_x_groups], dim=-1) | |
| ## Define and train the GPyTorch model ## | |
| class GroupSpecificMean(gpytorch.means.Mean): | |
| """A GPyTorch mean function that applies a different ConstantMean to data points | |
| based on their group index. | |
| """ | |
| def __init__(self, num_groups, group_dim=1): | |
| super().__init__() | |
| self.num_groups = num_groups | |
| self.group_dim = group_dim | |
| # Create a list of ConstantMean modules, one for each group | |
| self.base_means = torch.nn.ModuleList([gpytorch.means.ConstantMean() for _ in range(num_groups)]) | |
| def forward(self, input_tensor): | |
| # Extract group indices from the input tensor | |
| group_indices = input_tensor[:, self.group_dim].long() | |
| # Create a tensor to store the output means | |
| mean_output = torch.zeros_like(group_indices, dtype=input_tensor.dtype) | |
| # Apply the correct mean for each group | |
| for i in range(self.num_groups): | |
| mask = (group_indices == i) | |
| if mask.any(): | |
| # Get the mean from the corresponding ConstantMean module | |
| group_mean = self.base_means[i](input_tensor[mask]) | |
| mean_output[mask] = group_mean | |
| return mean_output | |
| class HierarchicalGP(gpytorch.models.ExactGP): | |
| """A hierarchical GP model that uses a combination of kernels to capture | |
| population-level effects and group-specific effects. | |
| """ | |
| def __init__(self, train_x, train_y, likelihood, num_groups): | |
| super(HierarchicalGP, self).__init__(train_x, train_y, likelihood) | |
| self.mean_module = GroupSpecificMean(num_groups=num_groups, group_dim=1) | |
| rbf_kernel = gpytorch.kernels.RBFKernel(active_dims=[0]) | |
| group_kernel = gpytorch.kernels.IndexKernel( | |
| num_tasks=num_groups, rank=1, active_dims=[1] | |
| ) | |
| # Combine kernels: the RBF kernel models the continuous feature, and the | |
| # IndexKernel models the group-specific variations. Multiply them | |
| # to allow the continuous kernel's shape to vary by group. | |
| self.covar_module = gpytorch.kernels.ScaleKernel(rbf_kernel * group_kernel) | |
| def forward(self, x): | |
| mean_x = self.mean_module(x) | |
| covar_x = self.covar_module(x) | |
| return gpytorch.distributions.MultivariateNormal(mean_x, covar_x) | |
| num_groups = len(group_sizes) | |
| likelihood = gpytorch.likelihoods.GaussianLikelihood() | |
| model = HierarchicalGP(train_x, train_y, likelihood, num_groups) | |
| mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | |
| training_iterations = 150 | |
| model.train() | |
| likelihood.train() | |
| for i in range(training_iterations): | |
| optimizer.zero_grad() | |
| output = model(train_x) | |
| loss = -mll(output, train_y) | |
| loss.backward() | |
| if (i + 1) % 10 == 0: | |
| print(f"Iter {i+1}/{training_iterations} - Loss: {loss.item():.3f}") | |
| optimizer.step() | |
| ## Make predictions ## | |
| model.eval() | |
| likelihood.eval() | |
| # Create a grid of test points for the continuous feature | |
| test_x_continuous = torch.linspace(-6, 6, 100) | |
| fig, ax = plt.subplots(1, 1, figsize=(12, 8)) | |
| colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] | |
| # Plot the original training data | |
| for i in range(num_groups): | |
| mask = train_x[:, 1] == i | |
| ax.scatter( | |
| train_x[mask, 0].detach().numpy(), | |
| train_y[mask].detach().numpy(), | |
| color=colors[i], | |
| marker='o', | |
| label=f'Group {i} Data' | |
| ) | |
| # Plot predictions for each group using a grid of data | |
| with torch.no_grad(), gpytorch.settings.fast_pred_var(): | |
| for i in range(num_groups): | |
| test_x_group_index = torch.full((100,), float(i)) | |
| test_x = torch.stack([test_x_continuous, test_x_group_index], dim=-1) | |
| predictions = likelihood(model(test_x)) | |
| mean = predictions.mean | |
| lower, upper = predictions.confidence_region() | |
| ax.plot( | |
| test_x_continuous.numpy(), | |
| mean.numpy(), | |
| color=colors[i], | |
| linewidth=2, | |
| label=f'Group {i} Prediction' | |
| ) | |
| ax.fill_between( | |
| test_x_continuous.numpy(), | |
| lower.numpy(), | |
| upper.numpy(), | |
| alpha=0.2, | |
| color=colors[i] | |
| ) | |
| ax.set_title('Hierarchical GP Predictions') | |
| ax.set_xlabel('x') | |
| ax.set_ylabel('y') | |
| ax.legend() | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment