Skip to content

pfjax.models.bm_model

Classes:

Name Description
BMModel

Brownian motion state space model.

BMModel

Bases: BaseModel

Brownian motion state space model.

The model is:

::

x_0 ~ pi(x_0) \propto 1
x_t ~ N(x_{t-1} + mu dt, sigma^2 dt)
y_t ~ N(x_t, tau^2)

Parameters:

Name Type Description Default
dt

Interobservation time.

required
unconstrained_theta

Whether or not to use the regular parameters scale theta = (mu, sigma, tau) or the unconstrained scale theta = (mu, log(sigma), log(tau)).

False

Methods:

Name Description
state_lpdf

Calculates the log-density of p(x_curr | x_prev, theta).

state_sample

Samples from x_curr ~ p(x_curr | x_prev, theta).

meas_lpdf

Log-density of p(y_curr | x_curr, theta).

meas_sample

Sample from p(y_curr | x_curr, theta).

pf_init

Get particles for initial state x_init = x_state[0].

loglik_exact

Marginal loglikelihood of the BM model.

Source code in src/pfjax/models/bm_model.py
class BMModel(BaseModel):
    r"""
    Brownian motion state space model.

    The model is:

    ::

        x_0 ~ pi(x_0) \propto 1
        x_t ~ N(x_{t-1} + mu dt, sigma^2 dt)
        y_t ~ N(x_t, tau^2)


    Args:
        dt: Interobservation time.
        unconstrained_theta: Whether or not to use the regular parameters scale `theta = (mu, sigma, tau)` or the unconstrained scale `theta = (mu, log(sigma), log(tau))`.
    """

    def __init__(self, dt, unconstrained_theta=False):
        super().__init__(bootstrap=True)
        self._dt = dt
        self._unconstrained_theta = unconstrained_theta

    def _constrain_theta(self, theta):
        r"""
        Convert `theta` to the constrained scale.
        """
        return jnp.array([theta[0], jnp.exp(theta[1]), jnp.exp(theta[2])])

    def state_lpdf(self, x_curr, x_prev, theta):
        r"""
        Calculates the log-density of `p(x_curr | x_prev, theta)`.

        Args:
            x_curr: State variable at current time `t`.
            x_prev: State variable at previous time `t-1`.
            theta: Parameter value.

        Returns:
            The log-density of `p(x_curr | x_prev, theta)`.
        """
        if self._unconstrained_theta:
            theta = self._constrain_theta(theta)
        mu = theta[0]
        sigma = theta[1]
        return jnp.squeeze(
            jsp.stats.norm.logpdf(x_curr, loc=x_prev + mu * self._dt,
                                  scale=sigma * jnp.sqrt(self._dt))
        )

    def state_sample(self, key, x_prev, theta):
        r"""
        Samples from `x_curr ~ p(x_curr | x_prev, theta)`.

        Args:
            key: PRNG key.
            x_prev: State variable at previous time `t-1`.
            theta: Parameter value.

        Returns:
            Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
        """
        if self._unconstrained_theta:
            theta = self._constrain_theta(theta)
        mu = theta[0]
        sigma = theta[1]
        x_mean = x_prev + mu * self._dt
        x_sd = sigma * jnp.sqrt(self._dt)
        return x_mean + x_sd * random.normal(key=key)

    def meas_lpdf(self, y_curr, x_curr, theta):
        r"""
        Log-density of `p(y_curr | x_curr, theta)`.

        Args:
            y_curr: Measurement variable at current time `t`.
            x_curr: State variable at current time `t`.
            theta: Parameter value.

        Returns:
            The log-density of `p(y_curr | x_curr, theta)`.
        """
        if self._unconstrained_theta:
            theta = self._constrain_theta(theta)
        tau = theta[2]
        return jnp.squeeze(
            jsp.stats.norm.logpdf(y_curr, loc=x_curr, scale=tau)
        )

    def meas_sample(self, key, x_curr, theta):
        r"""
        Sample from `p(y_curr | x_curr, theta)`.

        Args:
            key: PRNG key.
            x_curr: State variable at current time `t`.
            theta: Parameter value.

        Returns:
            Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
        """
        if self._unconstrained_theta:
            theta = self._constrain_theta(theta)
        tau = theta[2]
        return x_curr + tau * random.normal(key=key)

    def pf_init(self, key, y_init, theta):
        r"""
        Get particles for initial state `x_init = x_state[0]`. 

        Samples from an importance sampling proposal distribution

        ::

            x_init ~ q(x_init) = q(x_init | y_init, theta)

        and calculates the log weight

        ::

            logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)

        In this case we have an exact proposal 

        ::

                  q(x_init) = p(x_init | y_init, theta)   
            <=>   x_init ~ N(y_init, tau)

        Moreover, due to symmetry of arguments we have `q(x_init) = p(y_init | x_init, theta)`, and since `p(x_init | theta) \propto 1` we have `logw = 0`.

        Args:
            key: PRNG key.
            y_init: Measurement variable at initial time `t = 0`.
            theta: Parameter value.

        Returns:

            Tuple:

            - **x_init** - A sample from the proposal distribution for `x_init`.
            - **logw** - The log-weight of `x_init`.
        """
        return self.meas_sample(key, y_init, theta), jnp.zeros(())

    def loglik_exact(self, y_meas, theta):
        r"""
        Marginal loglikelihood of the BM model.

        Actually calculates `log p(y_{1:N} | theta, y_0)`, since for the flat prior on `x_0` the marginal likelihood `p(y_{0:N} | theta)` does not exist.

        Args:
            y_meas: Vector of observations `y_0, ..., y_N`.
            theta: Parameter value.

        Returns:
            float:

            The marginal loglikelihood `log p(y_{1:N} | theta, y_0)`.
        """
        if self._unconstrained_theta:
            theta = self._constrain_theta(theta)
        mu = theta[0]
        sigma2 = theta[1] * theta[1]
        tau2 = theta[2] * theta[2]
        n_obs = y_meas.shape[0]-1  # conditioning on y_0
        t_meas = jnp.arange(1, n_obs+1) * self._dt
        Sigma_y = sigma2 * jax.vmap(lambda t:
                                    jnp.minimum(t, t_meas))(t_meas) + \
            tau2 * (jnp.ones((n_obs, n_obs)) + jnp.eye(n_obs))
        mu_y = y_meas[0] + mu * t_meas
        return jsp.stats.multivariate_normal.logpdf(
            x=jnp.squeeze(y_meas[1:]),
            mean=mu_y,
            cov=Sigma_y
        )

state_lpdf(x_curr, x_prev, theta)

Calculates the log-density of p(x_curr | x_prev, theta).

Parameters:

Name Type Description Default
x_curr

State variable at current time t.

required
x_prev

State variable at previous time t-1.

required
theta

Parameter value.

required

Returns:

Type Description

The log-density of p(x_curr | x_prev, theta).

Source code in src/pfjax/models/bm_model.py
def state_lpdf(self, x_curr, x_prev, theta):
    r"""
    Calculates the log-density of `p(x_curr | x_prev, theta)`.

    Args:
        x_curr: State variable at current time `t`.
        x_prev: State variable at previous time `t-1`.
        theta: Parameter value.

    Returns:
        The log-density of `p(x_curr | x_prev, theta)`.
    """
    if self._unconstrained_theta:
        theta = self._constrain_theta(theta)
    mu = theta[0]
    sigma = theta[1]
    return jnp.squeeze(
        jsp.stats.norm.logpdf(x_curr, loc=x_prev + mu * self._dt,
                              scale=sigma * jnp.sqrt(self._dt))
    )

state_sample(key, x_prev, theta)

Samples from x_curr ~ p(x_curr | x_prev, theta).

Parameters:

Name Type Description Default
key

PRNG key.

required
x_prev

State variable at previous time t-1.

required
theta

Parameter value.

required

Returns:

Type Description

Sample of the state variable at current time t: x_curr ~ p(x_curr | x_prev, theta).

Source code in src/pfjax/models/bm_model.py
def state_sample(self, key, x_prev, theta):
    r"""
    Samples from `x_curr ~ p(x_curr | x_prev, theta)`.

    Args:
        key: PRNG key.
        x_prev: State variable at previous time `t-1`.
        theta: Parameter value.

    Returns:
        Sample of the state variable at current time `t`: `x_curr ~ p(x_curr | x_prev, theta)`.
    """
    if self._unconstrained_theta:
        theta = self._constrain_theta(theta)
    mu = theta[0]
    sigma = theta[1]
    x_mean = x_prev + mu * self._dt
    x_sd = sigma * jnp.sqrt(self._dt)
    return x_mean + x_sd * random.normal(key=key)

meas_lpdf(y_curr, x_curr, theta)

Log-density of p(y_curr | x_curr, theta).

Parameters:

Name Type Description Default
y_curr

Measurement variable at current time t.

required
x_curr

State variable at current time t.

required
theta

Parameter value.

required

Returns:

Type Description

The log-density of p(y_curr | x_curr, theta).

Source code in src/pfjax/models/bm_model.py
def meas_lpdf(self, y_curr, x_curr, theta):
    r"""
    Log-density of `p(y_curr | x_curr, theta)`.

    Args:
        y_curr: Measurement variable at current time `t`.
        x_curr: State variable at current time `t`.
        theta: Parameter value.

    Returns:
        The log-density of `p(y_curr | x_curr, theta)`.
    """
    if self._unconstrained_theta:
        theta = self._constrain_theta(theta)
    tau = theta[2]
    return jnp.squeeze(
        jsp.stats.norm.logpdf(y_curr, loc=x_curr, scale=tau)
    )

meas_sample(key, x_curr, theta)

Sample from p(y_curr | x_curr, theta).

Parameters:

Name Type Description Default
key

PRNG key.

required
x_curr

State variable at current time t.

required
theta

Parameter value.

required

Returns:

Type Description

Sample of the measurement variable at current time t: y_curr ~ p(y_curr | x_curr, theta).

Source code in src/pfjax/models/bm_model.py
def meas_sample(self, key, x_curr, theta):
    r"""
    Sample from `p(y_curr | x_curr, theta)`.

    Args:
        key: PRNG key.
        x_curr: State variable at current time `t`.
        theta: Parameter value.

    Returns:
        Sample of the measurement variable at current time `t`: `y_curr ~ p(y_curr | x_curr, theta)`.
    """
    if self._unconstrained_theta:
        theta = self._constrain_theta(theta)
    tau = theta[2]
    return x_curr + tau * random.normal(key=key)

pf_init(key, y_init, theta)

Get particles for initial state x_init = x_state[0].

Samples from an importance sampling proposal distribution

::

x_init ~ q(x_init) = q(x_init | y_init, theta)

and calculates the log weight

::

logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)

In this case we have an exact proposal

::

      q(x_init) = p(x_init | y_init, theta)   
<=>   x_init ~ N(y_init, tau)

Moreover, due to symmetry of arguments we have q(x_init) = p(y_init | x_init, theta), and since p(x_init | theta) \propto 1 we have logw = 0.

Parameters:

Name Type Description Default
key

PRNG key.

required
y_init

Measurement variable at initial time t = 0.

required
theta

Parameter value.

required

Returns:

Tuple:

- **x_init** - A sample from the proposal distribution for `x_init`.
- **logw** - The log-weight of `x_init`.
Source code in src/pfjax/models/bm_model.py
def pf_init(self, key, y_init, theta):
    r"""
    Get particles for initial state `x_init = x_state[0]`. 

    Samples from an importance sampling proposal distribution

    ::

        x_init ~ q(x_init) = q(x_init | y_init, theta)

    and calculates the log weight

    ::

        logw = log p(y_init | x_init, theta) + log p(x_init | theta) - log q(x_init)

    In this case we have an exact proposal 

    ::

              q(x_init) = p(x_init | y_init, theta)   
        <=>   x_init ~ N(y_init, tau)

    Moreover, due to symmetry of arguments we have `q(x_init) = p(y_init | x_init, theta)`, and since `p(x_init | theta) \propto 1` we have `logw = 0`.

    Args:
        key: PRNG key.
        y_init: Measurement variable at initial time `t = 0`.
        theta: Parameter value.

    Returns:

        Tuple:

        - **x_init** - A sample from the proposal distribution for `x_init`.
        - **logw** - The log-weight of `x_init`.
    """
    return self.meas_sample(key, y_init, theta), jnp.zeros(())

loglik_exact(y_meas, theta)

Marginal loglikelihood of the BM model.

Actually calculates log p(y_{1:N} | theta, y_0), since for the flat prior on x_0 the marginal likelihood p(y_{0:N} | theta) does not exist.

Parameters:

Name Type Description Default
y_meas

Vector of observations y_0, ..., y_N.

required
theta

Parameter value.

required

Returns:

Name Type Description
float

The marginal loglikelihood log p(y_{1:N} | theta, y_0).

Source code in src/pfjax/models/bm_model.py
def loglik_exact(self, y_meas, theta):
    r"""
    Marginal loglikelihood of the BM model.

    Actually calculates `log p(y_{1:N} | theta, y_0)`, since for the flat prior on `x_0` the marginal likelihood `p(y_{0:N} | theta)` does not exist.

    Args:
        y_meas: Vector of observations `y_0, ..., y_N`.
        theta: Parameter value.

    Returns:
        float:

        The marginal loglikelihood `log p(y_{1:N} | theta, y_0)`.
    """
    if self._unconstrained_theta:
        theta = self._constrain_theta(theta)
    mu = theta[0]
    sigma2 = theta[1] * theta[1]
    tau2 = theta[2] * theta[2]
    n_obs = y_meas.shape[0]-1  # conditioning on y_0
    t_meas = jnp.arange(1, n_obs+1) * self._dt
    Sigma_y = sigma2 * jax.vmap(lambda t:
                                jnp.minimum(t, t_meas))(t_meas) + \
        tau2 * (jnp.ones((n_obs, n_obs)) + jnp.eye(n_obs))
    mu_y = y_meas[0] + mu * t_meas
    return jsp.stats.multivariate_normal.logpdf(
        x=jnp.squeeze(y_meas[1:]),
        mean=mu_y,
        cov=Sigma_y
    )