Skip to content

pfjax.mvn_bridge

Multivariate normal bridge proposals.

Suppose we have the multivariate normal model

::

       W ~ N(mu_W, Sigma_W)
   X | W ~ N(W + mu_XW, Sigma_XW)
Y | X, W ~ N(AX, Omega).

We are interested in calculating the mean and variance of p(W|Y).

Functions:

Name Description
mvn_bridge_pars

Calculate the unconditional mean of Y, the variance of Y and the covariance between W and Y.

mvn_bridge_mv

Calculate the mean and variance of p(W|Y).

mvn_bridge_pars(mu_W, Sigma_W, mu_XW, Sigma_XW, A, Omega)

Calculate the unconditional mean of Y, the variance of Y and the covariance between W and Y.

Parameters:

Name Type Description Default
mu_W

Mean of W.

required
Sigma_W

Variance of W.

required
mu_XW

Mean fo X|W.

required
Sigma_XW

Variance of X|W.

required
A

Matrix to obtain mean of Y given X,W.

required
Omega

Variance of Y|X,W.

required

Returns:

Name Type Description
Tuple
  • mu_Y - Unconditional mean of Y.
  • AS_W - Covariance of W, Y.
  • Sigma_Y - Unconditional variance of Y.
Source code in src/pfjax/mvn_bridge.py
def mvn_bridge_pars(mu_W, Sigma_W, mu_XW, Sigma_XW, A, Omega):
    r"""
    Calculate the unconditional mean of Y, the variance of Y and the covariance between W and Y.

    Args:
        mu_W: Mean of W.
        Sigma_W: Variance of W.
        mu_XW: Mean fo X|W.
        Sigma_XW: Variance of X|W.
        A: Matrix to obtain mean of Y given X,W.
        Omega: Variance of Y|X,W.

    Returns:
        Tuple:

        - **mu_Y** - Unconditional mean of Y.
        - **AS_W** - Covariance of W, Y.
        - **Sigma_Y** - Unconditional variance of Y.

    """
    mu_Y = jnp.matmul(A, mu_W + mu_XW)
    AS_W = jnp.matmul(A, Sigma_W)
    Sigma_Y = jnp.linalg.multi_dot([A, Sigma_W + Sigma_XW, A.T]) + Omega
    return mu_Y, AS_W, Sigma_Y

mvn_bridge_mv(mu_W, Sigma_W, mu_Y, AS_W, Sigma_Y, Y)

Calculate the mean and variance of p(W|Y).

Parameters:

Name Type Description Default
mu_W

Mean of W.

required
Sigma_W

Variance of W.

required
mu_Y

Unconditional mean of Y.

required
AS_W

Covariance of Y, W.

required
Sigma_Y

Unconditional variance of Y.

required
Y

Observed Y.

required

Returns:

Name Type Description
Tuple
  • mu_WY - Mean of W|Y.
  • Sigma_WY - Variance of W|Y.
Source code in src/pfjax/mvn_bridge.py
def mvn_bridge_mv(mu_W, Sigma_W, mu_Y, AS_W, Sigma_Y, Y):
    r"""
    Calculate the mean and variance of `p(W|Y)`.

    Args:
        mu_W: Mean of W.
        Sigma_W: Variance of W.
        mu_Y: Unconditional mean of Y.
        AS_W: Covariance of Y, W.
        Sigma_Y: Unconditional variance of Y.
        Y: Observed Y.

    Returns:
        Tuple:

        - **mu_WY** - Mean of W|Y.
        - **Sigma_WY** - Variance of W|Y.

    """
    # solve both linear systems simultaneously
    # sol = jnp.matmul(AS_W.T, jnp.linalg.solve(
    #     Sigma_Y, jnp.hstack([jnp.array([Y-mu_Y]).T, AS_W])))

    Sigma_chol = jsp.linalg.cho_factor(Sigma_Y, True)
    sol = jnp.matmul(
        AS_W.T,
        jsp.linalg.cho_solve(Sigma_chol, jnp.hstack([jnp.array([Y - mu_Y]).T, AS_W])),
    )
    return mu_W + jnp.squeeze(sol[:, 0]), Sigma_W - sol[:, 1:]