Skip to content

Instantly share code, notes, and snippets.

@WardBrian
Last active January 20, 2026 19:54
Show Gist options
  • Select an option

  • Save WardBrian/5386b4c49aa1371a4347f55c083a635f to your computer and use it in GitHub Desktop.

Select an option

Save WardBrian/5386b4c49aa1371a4347f55c083a635f to your computer and use it in GitHub Desktop.
functional-style jax models (take 3)
from jax import random, jit
import jax.numpy as jnp
from jax.scipy import stats
from util import (
ravelize_function,
make_log_density,
constrain,
positive,
real,
spec_to_pytree,
)
# These are the primary exports of this module:
__all__ = [
"log_density",
"log_density_vec",
"constraints",
"generated_quantities",
"generated_quantities_vec",
]
# data is closed over, so we need some fake data
x = jnp.array([[1.0, 2.0, 3.0], [0.2, 0.1, 0.4]]).T # 3x2
y = jnp.array([2.1, 3.7, 6.5])
## Parameter definitions
# This can be partial (or even entirely omitted) if you do not require any reshaping utilities
# (ravelize_function, spec_to_pytree, init_random, etc.),
# otherwise it must cover all parameters to get the shapes/dtypes correct.
parameter_spec = {
"alpha": real(),
"beta": real(shape=x.shape[1]),
"sigma": positive(),
}
# log density components
def log_prior(alpha, beta, sigma):
lp_alpha = jnp.sum(stats.norm.logpdf(alpha, loc=0.0, scale=1.0))
lp_beta = jnp.sum(stats.norm.logpdf(beta, loc=0.0, scale=1.0))
# "scale" is rate of exponential distribution (bad SciPy)
lp_sigma = jnp.sum(stats.expon.logpdf(sigma, scale=1.0))
return lp_alpha + lp_beta + lp_sigma
def log_likelihood(alpha, beta, sigma):
mu = alpha + x @ beta
return jnp.sum(stats.norm.logpdf(y, loc=mu, scale=sigma))
# a log density function
log_density = make_log_density(log_prior, log_likelihood, parameter_spec=parameter_spec)
# We can also provide a flattened version, automatically,
# using the structure of the parameters defined above.
log_density_vec = ravelize_function(log_density, spec_to_pytree(parameter_spec))
# we might also want something like "generated quantities"
@jit
def generated_quantities(rng, x_new, **params):
constrained, _ = constrain(parameter_spec, **params)
alpha, beta, sigma = constrained["alpha"], constrained["beta"], constrained["sigma"]
mu_new = alpha + x_new @ beta
y_new = mu_new + sigma * random.normal(rng, shape=x_new.shape)
return {"alpha": alpha, "beta": beta, "sigma": sigma, "y_new": y_new}
# and a flattened version
@jit
def generated_quantities_vec(rng, x_new, params_vec):
gq = lambda param_dict: generated_quantities(rng, x_new, **param_dict)
return ravelize_function(gq, spec_to_pytree(parameter_spec))(params_vec)
import functools
import blackjax
import jax
import jax.numpy as jnp
from util import init_random
from linear_regression import (
log_density,
generated_quantities,
parameter_spec,
)
from linear_regression import log_density_vec, generated_quantities_vec
def stan_sample(log_density, initial, steps=1_000, rng_key=None):
# completely copied from https://blackjax-devs.github.io/blackjax/examples/quickstart.html
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return states
warmup = blackjax.window_adaptation(blackjax.nuts, log_density)
rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3)
(state, parameters), _ = warmup.run(warmup_key, initial, num_steps=steps)
kernel = blackjax.nuts(log_density, **parameters).step
states = inference_loop(sample_key, kernel, state, steps)
return states
if __name__ == "__main__":
N = 1000
rng_key = jax.random.key(4567)
init_key, sample_key, gq_key = jax.random.split(rng_key, 3)
# for "generated quantities"-like behavior:
rngs = jax.random.split(gq_key, N)
x_new = jnp.array([0.1, 0.4])
# sample
init_draw = init_random(parameter_spec, init_key)
states = stan_sample(log_density, init_draw, N, sample_key)
# postprocess draws - constrains and does generated quantities
draws = jax.vmap(generated_quantities, (0, None))(rngs, x_new, **states.position)
print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws))
# ------------- "flat" version -------------
init_draw_vec = jax.random.uniform(init_key, shape=(4,))
states_vec = stan_sample(log_density_vec, init_draw_vec, N, sample_key)
draws_vec = jax.vmap(generated_quantities_vec, (0, None, 0))(
rngs, x_new, states_vec.position
)
# note: because generated_quantities returns a pytree, we're no longer
# in the flattened realm
print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws_vec))
import jax
import jax.numpy as jnp
from typing import Mapping, Protocol
def ravelize_function(f, pytree):
"""
Takes a function that accepts a PyTree and a PyTree,
and produces a function that accepts a flat array.
"""
# note: ravel_pytree is only really safe when we
# know all the dtypes are the same. See
# https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html
# This is usually true in stats models
_, unravel = jax.flatten_util.ravel_pytree(pytree)
return lambda x: f(unravel(x))
class Shaped(Protocol):
shape: tuple[int, ...]
dtype: jnp.dtype
class ParameterConstraint:
def __init__(self, shape=(), dtype=jnp.float32):
if any(s < 0 for s in shape):
raise ValueError("Shape dimensions must be non-negative")
self.shape = shape
self.dtype = dtype
def __call__(self, x):
return x
def inverse(self, y):
return y
def jacobian(self, _):
return 0.0
# simple alias: base class does transforms
real = ParameterConstraint
# basic example of a positive constraint
class positive(ParameterConstraint):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __call__(self, x):
return jnp.exp(x)
def inverse(self, y):
return jnp.log(y)
def jacobian(self, x):
return x
# note: shape will need some care for something like a simplex,
# which should also include axis=... in its definition. The shape
# should be that of the unconstrained parameter.
class simplex(ParameterConstraint):
def __init__(self, shape, axes=..., **kwargs):
self.axes = axes
shape_n = jnp.array(shape, dtype=int)
shape_n = shape_n.at[axes,].set(shape_n[axes,] - 1)
super().__init__(shape=tuple(shape_n), **kwargs)
def __call__(self, x):
raise NotImplementedError("Simplex constraint not yet implemented")
def inverse(self, y):
raise NotImplementedError("Simplex constraint not yet implemented")
def jacobian(self, x):
raise NotImplementedError("Simplex constraint not yet implemented")
def spec_to_pytree(
parameter_spec: Mapping[str, Shaped],
) -> dict[str, jnp.ndarray]:
"""
Turns a dictionary from parameter names to shapes/dtypes
into a pytree of zeros with those shapes/dtypes.
Note that the Shaped protocol is satisfied by ParameterConstraint
objects, but also by something like a jax.numpy array.
"""
return {k: jnp.zeros(v.shape, dtype=v.dtype) for k, v in parameter_spec.items()}
def constrain(parameter_spec: Mapping[str, ParameterConstraint], **kwargs):
"""
Constrain parameters according to the provided spec and compute the log jacobian.
Anything missing from the spec is assumed to have no constraints.
"""
jacobian = 0.0
parameters = {}
for param in kwargs:
if param in parameter_spec:
parameters[param] = parameter_spec[param](kwargs[param])
jacobian += parameter_spec[param].jacobian(kwargs[param])
else:
parameters[param] = kwargs[param]
return parameters, jacobian
def unconstrain(parameter_spec: Mapping[str, ParameterConstraint], **kwargs):
"""
Inverse of constrain: given constrained parameters, return unconstrained versions.
Anything missing from the spec is assumed to have no constraints.
"""
parameters = {}
for param in kwargs:
if param in parameter_spec:
parameters[param] = parameter_spec[param].inverse(kwargs[param])
else:
parameters[param] = kwargs[param]
return parameters
# version that assumes data is closed over in
# the passed-in functions.
# Could easily change to pass data later
def make_log_density(
log_prior,
log_likelihood,
parameter_spec: Mapping[str, ParameterConstraint] = dict(),
):
"""
Make a log_density function from a log_prior, log_likelihood,
and (optionally) a function to constrain parameters.
Parameters
----------
log_prior : function
This function will be passed the parameters
by name, and should return the log of the prior density.
log_likelihood : function.
This function will be passed the parameters
by name, and should return the log of the likelihood.
parameter_spec : dict
A dictionary mapping parameter names to ParameterConstraint
objects.
Returns
-------
function
A function that computes the log density of the model.
"""
@jax.jit
def log_density(unc_params):
params, log_det_jac = constrain(parameter_spec, **unc_params)
return log_det_jac + log_prior(**params) + log_likelihood(**params)
return log_density
# similar to a solution found at https://github.com/jax-ml/jax/discussions/9508#discussioncomment-2144076,
# but uses ravel_pytree to avoid needing to split the key
def init_random(parameter_spec, rng_key, radius=2):
"""
Given a tree and a random key, return a tree with the same structure
but with each leaf replaced by a random uniform value in the range [-radius, radius].
"""
d, unravel = jax.flatten_util.ravel_pytree(spec_to_pytree(parameter_spec))
uniforms = jax.random.uniform(rng_key, shape=d.shape, minval=-radius, maxval=radius)
return unravel(uniforms)
def init(parameter_spec, parameters):
return unconstrain(parameter_spec, **parameters)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment