Last active
January 20, 2026 19:54
-
-
Save WardBrian/5386b4c49aa1371a4347f55c083a635f to your computer and use it in GitHub Desktop.
functional-style jax models (take 3)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from jax import random, jit | |
| import jax.numpy as jnp | |
| from jax.scipy import stats | |
| from util import ( | |
| ravelize_function, | |
| make_log_density, | |
| constrain, | |
| positive, | |
| real, | |
| spec_to_pytree, | |
| ) | |
| # These are the primary exports of this module: | |
| __all__ = [ | |
| "log_density", | |
| "log_density_vec", | |
| "constraints", | |
| "generated_quantities", | |
| "generated_quantities_vec", | |
| ] | |
| # data is closed over, so we need some fake data | |
| x = jnp.array([[1.0, 2.0, 3.0], [0.2, 0.1, 0.4]]).T # 3x2 | |
| y = jnp.array([2.1, 3.7, 6.5]) | |
| ## Parameter definitions | |
| # This can be partial (or even entirely omitted) if you do not require any reshaping utilities | |
| # (ravelize_function, spec_to_pytree, init_random, etc.), | |
| # otherwise it must cover all parameters to get the shapes/dtypes correct. | |
| parameter_spec = { | |
| "alpha": real(), | |
| "beta": real(shape=x.shape[1]), | |
| "sigma": positive(), | |
| } | |
| # log density components | |
| def log_prior(alpha, beta, sigma): | |
| lp_alpha = jnp.sum(stats.norm.logpdf(alpha, loc=0.0, scale=1.0)) | |
| lp_beta = jnp.sum(stats.norm.logpdf(beta, loc=0.0, scale=1.0)) | |
| # "scale" is rate of exponential distribution (bad SciPy) | |
| lp_sigma = jnp.sum(stats.expon.logpdf(sigma, scale=1.0)) | |
| return lp_alpha + lp_beta + lp_sigma | |
| def log_likelihood(alpha, beta, sigma): | |
| mu = alpha + x @ beta | |
| return jnp.sum(stats.norm.logpdf(y, loc=mu, scale=sigma)) | |
| # a log density function | |
| log_density = make_log_density(log_prior, log_likelihood, parameter_spec=parameter_spec) | |
| # We can also provide a flattened version, automatically, | |
| # using the structure of the parameters defined above. | |
| log_density_vec = ravelize_function(log_density, spec_to_pytree(parameter_spec)) | |
| # we might also want something like "generated quantities" | |
| @jit | |
| def generated_quantities(rng, x_new, **params): | |
| constrained, _ = constrain(parameter_spec, **params) | |
| alpha, beta, sigma = constrained["alpha"], constrained["beta"], constrained["sigma"] | |
| mu_new = alpha + x_new @ beta | |
| y_new = mu_new + sigma * random.normal(rng, shape=x_new.shape) | |
| return {"alpha": alpha, "beta": beta, "sigma": sigma, "y_new": y_new} | |
| # and a flattened version | |
| @jit | |
| def generated_quantities_vec(rng, x_new, params_vec): | |
| gq = lambda param_dict: generated_quantities(rng, x_new, **param_dict) | |
| return ravelize_function(gq, spec_to_pytree(parameter_spec))(params_vec) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import functools | |
| import blackjax | |
| import jax | |
| import jax.numpy as jnp | |
| from util import init_random | |
| from linear_regression import ( | |
| log_density, | |
| generated_quantities, | |
| parameter_spec, | |
| ) | |
| from linear_regression import log_density_vec, generated_quantities_vec | |
| def stan_sample(log_density, initial, steps=1_000, rng_key=None): | |
| # completely copied from https://blackjax-devs.github.io/blackjax/examples/quickstart.html | |
| def inference_loop(rng_key, kernel, initial_state, num_samples): | |
| @jax.jit | |
| def one_step(state, rng_key): | |
| state, _ = kernel(rng_key, state) | |
| return state, state | |
| keys = jax.random.split(rng_key, num_samples) | |
| _, states = jax.lax.scan(one_step, initial_state, keys) | |
| return states | |
| warmup = blackjax.window_adaptation(blackjax.nuts, log_density) | |
| rng_key, warmup_key, sample_key = jax.random.split(rng_key, 3) | |
| (state, parameters), _ = warmup.run(warmup_key, initial, num_steps=steps) | |
| kernel = blackjax.nuts(log_density, **parameters).step | |
| states = inference_loop(sample_key, kernel, state, steps) | |
| return states | |
| if __name__ == "__main__": | |
| N = 1000 | |
| rng_key = jax.random.key(4567) | |
| init_key, sample_key, gq_key = jax.random.split(rng_key, 3) | |
| # for "generated quantities"-like behavior: | |
| rngs = jax.random.split(gq_key, N) | |
| x_new = jnp.array([0.1, 0.4]) | |
| # sample | |
| init_draw = init_random(parameter_spec, init_key) | |
| states = stan_sample(log_density, init_draw, N, sample_key) | |
| # postprocess draws - constrains and does generated quantities | |
| draws = jax.vmap(generated_quantities, (0, None))(rngs, x_new, **states.position) | |
| print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws)) | |
| # ------------- "flat" version ------------- | |
| init_draw_vec = jax.random.uniform(init_key, shape=(4,)) | |
| states_vec = stan_sample(log_density_vec, init_draw_vec, N, sample_key) | |
| draws_vec = jax.vmap(generated_quantities_vec, (0, None, 0))( | |
| rngs, x_new, states_vec.position | |
| ) | |
| # note: because generated_quantities returns a pytree, we're no longer | |
| # in the flattened realm | |
| print(jax.tree.map(functools.partial(jnp.mean, axis=0), draws_vec)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import jax.numpy as jnp | |
| from typing import Mapping, Protocol | |
| def ravelize_function(f, pytree): | |
| """ | |
| Takes a function that accepts a PyTree and a PyTree, | |
| and produces a function that accepts a flat array. | |
| """ | |
| # note: ravel_pytree is only really safe when we | |
| # know all the dtypes are the same. See | |
| # https://jax.readthedocs.io/en/latest/_autosummary/jax.flatten_util.ravel_pytree.html | |
| # This is usually true in stats models | |
| _, unravel = jax.flatten_util.ravel_pytree(pytree) | |
| return lambda x: f(unravel(x)) | |
| class Shaped(Protocol): | |
| shape: tuple[int, ...] | |
| dtype: jnp.dtype | |
| class ParameterConstraint: | |
| def __init__(self, shape=(), dtype=jnp.float32): | |
| if any(s < 0 for s in shape): | |
| raise ValueError("Shape dimensions must be non-negative") | |
| self.shape = shape | |
| self.dtype = dtype | |
| def __call__(self, x): | |
| return x | |
| def inverse(self, y): | |
| return y | |
| def jacobian(self, _): | |
| return 0.0 | |
| # simple alias: base class does transforms | |
| real = ParameterConstraint | |
| # basic example of a positive constraint | |
| class positive(ParameterConstraint): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def __call__(self, x): | |
| return jnp.exp(x) | |
| def inverse(self, y): | |
| return jnp.log(y) | |
| def jacobian(self, x): | |
| return x | |
| # note: shape will need some care for something like a simplex, | |
| # which should also include axis=... in its definition. The shape | |
| # should be that of the unconstrained parameter. | |
| class simplex(ParameterConstraint): | |
| def __init__(self, shape, axes=..., **kwargs): | |
| self.axes = axes | |
| shape_n = jnp.array(shape, dtype=int) | |
| shape_n = shape_n.at[axes,].set(shape_n[axes,] - 1) | |
| super().__init__(shape=tuple(shape_n), **kwargs) | |
| def __call__(self, x): | |
| raise NotImplementedError("Simplex constraint not yet implemented") | |
| def inverse(self, y): | |
| raise NotImplementedError("Simplex constraint not yet implemented") | |
| def jacobian(self, x): | |
| raise NotImplementedError("Simplex constraint not yet implemented") | |
| def spec_to_pytree( | |
| parameter_spec: Mapping[str, Shaped], | |
| ) -> dict[str, jnp.ndarray]: | |
| """ | |
| Turns a dictionary from parameter names to shapes/dtypes | |
| into a pytree of zeros with those shapes/dtypes. | |
| Note that the Shaped protocol is satisfied by ParameterConstraint | |
| objects, but also by something like a jax.numpy array. | |
| """ | |
| return {k: jnp.zeros(v.shape, dtype=v.dtype) for k, v in parameter_spec.items()} | |
| def constrain(parameter_spec: Mapping[str, ParameterConstraint], **kwargs): | |
| """ | |
| Constrain parameters according to the provided spec and compute the log jacobian. | |
| Anything missing from the spec is assumed to have no constraints. | |
| """ | |
| jacobian = 0.0 | |
| parameters = {} | |
| for param in kwargs: | |
| if param in parameter_spec: | |
| parameters[param] = parameter_spec[param](kwargs[param]) | |
| jacobian += parameter_spec[param].jacobian(kwargs[param]) | |
| else: | |
| parameters[param] = kwargs[param] | |
| return parameters, jacobian | |
| def unconstrain(parameter_spec: Mapping[str, ParameterConstraint], **kwargs): | |
| """ | |
| Inverse of constrain: given constrained parameters, return unconstrained versions. | |
| Anything missing from the spec is assumed to have no constraints. | |
| """ | |
| parameters = {} | |
| for param in kwargs: | |
| if param in parameter_spec: | |
| parameters[param] = parameter_spec[param].inverse(kwargs[param]) | |
| else: | |
| parameters[param] = kwargs[param] | |
| return parameters | |
| # version that assumes data is closed over in | |
| # the passed-in functions. | |
| # Could easily change to pass data later | |
| def make_log_density( | |
| log_prior, | |
| log_likelihood, | |
| parameter_spec: Mapping[str, ParameterConstraint] = dict(), | |
| ): | |
| """ | |
| Make a log_density function from a log_prior, log_likelihood, | |
| and (optionally) a function to constrain parameters. | |
| Parameters | |
| ---------- | |
| log_prior : function | |
| This function will be passed the parameters | |
| by name, and should return the log of the prior density. | |
| log_likelihood : function. | |
| This function will be passed the parameters | |
| by name, and should return the log of the likelihood. | |
| parameter_spec : dict | |
| A dictionary mapping parameter names to ParameterConstraint | |
| objects. | |
| Returns | |
| ------- | |
| function | |
| A function that computes the log density of the model. | |
| """ | |
| @jax.jit | |
| def log_density(unc_params): | |
| params, log_det_jac = constrain(parameter_spec, **unc_params) | |
| return log_det_jac + log_prior(**params) + log_likelihood(**params) | |
| return log_density | |
| # similar to a solution found at https://github.com/jax-ml/jax/discussions/9508#discussioncomment-2144076, | |
| # but uses ravel_pytree to avoid needing to split the key | |
| def init_random(parameter_spec, rng_key, radius=2): | |
| """ | |
| Given a tree and a random key, return a tree with the same structure | |
| but with each leaf replaced by a random uniform value in the range [-radius, radius]. | |
| """ | |
| d, unravel = jax.flatten_util.ravel_pytree(spec_to_pytree(parameter_spec)) | |
| uniforms = jax.random.uniform(rng_key, shape=d.shape, minval=-radius, maxval=radius) | |
| return unravel(uniforms) | |
| def init(parameter_spec, parameters): | |
| return unconstrain(parameter_spec, **parameters) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment