Skip to content

Instantly share code, notes, and snippets.

@swo
Created March 7, 2026 00:34
Show Gist options
  • Select an option

  • Save swo/06c57d4bede1590351ce0d6136ce7657 to your computer and use it in GitHub Desktop.

Select an option

Save swo/06c57d4bede1590351ce0d6136ce7657 to your computer and use it in GitHub Desktop.
Particle MCMC example
import altair as alt
import numpy as np
import numpy.random
import polars as pl
import scipy.stats
from scipy.stats import distributions as dist
class Sampler:
def __init__(self, y: np.ndarray, n_particles: int, n_mh_iter: int):
self.y = y
self.n_particles = n_particles
self.n_mh_iter = n_mh_iter
self.T = len(self.y)
self.prior_x0 = dist.uniform(loc=0.0, scale=1000.0)
self.prior_sigma = dist.uniform(loc=1e-1, scale=3.0)
self.proposal_scale = 0.1
# initialize mh
sigma0 = self.prior_sigma.rvs(size=1)[0]
particles0, weights0 = self.smc(sigma0)
self.sigmas = [sigma0]
self.particles = [particles0]
self.weights = [weights0]
# run mh
for _ in range(n_mh_iter):
self.mh_iter()
def propose_sigma(self, sigma: float) -> float:
assert isinstance(sigma, float) or (
isinstance(sigma, np.ndarray) and len(sigma) == 1
)
return np.exp(
np.log(sigma) + dist.norm(loc=0.0, scale=self.proposal_scale).rvs(size=1)[0]
)
def transition(self, x: np.ndarray, sigma: float) -> np.ndarray:
return np.exp(np.log(x) + dist.norm(loc=0.0, scale=sigma).rvs(size=len(x)))
def obs_prob(self, y: float, x: np.ndarray) -> np.ndarray:
return dist.poisson(mu=x).pmf(y)
def amll(self, w: np.ndarray) -> float:
"""Approximate marginal log likelihood: log p(y | theta)"""
# check weights shape
assert w.ndim == 2
assert w.shape == (self.T, self.n_particles)
return np.sum(np.log(np.sum(w, axis=1) / self.n_particles))
def mh_iter(self):
last_sigma = self.sigmas[-1]
last_weights = self.weights[-1]
new_sigma = self.propose_sigma(last_sigma)
new_particles, new_weights = self.smc(new_sigma)
p_accept = np.exp(
self.amll(new_weights)
+ self.prior_sigma.logpdf(new_sigma)
- self.amll(last_weights)
- self.prior_sigma.logpdf(last_sigma)
)
p_accept = min(1.0, p_accept)
# print(f"{last_sigma=} {new_sigma=} {p_accept=}")
# print(new_particles)
accept = rng.choice([True, False], size=1, p=[p_accept, 1.0 - p_accept])[0]
assert isinstance(accept, (bool, np.bool))
if accept:
self.sigmas.append(new_sigma)
self.particles.append(new_particles)
self.weights.append(new_weights)
def smc(self, sigma: float):
# `particles` has time on the rows (first axis) and particles on the columns
new_x0 = self.prior_x0.rvs(size=self.n_particles)
particles = np.array([new_x0])
assert particles.shape == (1, self.n_particles)
# `weights` has the same shape
new_weights = self.obs_prob(self.y[0], new_x0)
weights = np.array([new_weights])
assert weights.shape == particles.shape
for t in range(1, self.T):
last_weights = weights[-1, :]
norm_weights = last_weights / last_weights.sum()
parents = scipy.stats.multinomial(n=self.n_particles, p=norm_weights).rvs(
size=1
)[0]
new_particles = np.array(
[particles[:, i] for i, n in enumerate(parents) for _ in range(n)]
).T
last_x = new_particles[-1, :]
new_x = self.transition(x=last_x, sigma=sigma)
new_weights = self.obs_prob(y=self.y[t], x=new_x)
particles = np.vstack((new_particles, [new_x]))
weights = np.vstack((weights, [new_weights]))
assert particles.shape == (t + 1, self.n_particles)
assert weights.shape == particles.shape
return particles, weights
# generate data
rng = numpy.random.default_rng(42)
T = 10
sigma = 0.5
x0 = 100.0
z = rng.normal(loc=0.0, scale=sigma, size=T)
x = np.exp(np.log(x0) + np.cumsum(np.concat((np.zeros(1), z))))
y = rng.poisson(x)
sampler = Sampler(y=y, n_particles=int(1e4), n_mh_iter=100)
print("acceptance ratio: ", len(sampler.sigmas) / sampler.n_mh_iter)
print("accepted sigmas", np.array(sampler.sigmas))
print("y", y)
alt.Chart(pl.DataFrame({"sigma": sampler.sigmas})).mark_bar().encode(
alt.X("sigma", bin=True), alt.Y("count()")
).save("tmp_sigma.svg")
particle_data = (
pl.DataFrame(sampler.particles[-1])
.with_row_index("time")
.unpivot(index="time", variable_name="particle_id")
.with_columns(pl.col("particle_id").str.replace("^column_", "").cast(pl.Int64))
.filter(pl.col("particle_id") < 10)
)
alt.Chart(particle_data).mark_line().encode(
alt.X("time"), alt.Y("value"), alt.Detail("particle_id")
).save("tmp_particles.svg")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment