Skip to content

Instantly share code, notes, and snippets.

@GStechschulte
Created June 18, 2025 15:27
Show Gist options
  • Select an option

  • Save GStechschulte/c0a125e197d9d8a4444c2683987a6644 to your computer and use it in GitHub Desktop.

Select an option

Save GStechschulte/c0a125e197d9d8a4444c2683987a6644 to your computer and use it in GitHub Desktop.
Hierarchical Gaussian process model with GPyTorch
import torch
import gpytorch
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(42)
np.random.seed(42)
## Simulate hierarchical Data ##
group_sizes = [30, 50, 20]
group_means = [0.0, 2.5, -2.5]
group_noises = [0.3, 0.4, 0.2]
all_x, all_y, all_groups = [], [], []
for i, (size, mean, noise) in enumerate(zip(group_sizes, group_means, group_noises)):
x_cont = torch.linspace(-5, 5, size)
y = torch.sin(x_cont) + mean
y_noisy = y + torch.randn(size) * noise
all_x.append(x_cont)
all_y.append(y_noisy)
all_groups.append(torch.full((size,), float(i)))
# Combine all groups into single training tensors
train_x_continuous = torch.cat(all_x)
train_y = torch.cat(all_y)
train_x_groups = torch.cat(all_groups)
train_x = torch.stack([train_x_continuous, train_x_groups], dim=-1)
## Define and train the GPyTorch model ##
class GroupSpecificMean(gpytorch.means.Mean):
"""A GPyTorch mean function that applies a different ConstantMean to data points
based on their group index.
"""
def __init__(self, num_groups, group_dim=1):
super().__init__()
self.num_groups = num_groups
self.group_dim = group_dim
# Create a list of ConstantMean modules, one for each group
self.base_means = torch.nn.ModuleList([gpytorch.means.ConstantMean() for _ in range(num_groups)])
def forward(self, input_tensor):
# Extract group indices from the input tensor
group_indices = input_tensor[:, self.group_dim].long()
# Create a tensor to store the output means
mean_output = torch.zeros_like(group_indices, dtype=input_tensor.dtype)
# Apply the correct mean for each group
for i in range(self.num_groups):
mask = (group_indices == i)
if mask.any():
# Get the mean from the corresponding ConstantMean module
group_mean = self.base_means[i](input_tensor[mask])
mean_output[mask] = group_mean
return mean_output
class HierarchicalGP(gpytorch.models.ExactGP):
"""A hierarchical GP model that uses a combination of kernels to capture
population-level effects and group-specific effects.
"""
def __init__(self, train_x, train_y, likelihood, num_groups):
super(HierarchicalGP, self).__init__(train_x, train_y, likelihood)
self.mean_module = GroupSpecificMean(num_groups=num_groups, group_dim=1)
rbf_kernel = gpytorch.kernels.RBFKernel(active_dims=[0])
group_kernel = gpytorch.kernels.IndexKernel(
num_tasks=num_groups, rank=1, active_dims=[1]
)
# Combine kernels: the RBF kernel models the continuous feature, and the
# IndexKernel models the group-specific variations. Multiply them
# to allow the continuous kernel's shape to vary by group.
self.covar_module = gpytorch.kernels.ScaleKernel(rbf_kernel * group_kernel)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
num_groups = len(group_sizes)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = HierarchicalGP(train_x, train_y, likelihood, num_groups)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
training_iterations = 150
model.train()
likelihood.train()
for i in range(training_iterations):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
if (i + 1) % 10 == 0:
print(f"Iter {i+1}/{training_iterations} - Loss: {loss.item():.3f}")
optimizer.step()
## Make predictions ##
model.eval()
likelihood.eval()
# Create a grid of test points for the continuous feature
test_x_continuous = torch.linspace(-6, 6, 100)
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
# Plot the original training data
for i in range(num_groups):
mask = train_x[:, 1] == i
ax.scatter(
train_x[mask, 0].detach().numpy(),
train_y[mask].detach().numpy(),
color=colors[i],
marker='o',
label=f'Group {i} Data'
)
# Plot predictions for each group using a grid of data
with torch.no_grad(), gpytorch.settings.fast_pred_var():
for i in range(num_groups):
test_x_group_index = torch.full((100,), float(i))
test_x = torch.stack([test_x_continuous, test_x_group_index], dim=-1)
predictions = likelihood(model(test_x))
mean = predictions.mean
lower, upper = predictions.confidence_region()
ax.plot(
test_x_continuous.numpy(),
mean.numpy(),
color=colors[i],
linewidth=2,
label=f'Group {i} Prediction'
)
ax.fill_between(
test_x_continuous.numpy(),
lower.numpy(),
upper.numpy(),
alpha=0.2,
color=colors[i]
)
ax.set_title('Hierarchical GP Predictions')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment