Skip to content

Bayesian Inference with PFJAX

Michelle Ko, Martin Lysy 2022-05-18

\( \newcommand{\bm}[1]{\boldsymbol{#1}} \newcommand{\ud}{\mathop{}\!{\mathrm{d}}} \newcommand{\iid}{\stackrel {\mathrm{iid}}{\sim}} \newcommand{\ind}{\stackrel {\mathrm{ind}}{\sim}} \newcommand{\del}[2][]{\frac{\partial^{#1}}{\partial {#2}^{#1}}} \newcommand{\der}[2][]{\frac{\textnormal{d}^{#1}}{\textnormal{d} {#2}^{#1}}} \newcommand{\fdel}[3][]{\frac{\partial^{#1}#3}{\partial{#2}^{#1}}} \newcommand{\fder}[3][]{\frac{\textnormal{d}^{#1}#3}{\textnormal{d} {#2}^{#1}}} \DeclareMathOperator{\logit}{logit} \DeclareMathOperator*{\argmax}{arg\,max} \DeclareMathOperator*{\argmin}{arg\,min} \newcommand{\E}{\mathop{}\!E\!\mathop{}} \DeclareMathOperator{\ibm}{IBM} \DeclareMathOperator{\var}{var} \DeclareMathOperator{\cov}{cov} \DeclareMathOperator{\cor}{cor} \DeclareMathOperator{\sd}{sd} \DeclareMathOperator{\se}{se} \DeclareMathOperator{\diag}{diag} \newcommand{\rv}[3][1]{#2_{#1},\ldots,#2_{#3}} \newcommand{\N}{\mathcal N} \newcommand{\bO}{\mathcal{O}} \newcommand{\correct}[1]{\textbf{[{\color{red}#1}]}} \newcommand{\bz}{{\bm 0}} \renewcommand{\AA}{{\bm A}} \newcommand{\BB}{{\bm B}} \newcommand{\CC}{{\bm C}} \newcommand{\bb}{{\bm b}} \newcommand{\ff}{{\bm f}} \newcommand{\GG}{{\bm G}} \newcommand{\QQ}{{\bm Q}} \newcommand{\RR}{{\bm R}} \newcommand{\WW}{\bm{W}} \newcommand{\XX}{\bm{X}} \newcommand{\ZZ}{\bm{Z}} \newcommand{\xx}{\bm{x}} \newcommand{\yy}{\bm{y}} \newcommand{\YY}{\bm{Y}} \newcommand{\UU}{\bm{U}} \newcommand{\II}{\bm{I}} \newcommand{\eps}{{\bm \epsilon}} \newcommand{\bbe}{{\bm \beta}} \newcommand{\eet}{{\bm \eta}} \newcommand{\lla}{{\bm \lambda}} \newcommand{\pph}{{\bm \phi}} \newcommand{\rrh}{{\bm \rho}} \newcommand{\gga}{{\bm \gamma}} \newcommand{\tta}{{\bm \tau}} \newcommand{\TTh}{\bm{\Theta}} \newcommand{\tth}{\bm{\theta}} \newcommand{\hess}[1]{\frac{\partial^2}{\partial{#1}\partial{#1}'}} \newcommand{\dt}{\Delta t} \newcommand{\Id}{\bm{I}} \newcommand{\mmu}{{\bm \mu}} \newcommand{\SSi}{{\bm \Sigma}} \newcommand{\OOm}{{\bm \Omega}} \newcommand{\ssi}{{\bm \sigma}} \newcommand{\dr}{{\bm \Lambda}} \newcommand{\df}{{\bm \Sigma}} \newcommand{\zzero}{{\bm 0}} \newcommand{\up}[1]{^{(#1)}} \newcommand{\lap}{\textnormal{Lap}} \newcommand{\ella}{\ell_{\lap}} \newcommand{\Lhat}{\hat {\mathcal L}} \newcommand{\Ell}{\mathcal{L}} \newcommand{\pot}{\mathcal{U}} \renewcommand{\vec}{\operatorname{vec}} \)

Summary

In the Introduction to PFJAX, we saw how to set up a model class for a state-space model and how to use a particle filter to estimate its marginal likelihood

\[ p(\yy_{0:T} \mid \tth) = \int \pi(\xx_0 \mid \tth) \times \prod_{t=0}^T g(\yy_t \mid \xx_t, \tth) \times \prod_{t=1}^T f(\xx_t \mid \xx_{t-1}, \tth) \ud \xx_{0:T}. \]

In this tutorial, we’ll use PFJAX to sample from the posterior distribution \(p(\tth \mid \yy_{0:T}) \propto p(\yy_{0:T} \mid \tth) \pi(\tth)\) via Markov chain Monte Carlo (MCMC). There are essentially two ways to go about this:

  1. Do MCMC directly on the posterior distribution \(p(\tth \mid \yy_{0:T})\).
  2. Do MCMC on the joint distribution of \(p(\xx_{0:T}, \tth \mid \yy_{0:T})\) using a Gibbs sampler alternating between draws from \(p(\tth \mid \xx_{0:T}, \yy_{0:T})\) and from \(p(\xx_{0:T} \mid \tth, \yy_{0:T})\), where the latter is obtained with a particle filter/smoother.

Here we’ll be implementing a version of the latter, i.e., a Gibbs sampler in which the parameter draws from \(p(\tth \mid \xx_{0:T}, \yy_{0:T})\) are obtained using adaptive random walk Metropolis-within-Gibbs sampling.

Benchmark Model

We’ll be using a Bootstrap filter for the Brownian motion with drift model defined in the Introduction:

\[ \begin{aligned} x_0 & \sim \N(0, \sigma^2 \dt) \\ x_t & \sim \N(x_{t-1} + \mu \dt, \sigma^2 \dt) \\ y_t & \sim \N(x_t, \tau^2), \end{aligned} \]

where the model parameters are \(\tth = (\mu, \log \sigma, \log \tau)\). The details of setting up the appropriate model class are provided in the Introduction. Here we’ll use the version of this model provided with PFJAX: pfjax.models.BMModel.

# jax
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as random
from functools import partial
# plotting
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# pfjax
import pfjax as pf
import pfjax.mcmc
from pfjax.models import BMModel

Simulate Data

# initial key for random numbers
key = random.PRNGKey(0)

# parameter values
mu, sigma, tau = 1., .5, .8
theta_true = jnp.array([mu, jnp.log(sigma), jnp.log(tau)])

# data specification
dt = .2
n_obs = 100
x_init = jnp.array([0.])

# simulate data
bm_model = BMModel(
    dt=dt, 
    unconstrained_theta=True # puts theta on the unconstrained scale
) 
key, subkey = random.split(key)
y_meas, x_state = pf.simulate(
    model=bm_model,
    key=subkey,
    n_obs=n_obs,
    x_init=x_init,
    theta=theta_true
)

# plot data
plot_df = (pd.DataFrame({"time": jnp.arange(n_obs) * dt,
                         "state": jnp.squeeze(x_state),
                         "meas": jnp.squeeze(y_meas)}))

g = sns.FacetGrid(plot_df, height = 6)
g = g.map(plt.scatter, "time", "meas", color="grey")
plt.plot(plot_df['time'], plot_df['state'], color='black')
plt.legend(labels=["Measurement","State"])
plt.xlabel("Time")
plt.ylabel("Value")
plt.show()

Particle Gibbs Sampler

  • We’ll use a flat prior distribution \(\pi(\tth) \propto 1\).

  • The update for \(p(\xx_{0:T} \mid \tth, \yy_{0:T})\) uses the functions pfjax.particle_filter() and pfjax.particle_smooth().

  • The update for \(p(\tth \mid \xx_{0:T}, \yy_{0:T})\) uses the adaptive Metropolis-within-Gibbs sampler provided by the class pfjax.mcmc.AdaptiveMWG. The design of this sampler is heavily inspired by BlackJAX.

def particle_gibbs(key, n_iter, theta_init, x_state_init, n_particles, rw_sd):
    """
    Sample from the joint posterior distribution of parameters and latent states using a Particle Gibbs sampler.

    Args:
        key: PRNG key.
        n_iter: Number of MCMC iterations.
        theta_init: A vector of `n_params` initial parameter values on the unconstrained scale.
        x_state_init: JAX PyTree of initial state variables.
        n_particles: Number of particles for the particle filter.
        rw_sd: Vector of `n_params` initial standard deviations for the adaptive MWG proposal.

    Returns:
        A dictionary with elements

        - **x_state** - MCMC output for the state variables, with leading dimension `n_iter`.
        - **theta** - MCMC output for the unconstrained parameters, with leading dimension `n_iter`.
        - **accept_rate** - Vector of `n_params` acceptance rates.  These should be close to 0.44.
    """
    # initialize the sampler
    n_params = theta_init.size
    amwg = pfjax.mcmc.AdaptiveMWG()
    # initial state of MWG sampler
    initial_state = {
        "theta": theta_init,
        "x_state": x_state_init,
        "adapt_pars": amwg.init(rw_sd),
    }

    def mcmc_update(key, theta, x_state, adapt_pars):
        """
        MCMC update for parameters and latent variables.

        Use Adaptive MWG for the former and a particle filter for the latter.
        """
        keys = jax.random.split(key, num=3) # two for particle_filter, one for amwg
        # latent variable update
        pf_out = pf.particle_filter(
            model=bm_model,
            key=keys[0],
            y_meas=y_meas,
            theta=theta,
            n_particles=n_particles,
            history=True
        )
        x_state = pf.particle_smooth(
            key=keys[1],
            logw=pf_out["logw"][y_meas.shape[0]-1],
            x_particles=pf_out["x_particles"],
            ancestors=pf_out["resample_out"]["ancestors"]
        )

        # parameter update
        def logpost(theta):
            """
            Log-posterior of the conditional parameter distribution.
            """
            return pf.loglik_full(
                model=bm_model,
                theta=theta,
                x_state=x_state,
                y_meas=y_meas
            )
        theta_state, accept = amwg.step(
            key=keys[2],
            position=theta,
            logprob_fn=logpost,
            rw_sd=adapt_pars["rw_sd"]
        )
        # adapt random walk jump sizes
        adapt_pars = amwg.adapt(pars=adapt_pars, accept=accept)
        return theta_state, x_state, adapt_pars, accept

    @jax.jit
    def step(state, key):
        """
        One step of MCMC update.
        """
        theta, x_state, adapt_pars, accept = mcmc_update(
            key=key,
            theta=state["theta"],
            x_state=state["x_state"],
            adapt_pars=state["adapt_pars"]
        )
        new_state = {
            "theta": theta, 
            "x_state": x_state, 
            "adapt_pars": adapt_pars
        }
        stack_state = {
            "theta": theta, 
            "x_state": x_state
        }
        return new_state, stack_state

    keys = jax.random.split(key, num=n_iter)
    state, out = jax.lax.scan(step, initial_state, keys)
    # calculate acceptance rate
    out["accept_rate"] = (1.0 * state["adapt_pars"]["n_accept"]) / n_iter
    return out

Run Sampler

n_particles = 50
rw_sd = 1. * jnp.array([1., 1., 1.])
n_iter = 20_000

key, subkey = jax.random.split(key)
pg_out = particle_gibbs(
    key=subkey, 
    n_iter=n_iter, 
    theta_init=theta_true, 
    x_state_init=x_state, 
    n_particles=n_particles, 
    rw_sd=rw_sd
)

pg_out["accept_rate"] # should be close to 0.44
Array([0.44425, 0.435  , 0.43565], dtype=float32)

Plot MCMC Output

First we’ll start with the posterior of the latent states \(p(\xx_{0:T} \mid \yy_{0:T})\) against their true values.

# plot data
plot_pg = (pd.DataFrame({"time": jnp.arange(n_obs) * dt,
                         "state": jnp.squeeze(x_state),
                         "meas": jnp.squeeze(y_meas),
                         "med": jnp.squeeze(jnp.median(pg_out["x_state"],axis=0)),
                         "2.5th": jnp.squeeze(jnp.percentile(pg_out["x_state"], 2.5, axis=0)),
                         "97.5th": jnp.squeeze(jnp.percentile(pg_out["x_state"], 97.5, axis=0))}))

g = sns.FacetGrid(plot_pg, height = 6)
g = g.map(plt.scatter, "time", "meas", color="grey")
plt.plot(plot_df['time'], plot_pg['state'], color='black')
plt.plot(plot_df['time'], plot_pg['med'], color='deepskyblue')

plt.fill_between(plot_df['time'], plot_pg['2.5th'], plot_pg['97.5th'], color='skyblue', alpha=0.5)

plt.legend(labels=["Measurement", "State: True", "State: Posterior Median"])
plt.xlabel("Time")
plt.ylabel("Value")
plt.show()

Histograms for the sampled values of each parameter are shown below.

plot_pg = pd.DataFrame({"iter": jnp.arange(n_iter),
                         "mu": pg_out['theta'][:,0],
                         "sigma": jnp.exp(pg_out['theta'][:,1]),
                         "tau": jnp.exp(pg_out['theta'][:,2])})
plot_pg = pd.melt(plot_pg, id_vars=['iter'], value_vars=['mu', 'sigma', 'tau'])

hp = sns.displot(
    data=plot_pg, 
    x="value", 
    col="variable",
    kde=True
)
hp.set_titles(col_template="{col_name}")
hp.set(xlabel=None)
# add true parameter values
for ax, theta in zip(hp.axes.flat, jnp.array([mu, sigma, tau])):
    ax.axvline(theta, color="black")

Comparison to True Posterior

To further confirm that the implementation of the sampler is correct, we compare the MCMC output against that of an importance sampler from the exact posterior \(p(\tth \mid \yy_{0:T})\) of the BM model. While this is generally unavailable, for the BM model it can be computed from the exact loglikelihood pfjax.models.BMModel.loglik_exact().

To implement the importance sampler, we first find the mode and quadrature of \(\log p(\tth \mid \yy_{0:T})\). Our importance distribution is then a multivariate normal with mean given by the mode and variance given by some inflation factor of the quadrature.

The mode of the BM log-likelihood function is found with a simple gradient ascent, with the true parameters as initial values. The Hessian is also obtained to calculate the quadrature of the log-likelihood function.

grad_fun = jax.jit(jax.grad(bm_model.loglik_exact, argnums = 1))

# Gradient ascent learning rate
learning_rate = 0.01

theta_mode = theta_true

for i in range(1000):
    grads = grad_fun(y_meas, theta_mode)
    # Update parameters via gradient ascent
    theta_mode = theta_mode + learning_rate * grads

# def hessian(f):
#     return jax.jacfwd(jax.grad(f, argnums = 2), argnums = 2)

# hess = hessian(bm_loglik)(y_meas, dt, theta_mode)
hess = jax.jacfwd(jax.jacrev(bm_model.loglik_exact, argnums=1), argnums=1)(y_meas, theta_mode)
theta_quad = -jnp.linalg.inv(hess)

print(theta_mode)
print(theta_true)
print(theta_quad)
[ 1.0411831  -1.0344263  -0.07254256]
[ 1.         -0.6931472  -0.22314353]
[[ 7.1736779e-03  1.7677188e-03 -1.5563695e-04]
 [ 1.7676934e-03  1.9352004e-01 -1.1542980e-02]
 [-1.5563211e-04 -1.1542980e-02  6.5114573e-03]]

From the mode-quadrature proposal distribution, parameter samples are drawn with an inflation factor of 1.5 for the covariance. After adjusting the weights of each draw via importance sampling (IS), the parameter samples are redrawn accordingly. The resulting kernel density estimates are shown below, with an overlay of those obtained from the particle Gibbs (PG) sampler above.

# Draw from the mode-quadrature distribution
infl = 1.5  # Inflation factor
key, subkey = jax.random.split(key)
draws = random.multivariate_normal(
    subkey, mean=theta_mode, cov=infl*theta_quad, shape=(n_iter,))

logpost = jax.jit(bm_model.loglik_exact)

# Importance sampling with mode-quadrature proposal and target proposal (BM log-likelihood)
logq_x = jsp.stats.multivariate_normal.logpdf(
    draws, mean=theta_mode, cov=infl*theta_quad)
# logp_x = jnp.array([bm_loglik(y_meas, dt, draws[i,:]) for i in range(n_iter)])
logp_x = jax.vmap(
    fun=logpost,
    in_axes=(None, 0)
)(y_meas, draws)

# Get the likelihood ratio and normalize
logw = logp_x - logq_x
prob = pf.utils.logw_to_prob(logw)

# importance sample
key, subkey = jax.random.split(key)
imp_index = jax.random.choice(
    subkey, a=n_iter, p=prob, shape=(n_iter,), replace=True)
theta_imp = draws[imp_index, :]
plot_imp = pd.DataFrame({"iter": jnp.arange(n_iter),
                           "mu": theta_imp[:,0],
                           "sigma": jnp.exp(theta_imp[:,1]),
                           "tau": jnp.exp(theta_imp[:,2])})

plot_imp = pd.melt(plot_imp, id_vars=['iter'], value_vars=['mu', 'sigma', 'tau'])

plot_df = pd.concat([plot_pg, plot_imp], ignore_index=True)
plot_df["method"] = np.repeat(["PG", "IS"], len(plot_pg["variable"]))


hp = sns.displot(
    data=plot_df, 
    x="value", 
    hue="method",
    col="variable",
    kind="kde"
)
hp.set_titles(col_template="{col_name}")
hp.set(xlabel=None)
# add true parameter values
for ax, theta in zip(hp.axes.flat, jnp.array([mu, sigma, tau])):
    ax.axvline(theta, color="black")