Created
March 7, 2026 00:34
-
-
Save swo/06c57d4bede1590351ce0d6136ce7657 to your computer and use it in GitHub Desktop.
Particle MCMC example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import altair as alt | |
| import numpy as np | |
| import numpy.random | |
| import polars as pl | |
| import scipy.stats | |
| from scipy.stats import distributions as dist | |
| class Sampler: | |
| def __init__(self, y: np.ndarray, n_particles: int, n_mh_iter: int): | |
| self.y = y | |
| self.n_particles = n_particles | |
| self.n_mh_iter = n_mh_iter | |
| self.T = len(self.y) | |
| self.prior_x0 = dist.uniform(loc=0.0, scale=1000.0) | |
| self.prior_sigma = dist.uniform(loc=1e-1, scale=3.0) | |
| self.proposal_scale = 0.1 | |
| # initialize mh | |
| sigma0 = self.prior_sigma.rvs(size=1)[0] | |
| particles0, weights0 = self.smc(sigma0) | |
| self.sigmas = [sigma0] | |
| self.particles = [particles0] | |
| self.weights = [weights0] | |
| # run mh | |
| for _ in range(n_mh_iter): | |
| self.mh_iter() | |
| def propose_sigma(self, sigma: float) -> float: | |
| assert isinstance(sigma, float) or ( | |
| isinstance(sigma, np.ndarray) and len(sigma) == 1 | |
| ) | |
| return np.exp( | |
| np.log(sigma) + dist.norm(loc=0.0, scale=self.proposal_scale).rvs(size=1)[0] | |
| ) | |
| def transition(self, x: np.ndarray, sigma: float) -> np.ndarray: | |
| return np.exp(np.log(x) + dist.norm(loc=0.0, scale=sigma).rvs(size=len(x))) | |
| def obs_prob(self, y: float, x: np.ndarray) -> np.ndarray: | |
| return dist.poisson(mu=x).pmf(y) | |
| def amll(self, w: np.ndarray) -> float: | |
| """Approximate marginal log likelihood: log p(y | theta)""" | |
| # check weights shape | |
| assert w.ndim == 2 | |
| assert w.shape == (self.T, self.n_particles) | |
| return np.sum(np.log(np.sum(w, axis=1) / self.n_particles)) | |
| def mh_iter(self): | |
| last_sigma = self.sigmas[-1] | |
| last_weights = self.weights[-1] | |
| new_sigma = self.propose_sigma(last_sigma) | |
| new_particles, new_weights = self.smc(new_sigma) | |
| p_accept = np.exp( | |
| self.amll(new_weights) | |
| + self.prior_sigma.logpdf(new_sigma) | |
| - self.amll(last_weights) | |
| - self.prior_sigma.logpdf(last_sigma) | |
| ) | |
| p_accept = min(1.0, p_accept) | |
| # print(f"{last_sigma=} {new_sigma=} {p_accept=}") | |
| # print(new_particles) | |
| accept = rng.choice([True, False], size=1, p=[p_accept, 1.0 - p_accept])[0] | |
| assert isinstance(accept, (bool, np.bool)) | |
| if accept: | |
| self.sigmas.append(new_sigma) | |
| self.particles.append(new_particles) | |
| self.weights.append(new_weights) | |
| def smc(self, sigma: float): | |
| # `particles` has time on the rows (first axis) and particles on the columns | |
| new_x0 = self.prior_x0.rvs(size=self.n_particles) | |
| particles = np.array([new_x0]) | |
| assert particles.shape == (1, self.n_particles) | |
| # `weights` has the same shape | |
| new_weights = self.obs_prob(self.y[0], new_x0) | |
| weights = np.array([new_weights]) | |
| assert weights.shape == particles.shape | |
| for t in range(1, self.T): | |
| last_weights = weights[-1, :] | |
| norm_weights = last_weights / last_weights.sum() | |
| parents = scipy.stats.multinomial(n=self.n_particles, p=norm_weights).rvs( | |
| size=1 | |
| )[0] | |
| new_particles = np.array( | |
| [particles[:, i] for i, n in enumerate(parents) for _ in range(n)] | |
| ).T | |
| last_x = new_particles[-1, :] | |
| new_x = self.transition(x=last_x, sigma=sigma) | |
| new_weights = self.obs_prob(y=self.y[t], x=new_x) | |
| particles = np.vstack((new_particles, [new_x])) | |
| weights = np.vstack((weights, [new_weights])) | |
| assert particles.shape == (t + 1, self.n_particles) | |
| assert weights.shape == particles.shape | |
| return particles, weights | |
| # generate data | |
| rng = numpy.random.default_rng(42) | |
| T = 10 | |
| sigma = 0.5 | |
| x0 = 100.0 | |
| z = rng.normal(loc=0.0, scale=sigma, size=T) | |
| x = np.exp(np.log(x0) + np.cumsum(np.concat((np.zeros(1), z)))) | |
| y = rng.poisson(x) | |
| sampler = Sampler(y=y, n_particles=int(1e4), n_mh_iter=100) | |
| print("acceptance ratio: ", len(sampler.sigmas) / sampler.n_mh_iter) | |
| print("accepted sigmas", np.array(sampler.sigmas)) | |
| print("y", y) | |
| alt.Chart(pl.DataFrame({"sigma": sampler.sigmas})).mark_bar().encode( | |
| alt.X("sigma", bin=True), alt.Y("count()") | |
| ).save("tmp_sigma.svg") | |
| particle_data = ( | |
| pl.DataFrame(sampler.particles[-1]) | |
| .with_row_index("time") | |
| .unpivot(index="time", variable_name="particle_id") | |
| .with_columns(pl.col("particle_id").str.replace("^column_", "").cast(pl.Int64)) | |
| .filter(pl.col("particle_id") < 10) | |
| ) | |
| alt.Chart(particle_data).mark_line().encode( | |
| alt.X("time"), alt.Y("value"), alt.Detail("particle_id") | |
| ).save("tmp_particles.svg") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment