Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Created February 17, 2026 01:30
Show Gist options
  • Select an option

  • Save abikoushi/faede902e5a69ae6ccc65748772e15da to your computer and use it in GitHub Desktop.

Select an option

Save abikoushi/faede902e5a69ae6ccc65748772e15da to your computer and use it in GitHub Desktop.
Gibbs sampler of the Bayesian logistic regression via Polya-gamma distribution
library(BayesLogit)
library(ggplot2)
library(dplyr)
#########
#Bayesian inference for logistic models using Polya-Gamma latent variables (2013)
#Nicholas G. Polson, James G. Scott, Jesse Windle
#https://arxiv.org/abs/1205.0310
#########
#Y: response variable
#X: explanatory design matrix
#lambda: prior parameter
gibbs_mlogit <- function(Y, X, M, iter=1000, lambda=1){
N <- length(Y)
ydif <- Y - 0.5*M
D <- ncol(X)
Lambda <- diag(lambda, D)
W_hist <- array(0, dim = c(iter, D))
W_tilde <- rnorm(D)
for(i in 1:iter){
eta <- drop(X%*%W_tilde)
omega <- rpg(N, M, eta)
## equivalent to #t(X)%*%diag(omega)%*%X + Lambda
Vinv <- t(X) %*% sweep(X,1,omega,"*") + Lambda
##
U <- chol(Vinv)
## equivalent to #mu <- solve(Vinv%*%(t(X)%*%(ydif)))
A <- forwardsolve(t(U), t(X)%*%(ydif))
mu <- backsolve(U, A) #multiply to inverse of U
##
W_tilde <- mu + backsolve(U, rnorm(D))
W_hist[i,] <- W_tilde
}
return(W_hist)
}
set.seed(1234)
W <- c(2,0.5)
x <- rnorm(50,0,1)
X <- cbind(1,x)
prob <- plogis(X%*%W)
Y <- rbinom(nrow(X), 100, prob)
out <- gibbs_mlogit(Y, X, M = 100, iter = 2000, lambda = 1)
#png("traceline1.png")
matplot(out, type="l")
abline(h = W, lty=2, col="darkgrey")
#dev.off()
lp_binom <- function(beta, y, M){
sum(dbinom(y, M, plogis(X%*%beta), log = TRUE))+sum(dnorm(beta,log=TRUE))
}
dfpost <- expand.grid(b1=seq(1.5,2.5,by=0.005),b2=seq(0,1,by=0.005)) %>%
rowwise() %>%
mutate(p = exp(lp_binom(c(b1, b2), y=Y, M=100))) %>%
ungroup()
burnin = 1:100
dfrand = data.frame(out[-burnin,]) %>%
setNames(c("b1","b2"))
p1 <- ggplot(data = dfrand, aes(x=b1, y=b2))+
geom_point(alpha=0.1) +
geom_contour(data = dfpost, aes(z=p, colour=after_stat(level)),
show.legend = FALSE)+
scale_color_viridis_c() +
theme_bw(16)
print(p1)
ggsave(filename = "contour.png", plot = p1, width = 7, height = 7)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment