exponential family notes: Nonlinear Function
Created: February 22, 2022
Modified: May 21, 2022

exponential family notes

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

Exponential Families, Conjugacy, Convexity, and Variational Inference

Any parameterized family of probability densities that can be written in the form

p(x)=h(x)exp(η,t(x)A(η))p(\mathbf{x}) = h(\mathbf{x}) \exp\left(\langle \eta, t(\mathbf{x}) \rangle - \mathcal{A}(\eta) \right)

is an exponential family. The important restriction is that the data x\mathbf{x} and the parameter η\eta interact only via the log-linear term exp(η,t(x))\exp\left( \langle \eta, t(\mathbf{x}) \rangle\right). Surprisingly many common distribution families can be written in this form, including (multivariate) Gaussians, gamma, exponential, bernoulli/categorical, Poisson, beta, Dirichlet, Wishart, chi-squared, geometric (see examples below). In this form, we label the terms as

  • sufficient statistics t(x)t(\mathbf{x}): this is some function or vector-valued collection of functions of x\mathbf{x} by which the variable interacts with the parameters η\mathbf{\eta}. They are "sufficient" in that they are sufficient for parameter estimation; if we know the sufficient statistic we may throw away the original observation and still compute (something proportional to) the likelihood. This becomes more interesting when we consider multiple draws from the same exponential family, which will form a new exponential family in which the sufficient statistic may be a very condensed representation of the original observations, i.e., the sufficient statistic for a sum of Bernoulli distributions (coinflips) is simply the total number of heads.
  • Natural parameters η\eta: These index a specific distribution within the family under consideration. They are taken to be "natural" because they directly multiply the sufficient statistic, and will have other nice properties we may explore later. They do not in general correspond to the parameters typically used for a given distribution, i.e., for a multivariate Gaussian the natural parameters are the precision matrix Σ1\Sigma^{-1} and the precision-scaled mean Σ1μ\Sigma^{-1}\mu, as opposed to the more common parameterization in terms of the mean μ\mu and covariance Σ\Sigma (although these correspond to canonical parameters, discussed below).
  • Base measure: the h(x)h(\mathbf{x}) term represents a base measure on x\mathbf{x}. The important condition is that this does not vary with the parameters η\eta; otherwise we could write any distribution family as an exponential family by absorbing its density into x\mathbf{x}.In particular: any single distribution is an exponential family in the trivial sense that we can smuggle it in as a base measure: this reemphasizes that exponential family machinery is about parameterized families of distributions, i.e., the family of all Normal distributions. This is not necessarily a probability measure, i.e., it does not necessarily normalize. Sometimes this is factored into a literal "base measure" ν(x)\nu(x) representing something like Lebesgue or counting measure, and another term ​h(x) representing a density atop that base measure, but these can always be combined into a single measure.
  • Log normalizer A(η)\mathcal{A}(\eta): also called the cumulant function. This is defined to normalize the distribution, i.e., we always have
    A(η)=logdxh(x)exp(η,t(x)).\mathcal{A}(\eta) = \log \int d\mathbf{x} h(\mathbf{x}) \exp( \langle \eta, t(\mathbf{x}) \rangle ).
    A "tractable" exponential family is, informally, one for which this integral has a closed-form solution. Note that, for a given exponential family defined by sufficient statistics t(x)t(\mathbf{x}) and base measure h(x)h(\mathbf{x}), this integral may be finite only for some values of η\eta; these values define the natural parameter space.

An exponential family is regular if the natural parameter space is a nonempty open set. (TODO: when does this not happen?). It is minimal if there are no linear dependencies among the natural parameters or among the sufficient statistics (with respect to the base measure). For example, t(x)=[x,1x]Tt(\mathbf{x}) = [\mathbf{x}, 1-\mathbf{x}]^T would not be the sufficient statistic of a minimal family. Non-minimal families can always be reduced to minimal families by suitable reparameterization.

Exponential Family representations of common distributions

Properties of Exponential Families

The Mean Parameterization

We've introduced exponential families in their natural parameterization, i.e., written as a function of the natural parameter η\eta. The natural parameter is "the coefficient on the sufficient statistics in the log-density", i.e., it's the thing you'd get by taking gradients of the log-density wrt the sufficient statistics.This fact will become useful in automatic conjugacy detection. For many distributions, it's more common to use the mean parameterization, given by the expected sufficient statistics:

μη=Epηt(x)\mu_\eta = E_{p_\eta} t(\mathbf{x})

For example, the family of Gaussian distributions of varying mean and variance has sufficient statistics t(x)=[x,xxT]t(\mathbf{x}) = [\mathbf{x}, \mathbf{x}\mathbf{x}^T], so its expected sufficient statistics are

μη=[E[x],E[xxT]]=[μ,Σ+μμT],\mu_\eta = \left[E[\mathbf{x}], E[\mathbf{x}\mathbf{x}^T]\right] = [\mu, \Sigma + \mu\mu^T],

i.e., the standard mean-and-variance parameterization, with the minor difference that the second moment is uncentered. We write an exponential family in its mean parameterization as

pμ(x=h(x)exp(t(x)Tη(μ)Aμ(μ))p_\mu(\mathbf{x} = h(\mathbf{x}) \exp\left(t(\mathbf{x})^T \eta(\mu) - A_\mu(\mu) \right)

where η(μ)\eta(\mu) denotes the mapping from mean parameters to natural parameters (it's not yet clear what this mapping is, we will explore further below) and

Aμ(μ)=logh(x)exp(t(x)Tη(μ))dxA_\mu(\mu) = \log \int h(\mathbf{x}) \exp\left(t(\mathbf{x})^T \eta(\mu) \right) d \mathbf{x}

is the log normalizer expressed in terms of mean parameters.

The mapping from natural to mean parameters is 1-1 for any minimal exponential family. It's easy to see that for any natural parameters η\eta, the corresponding mean parameters are well-defined by the expectation Epηt(x)E_{p_\eta} t(\mathbf{x}). How would we go the other way? Note that the mapping from mean to natural parameters cannot be a well-defined function if we're in a non-minimal family, i.e., if there are multiple natural-parameter settings that define the same distribution (and hence the same expectations, and in particular, the same expected sufficient statistics, i.e., mean parameters). We'll need some more machinery to understand this mapping; luckily the machinery is also generally useful.

First we present a general approach for computing the expected sufficient statistics E[t(x)]E[t(x)], i.e., the mean parameters μη\mu_\eta. Computing expectations is hard in general, because integrals are hard, but for exponential families it turns out this expectation can be computed as a derivative, specifically, the gradient of the log-normalizer. This is nice because derivatives are easy to compute.

ηA(η)=ηlogh(x)exp(t(x)Tη)dx=exp(A(η))ηh(x)exp(t(x)Tη)dx=exp(A(η))h(x)ηexp(t(x)Tη)dxby dominated convergence=exp(A(η))h(x)exp(t(x)Tη)t(x)dx=pη(x)t(x)dx=Epη[t(x)]\begin{align*} \nabla_\eta A(\eta) &= \nabla_\eta \log \int h(\mathbf{x}) \exp\left(t(\mathbf{x})^T \eta \right) d\mathbf{x}\\ &= \exp(-A(\eta)) \nabla_\eta \int h(\mathbf{x}) \exp\left(t(\mathbf{x})^T \eta \right) d\mathbf{x}\\ &= \exp(-A(\eta)) \int h(\mathbf{x}) \nabla_\eta \exp\left(t(\mathbf{x})^T \eta \right) d\mathbf{x} &\text{by dominated convergence}\\ &= \exp(-A(\eta)) \int h(\mathbf{x}) \exp\left(t(\mathbf{x})^T \eta \right) t(\mathbf{x}) d\mathbf{x}\\ &= \int p_\eta(\mathbf{x}) t(\mathbf{x}) d\mathbf{x}\\ &= E_{p_\eta} [t(\mathbf{x})] \end{align*}

For any tractable exponential family (one for which we have A(η)A(\eta) in closed form), this gives us an easy recipe for converting natural parameters η\eta to mean parameters μ\mu.This relationship goes suprisingly deep, e.g., into connections between exponential families and graphical models. The log-normalizer of a tree-structured graphical model can be computed by one-way message passing, e.g., the forward algorithm in Kalman filters or HMMs. Taking gradients of this log-normalizer with respect to the natural parameters gives us the mean parameters, i.e., the node and edge marginals. This implies that autodiff implements backwards message passing!

How can we go the other way? We'll need a bit more machinery.

  1. Convexity of A(η)A(\eta). This guarantees the gradient is monotonic, thus that we have a 1-1 map.
  2. conjugate duality. what is the dual? what does it mean?
  3. The dual of A(η)A(\eta) is A(μ)=μTη(μ)A(η(μ))A^*(\mu) = \mu^T\eta(\mu) - A(\eta(\mu))
  4. Gradients of the dual map back to η\eta. What does this mean? AA^* already includes the map η(μ)\eta(\mu), so how can we compute it or its gradients if we don't already know the map?

Convexity

First we see that A(η)A(\eta) is convex. This follows immediately from its cumulant-generating properties. We saw above that the gradient of A(η)A(\eta) gives the expected sufficient statistics. It is straightforward to verify that the second derivative (in general, the Hessian) of A(η)A(\eta) is the variance (in general, covariance) of the sufficient statistics, which is nonnegative (positive semidefinite) by definition. This establishes that AA is convex.

We additionally see that AA is strictly convex iff the family is minimal. Suppose that AA is not strictly convex, i.e., the covariance Σ\Sigma admits some nonzero vector a\mathbf{a} such that aTΣa=0\mathbf{a}^T\Sigma\mathbf{a} = 0. Note that aTΣa=varη[aTt(x)]\mathbf{a}^T \Sigma \mathbf{a} = \text{var}_\eta[\mathbf{a}^T t(\mathbf{x})]; for this variance to be zero implies that aTt(x)\mathbf{a}^T t(\mathbf{x}) is deterministic a.e.a.e., but this contradicts the definition of minimality! Meanwhile, if AA is strictly convex, then no such a\mathbf{a} can exist, which implies the family is minimal.

We immediately see that the mapping from natural to mean parameters is 1-1 if and only if we are in a minimal family: if the Hessian is always positive definite, the gradient is always increasing. (this can be made more precise in a multidimensional sense, see Wainwright and Jordan). How do we get this mapping? The answer runs through conjugate duality.

Conjugate duality

The conjugate dual of a function A(η)A(\eta) is defined as

A(μ):=supημTηA(η)A^*(\mu) := \sup_{\eta} \mu^T \eta - A(\eta)

Conjugate duals are defined even on non-convex functions, but the conjugate dual is always convex (see convex duality).

In the context of exponential families, the right side has probabilistic interpretations. It can be interpreted as computing a maximum-likelihood estimate of η\eta given observed sufficient statistics μ\mu. The optimal value is the (negative) entropy of a distribution with parameter η\eta^*.

Conjugacy

Bayesian inference is the generally correct means of updating beliefs represented by probability distributions. We begin with a prior belief p(z)p(\mathbf{z}) about some aspect of the world's state encoded in the variable z\mathbf{z}, and want to update this belief in light of observed evidence x\mathbf{x}. To do this we specify a likelihood p(xz)p(\mathbf{x}|\mathbf{z}) giving the probability of observed x\mathbf{x}, given explanation z\mathbf{z}, and combine this with the prior via Bayes' rule

p(zx)p(xz)p(z)p(\mathbf{z}|\mathbf{x}) \propto p(\mathbf{x}|\mathbf{z}) p(\mathbf{z})

to produce a posterior distribution p(zx)p(\mathbf{z}|\mathbf{x}) representing our new belief given the observed evidence. In general the posterior density will not have any 'nice' closed form, even if the prior and likelihood do, because the implicit normalizing constant given by the integral

p(x)=dz  p(xz)p(z)p(\mathbf{x}) = \int d\mathbf{z}\; p(\mathbf{x}|\mathbf{z}) p(\mathbf{z})

may not be solvable analytically. This is a problem when implementing Bayesian inference on computers, where a system must commit to a particular representation of beliefs---some concrete data structure---and incorporate evidence into that fixed belief representation.Typical developments, including this one, focus on the case where our data structure is a closed-form probability density function. However, this is quite artificial; more generally we could consider belief distributions represented by implicit models defined by sampling processes, families of algorithms for computing densities (closed-form densities are a special case of such algorithms, but we could consider, e.g., all polynomial-time algorithms) or for directly computing measures (not necessarily by integrating over a density), or ??? (interesting research question? the brain does not pass around closed-form densities). Bayesian updating only works on real computers when the posterior admits the same representation as the prior.That is, exact Bayesian updating only works in this circumstance. Human brains almost certainly do not perform exact Bayesian updating or maintain correct Bayesian posterior beliefs. In general, there is not reason to expect that exact Bayesian updating on a conjugate model will perform better on a task of interest than approximate inference in a more appropriate model.

For given likelihood p(xz)p(\mathbf{x}|\mathbf{z}), a conjugate prior is a family of distributions pθ(z)p_\theta(\mathbf{z}), which we will take to be parameterized by θ\theta, such that posterior beliefs under that likelihood remain in the same family as the prior, that is we have p(zx)=pθ(z)p(\mathbf{z}|\mathbf{x}) = p_{\theta'}(\mathbf{z}) for some parameter θ\theta'. There may be many conjugate priors for a given likelihood. For example, the family of all probability distributions is conjugate to every likelihood, but this is not particularly useful since we cannot represent arbitrary probability distributions on computers.

In the case where the likelihood is in the exponential family, with natural parameters determined by some link function η(z)\eta(\mathbf{z}) of the conditioning information,

p(xz)=h(x)exp(η(z),t(x)A(η(z)))p(\mathbf{x}|\mathbf{z}) = h(\mathbf{x}) \exp\left(\langle \eta (\mathbf{z}), t(\mathbf{x}) \rangle - \mathcal{A}(\eta(\mathbf{z})) \right)

we can define an exponential-family prior whose sufficient statistics are, by construction, exactly the information about z\mathbf{z} used to parameterize the likelihood:

pθ(z)=h(z)exp(θ,η(z)A(θ)).p_\theta(\mathbf{z}) = h(\mathbf{z}) \exp\left(\langle \theta, \eta(\mathbf{z}) \rangle - \mathcal{A}(\theta) \right).

This turns out to be a conjugate prior, as we can check:

p(zx)p(xz)pθ(z)=h(x)h(z)exp(η(z),t(x)+θ,η(z)A(θ)A(η(z)))=h(z)exp(θ+t(x),η(z)A(η(z)))=h(z)exp(θ,η(z))\begin{align*} p(\mathbf{z} | \mathbf{x}) &\propto p(\mathbf{x}|\mathbf{z}) p_\theta(\mathbf{z}) \\ &= h(\mathbf{x}) h(\mathbf{z}) \exp\left(\langle \eta (\mathbf{z}), t(\mathbf{x}) \rangle + \langle \theta, \eta(\mathbf{z}) \rangle - \mathcal{A}(\theta) - \mathcal{A}(\eta(\mathbf{z})) \right)\\ &= h(\mathbf{z}) \exp\left(\langle \theta + t(\mathbf{x}), \eta (\mathbf{z}) \rangle - \mathcal{A}(\eta(\mathbf{z})) \right)\\ &= h'(\mathbf{z})\exp\left(\langle \theta', \eta (\mathbf{z}) \rangle \right) \end{align*}

with new parameter θ=θ+t(x)\theta' = \theta + t(\mathbf{x}) and new base measure h(z)=h(z)exp(A(η(z))h'(\mathbf{z}) = h(\mathbf{z}) \exp(-\mathcal{A}(\eta(\mathbf{z})). As we incorporate additional points x1,,xN\mathbf{x}_1, \ldots, \mathbf{x}_N, we get the posterior

p(zx1:N)hN(z)exp(θN,η(z))p(\mathbf{z} | \mathbf{x}_{1:N}) \propto h_N(\mathbf{z}) \exp\left(\left\langle \theta_N, \eta (\mathbf{z}) \right\rangle \right)

where θN=θ+i=1Nt(xi)\theta_N = \theta + \sum_{i=1}^N t(\mathbf{x}_i) aggregates the sufficient statistics of all data points with the prior parameters θ\theta, which we can view as effectively representing the sufficient statistics of "hallucinated" prior data, and hN(z)=h(z)exp(NA(η(z))h_N(\mathbf{z}) = h(\mathbf{z}) \exp(-N \mathcal{A}(\eta(\mathbf{z})) enforces concentration of the base measure as we observe additional data.

Note that the conjugate families produced by this generic procedure may not be tractable. For example, let p(xz)=N(f(z),1)p(\mathbf{x} | \mathbf{z}) = \mathcal{N}( f(\mathbf{z}), 1) where the link function ff is a neural decoder. Then we have something like (TODO be more careful here)

p(xz)=expf(z),xTxA(f(z))p(\mathbf{x} | \mathbf{z}) = \exp{ \langle f(\mathbf{z}), \mathbf{x}^T \mathbf{x} \rangle - \mathcal{A}(f(\mathbf{z})) }

and the conjugate prior on zz has sufficient statistics given by the decoder output f(z)f(\mathbf{z}),

p(z)=exp(η,f(z)A(η)).p(\mathbf{z}) = \exp\left(\langle \eta , f(\mathbf{z}) \rangle - \mathcal{A}(\eta)\right).

Even if this normalizes TODO does this always normalize?, the normalization constant A(η)=dzexp(η,f(z))\mathcal{A}(\eta) = \int d\mathbf{z} \exp\left(\langle \eta , f(\mathbf{z}) \rangle\right) involves integrating through a neural net and this will not in general have any nice closed form.

In general a tractable conjugate prior requires the normalization constant to be easily evaluated which occurs only when the link function is *TODO what are the special cases where conjugate priors do have nice forms?*

Characterization of conjugacy

I believe there is a true result that, any likelihood function that admits a finite-dimensional conjugate prior, and support does not depend on the parameter, must be an exponential family. possible citation: Diaconis and someone, 1978ish

Conditionally conjugate models

A directed model of many variables x1:M\mathbf{x}_{1:M} is considered {\em conditionally conjugate} if all of the complete conditionals p(xix¬i)p(\mathbf{x}_i | \mathbf{x}_{\neg i}) are in the exponential family. This will occur if,these conditions are sufficient but maybe not necessary. I worked them out myself but someone must have written about this. is there a broader characterization? for each variable xi\mathbf{x}_i,

  1. its parent-conditional distribution p(xiΠ(xi))p(\mathbf{x}_i | \Pi( \mathbf{x}_i)) is in the exponential family, i.e.,
    p(xiΠ(xi))h(xi)exp(η(Π(xi)),ti(xi))p(\mathbf{x}_i | \Pi( \mathbf{x}_i)) \propto h(\mathbf{x}_i) \exp( \langle \eta(\Pi( \mathbf{x}_i) ), t_i(\mathbf{x}_i)\rangle )
    for some sufficient statistic ti(xi)t_i(\mathbf{x}_i) and parameter η(Π(xi))\eta(\Pi( \mathbf{x}_i) ), and
  2. all edges to children p(xjxi)p(\mathbf{x}_j | \mathbf{x}_i) are exponential-family likelihoods to which the parent-conditional distribution is conjugate, i.e., are of the form
    p(xjxi,)h(xj)exp(ti(xi)+η(),tj(xj))p(\mathbf{x}_j | \mathbf{x}_i , \ldots) \propto h(\mathbf{x}_j) \exp\left( \langle t_i(\mathbf{x_i}) + \eta(\ldots), t_j(\mathbf{x}_j) \rangle \right)
    where η()\eta(\ldots) represents contributions to the natural parameter from other parents of xj\mathbf{x}_j; treating these as fixed we are left with an exponential-family likelihood for ​xj|xi​.in general we just need that the likelihood is conditionally expfam given the other parents. it's sufficient to have this form, where the likelihood sums natural parameters across parents, but is it necessary? in general we could have arbitrary likelihoods that become conditionally expfam , e.g., one parent is a discrete variable that switches whether the other conditional is gaussian/gamma/whatever. how does this relate to condition 1? in this case we'd have parent-conditionals in the exponential family but the parents could determine not just the parameter but also the sufficient statistic function.

In this case we get many nice properties:

  1. Gibbs sampling updates are straightforward.
  2. We can perform variational message passing which updates an exponential-family factor at each variable, as a function of exponential-family factors at neighboring variables. does this require additional conditions on the link functions to be tractable?

Multi-affine representation

For simplicity consider a chain structured model

p(x1,,xN)=ip(xixi1)p(\mathbf{x}_1, \ldots, \mathbf{x}_N) = \prod_i p(\mathbf{x}_i | \mathbf{x}_{i-1})

that is conditionally conjugate, i.e., the parent-conditional of each node is a conjugate prior for its outgoing edge. Then, suppressing the base measure, we can write the model in the form

i=1Np(xixi1)=i=1Nexp(ηi(xi1),ti(xi)Ai(η(xi1)))=i=1Nexp(ti1(xi1),ti(xi)Ai(ti1(xi1)))\begin{align*} \prod_{i=1}^N p(\mathbf{x}_i | \mathbf{x}_{i-1}) &= \prod_{i=1}^N \exp\left(\langle \eta_{i}(\mathbf{x}_{i-1}), t_i(\mathbf{x}_i) \rangle - \mathcal{A}_i(\eta(\mathbf{x}_{i-1}))\right)\\ &= \prod_{i=1}^N \exp\left(\langle t_{i-1}(\mathbf{x}_{i-1}), t_i(\mathbf{x}_i) \rangle - \mathcal{A}_i(t_{i-1}(\mathbf{x}_{i-1}))\right) \end{align*}

where the identity ηi(xi1)=ti1(xi1)\eta_i(\mathbf{x}_{i-1}) = t_{i-1}(\mathbf{x}_{i-1}) encodes our assumption of conditional conjugacy, i.e., the natural parameters of a child are exactly the sufficient statistics of its parent. Pushing the product inside the exponential, we find a {\em multi-affine} function

p(x1,,xN)exp(η0,t(x)1+i=2Nti1(xi1),ti(xi))),\begin{align*} p(\mathbf{x}_1, \ldots, \mathbf{x}_N) &\propto \exp\left(\langle \eta_0, t(\mathbf{x})_1\rangle + \sum_{i=2}^N \langle t_{i-1}(\mathbf{x}_{i-1}), t_i(\mathbf{x}_i) \rangle)\right), \end{align*}

i.e., although the quantity f(t1(x1),,tN(xN))f(t_1(\mathbf{x}_1), \ldots, t_N(\mathbf{x}_N)) inside the exponential is not jointly affine in all the sufficient statistics (because there are second-order interaction terms), it is individually affine in any one of them, in the same sense that f(a,b,c,d)=ab+bc+cdf(a,b,c,d) = ab + bc + cd is an affine function of any of its arguments when the others are held fixed. Any multi-affine function whose arguments live in vector spaces can be written as an affine function of the tensor product of those vector spaces, so we have

p(x1,,xN)exp(η,t1(x1)tN(xN)))\begin{align*} p(\mathbf{x}_1, \ldots, \mathbf{x}_N) &\propto \exp\left(\langle \eta, t_1(\mathbf{x}_1) \otimes \ldots \otimes t_N(\mathbf{x}_N) \rangle)\right) \end{align*}

where t1(x1)tN(xN)t_1(\mathbf{x}_1) \otimes \ldots \otimes t_N(\mathbf{x}_N) is the

Detecting conjugacy

Suppose we have a model in which the complete conditional for a particular variable xi\mathbf{x}_i is in the exponential family. We might want to detect this, and construct the complete conditional, by examining the computation graph for the model's log joint density. Recall that conditionals (and in particular, the complete conditional for xi\mathbf{x}_i) are proportional to the joint density. So examining the log joint density, we want to identify a set of sufficient statistics---functions of xi\mathbf{x}_i---so that the only interactions between xi\mathbf{x}_i and other model variables are linear in the sufficient statistics. That is, the natural parameters of the complete conditional will be given by the coefficients in the log-joint of the sufficient statistics we identify.

For example, if the log-joint density contains a term sin(xj)logxi\sin(\mathbf{x}_j) \log \mathbf{x}_i, we can write this as the sufficient statistic t(xi)=logxit(\mathbf{x}_i) = \log \mathbf{x}_i and natural parameter η(x¬i)=sin(xj)\eta(\mathbf{x}_{\neg i}) = \sin(\mathbf{x}_j). On the other hand, if the log-joint contains a term sin(xixj)\sin(\mathbf{x}_i\mathbf{x}_j) then we have a nonlinear interaction and the complete conditional is not in the exponential family. More commonly we might encounter a log-joint term of the form (xjtanh(Axi))2-(\mathbf{x}_j - \tanh(A\mathbf{x}_i))^2 (arising from a Gaussian likelihood for xj\mathbf{x}_j with mean defined by a neural-net link function of xi\mathbf{x}_i), in which we can identify a sufficient statistic tanh(Axi)\tanh(A\mathbf{x}_i), but this does not correspond to any tractable exponential family, i.e., an exponential family that we know how to normalize.

Edward implements a version of this approach for computation graphs defined in TensorFlow. Specifically, it

  1. extracts the computation graph for the log-joint with respect to the Markov blanket of xi\mathbf{x}_i
  2. inspects the graph and does some symbolic algebra voodoo (TODO understand the voodoo, can we improve it?) to identify sufficient statistics of xi\mathbf{x}_i. The process might fail at this point if the algebra system fails to find the appropriate rewrite that exposes sufficient statistics.
  3. looks up these sufficient statistics in a dictionary of tractable exponential families, to get the form of the complete conditional (e.g., normal, beta, dirichlet, etc.). The process might fail at this point if the relevant exponential family is not registered with the system.
  4. extracts natural parameters by evaluating the gradient of the log-joint with respect to the sufficient statistics. This works because the log-joint is (by assumption) linear in sufficient statistics, so taking the gradient gives the coefficients, i.e., natural parameters. I believe this could in principle have been done by the symbolic algebra system (trivially in some sense, gradients are symbolic algebra) but I guess it's convenient to exploit automatic differentiation when available.
  5. Constructs the appropriate distribution using its natural parameters. This requires a constructor that accepts natural parameters.

Thoughts on bringing this directly into TF:

  • we would need to add to each distribution (or perhaps somewhere else), some exponential family properties:
  • sufficient statistics, in a form available to a symbolic algebra system. (also note some distributions might have multiple exponential family representations with different parameterizations, e.g., normal with fixed vs unknown variance)
  • constructors that accept natural parameters.
  • some sort of symbolic algebra to extract sufficient statistics.

SAVI

If the energy function logp(x)\log p(\mathbf{x}) is multi-affine, we get some nice properties.

  1. Tractable expectations Eq[logp(x1,,xN)]=logp(Eq[x1],,E[xN])\mathbb{E}_q[\log p(\mathbf{x}_1, \ldots, \mathbf{x}_N)] = \log p(\mathbb{E}_q[\mathbf{x}_1], \ldots, \mathbb{E}[\mathbf{x}_N]).
  2. "Conjugate proximal operators" TODO UNDERSTAND THIS

Natural Gradient VI

Natural Gradient

Gradient ascent doesn't really type-check. The standard update

w(t+1)=w(t)+αwf(w)\mathbf{w}^{(t+1)} =\mathbf{w}^{(t)} + \alpha \nabla_\mathbf{w} f(\mathbf{w})

adds a vector to a gradient. But gradients are not members of the original vector space; they're members of the tangent space. \todo{better understand what goes wrong here. is it that we're choosing a basis?}

The gradient-ascent update is equivalent to the maximization problem

w(t+1)=argmaxwwT[wf(w(t))]12αww(t)2;\mathbf{w}^{(t+1)} = \arg\max_{\mathbf{w}} \mathbf{w}^T \left[\nabla_\mathbf{w} f(\mathbf{w}^{(t)}) \right ] - \frac{1}{2 \alpha} \|\mathbf{w} - \mathbf{w}^{(t)}\|^2;

this can be verified by taking the derivative and setting it to zero. Viewing 1/α1/\alpha as a Laplace multiplier, we interpret this update as finding the step that best aligns with the gradient while remaining within a α\sqrt{\alpha}-ball of the original point, measured in Euclidean (2\ell^2) distance.

However, Euclidean distance is not always the most relevant metric. When w\mathbf{w} can be viewed as parameterizing a family of probability measures pwp_\mathbf{w} -- that is, our objective f(w)f(\mathbf{w}) is really defined in terms of the distribution f(pw)f(p_\mathbf{w}) -- it makes sense to think about optimizing directly over the space of distributions, rather than their parameters. Some examples of when this comes up:

  • Regression: most regression models define a distribution over predictions, sometimes implicitly (e.g., squared loss is equivalent to a Gaussian log-density, and cross-entropy loss is equivalent to a categorical log-density).
  • Variational inference: explicit optimization over a probability distribution that describes a Bayesian posterior.

We can do this by minimizing distance traveled in the Fisher metric,

w(t+1)=argmaxwwT[wf(w(t))]12α(ww(t))TFwt(ww(t)),\mathbf{w}^{(t+1)} = \arg\max_{\mathbf{w}} \mathbf{w}^T \left[\nabla_\mathbf{w} f(\mathbf{w}^{(t)}) \right ] - \frac{1}{2 \alpha} (\mathbf{w} - \mathbf{w}^{(t)})^T \mathcal{F}_\mathbf{w_t} (\mathbf{w} - \mathbf{w}^{(t)}) ,

where

Fwt=Expwt[(wlogpw(x))(wlogpw(x))T]\mathcal{F}_\mathbf{w_t} = E_{\mathbf{x} \sim p_{w_t}}\left[\left(\nabla_\mathbf{w} \log p_\mathbf{w}(\mathbf{x})\right) \left(\nabla_\mathbf{w} \log p_\mathbf{w}(\mathbf{x})\right)^T \right]

is the Fisher information matrix. The Fisher is the covariance of the score function logpw(x)\log p_\mathbf{w}(\mathbf{x}) under the distribution pwp_\mathbf{w}; intuitively it measures the sensitivity of the distribution to movement in parameter space. The Fisher metric (ww(t))TFwt(ww(t))(\mathbf{w} - \mathbf{w}^{(t)})^T \mathcal{F}_\mathbf{w_t} (\mathbf{w} - \mathbf{w}^{(t)}) is a second-order Taylor approximation of KL[pwpw(t)]]KL[p_\mathbf{w} \| p_{\mathbf{w}^{(t)}}]] (it actually doesn't matter which way the KL goes here; it's locally symmetric).

Solving the above maximization problem gives the natural gradient ascent step

w(t+1)=w(t)+αFw(t)1wf(w).\mathbf{w}^{(t+1)} =\mathbf{w}^{(t)} + \alpha \mathcal{F}_{\mathbf{w}^{(t)}}^{-1} \nabla_\mathbf{w} f(\mathbf{w}).

We refer to the scaled gradients ~wf(w):=Fw(t)1wf(w)\tilde{\nabla}_\mathbf{w} f(\mathbf{w}) := \mathcal{F}_{\mathbf{w}^{(t)}}^{-1} \nabla_\mathbf{w} f(\mathbf{w}) as the natural gradient. Note that if we have a stochastic gradient w^f(w)\hat{\nabla_\mathbf{w}} f(\mathbf{w}), i.e., a random variable whose expectation is the true gradient, then we can still compute a stochastic natural gradient since scaling by the Fisher falls outside of the expectation. If we also have a stochastic Fisher matrix, our natural gradient will still be unbiased as long as the Fisher is an independent RV from the original gradient itself (e.g., computed using an independent sample of w\mathbf{w}).

Fisher matrices

There's some stuff to say about Fisher matrices:

  • Instantaneous KL divergence.
  • Alternative form as negative expected Hessian.
  • Various stochastic approximations.
  • Riemannian geometry and interpretation as a metric tensor. (maybe this is more general machinery, but it should be said somewhere in this document).
Natural gradient in exponential families

Suppose our objective is a KL-divergence between two exponential-family terms that are conjugate to each other, e.g., a conditional distribution in a probability model (parent-conditional 'prior') and a variational posterior. The natural (stochastic) gradient

~ηEq[logp(x)qη(x)]=Fη1^ηqη(x)logp(x)qη(x)dx\begin{align*} \tilde{\nabla}_\eta E_q \left[\log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})}\right] &= \mathcal{F}_\eta^{-1} \hat{\nabla}_\eta \int q_\eta(\mathbf{x}) \log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})} d\mathbf{x} \end{align*}

has a nice form. First let's work out the Fisher matrix:

Fη=Eqη[(ηlogqη(x))(ηlogqη(x)T]=Eqη[(t(x)ηA(η))(t(x)ηA(η))T]=Eqη[t(x)t(x)T2t(x)μ(η)T+μ(η)μ(η)T]=Eqη[t(x)t(x)T]μ(η)μ(η)T=cov(t(x))=2A(η)\begin{align*} \mathcal{F}_\eta &= E_{q_\eta} \left[\left(\nabla_\eta \log q_\eta(\mathbf{x})\right) \left(\nabla_\eta \log q_\eta(\mathbf{x}\right)^T\right]\\ &= E_{q_\eta} \left[\left(t(\mathbf{x}) - \nabla_\eta A(\eta) \right) \left(t(\mathbf{x}) - \nabla_\eta A(\eta) \right)^T\right]\\ &= E_{q_\eta} \left[t(\mathbf{x})t(\mathbf{x})^T - 2 t(\mathbf{x}) \mu(\eta)^T + \mu(\eta) \mu(\eta)^T \right]\\ &= E_{q_\eta} \left[t(\mathbf{x})t(\mathbf{x})^T \right] - \mu(\eta) \mu(\eta)^T \\ &= cov(t(\mathbf{x})) = \nabla^2 A(\eta) \end{align*}

Next, we'll work out the gradient. Our general strategy will be to pull out the t(xt(\mathbf{x} terms so we can evaluate the expectation. We'll find that the gradient contains a hidden factor corresponding to the Fisher matrix!

ηqη(x)logp(x)qη(x)dx=ηqη(x)logp(x)qη(x)dx=(ηqη(x))logp(x)qη(x)dx+qη(x)ηlogp(x)qη(x)dx=Eqη[(t(x)ηA(η))logp(x)qη(x)(t(x)ηA(η))]=Eqη[(t(x)ηA(η))(logp(x)qη(x)1)]=Eqη[(t(x)ηA(η))(t(x)T(ηpη)A(ηp)+A(η)1)]=Eqη[(t(x)μ)(t(x)T(ηpη)+K)] for K=A(ηp)+A(η)1=Eqη[t(x)t(x)T(ηpη)+t(x)Kμt(x)T(ηpη)μK]=Eqη[t(x)t(x)T](ηpη)+μKμμT(ηpη)μK=(Eqη[t(x)t(x)T]μμT)(ηpη)=Fη(ηpη)\begin{align*} \nabla_\eta \int q_\eta(\mathbf{x}) \log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})} d\mathbf{x} &= \int \nabla_\eta q_\eta(\mathbf{x}) \log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})} d\mathbf{x}\\ &= \int \left(\nabla_\eta q_\eta(\mathbf{x})\right) \log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})} d\mathbf{x} + \int q_\eta(\mathbf{x}) \nabla_\eta \log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})} d\mathbf{x}\\ &= E_{q_\eta}\left[ \left(t(\mathbf{x}) - \nabla_\eta A (\eta)\right) \log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})} - \left(t(\mathbf{x}) - \nabla_\eta A (\eta)\right) \right]\\ &= E_{q_\eta}\left[ \left(t(\mathbf{x}) - \nabla_\eta A (\eta)\right) \left(\log \frac{p(\mathbf{x})}{q_\eta(\mathbf{x})} -1 \right) \right]\\ &= E_{q_\eta}\left[ \left(t(\mathbf{x}) - \nabla_\eta A (\eta)\right) \left(t(\mathbf{x})^T (\eta_p - \eta) - A(\eta_p) + A(\eta) -1 \right) \right]\\ &= E_{q_\eta}\left[ \left(t(\mathbf{x}) - \mu\right) \left(t(\mathbf{x})^T (\eta_p - \eta) + K \right) \right] \text{ for } K = - A(\eta_p) + A(\eta) -1\\ &= E_{q_\eta}\left[ t(\mathbf{x}) t(\mathbf{x})^T (\eta_p - \eta) + t(\mathbf{x}) K - \mu t(\mathbf{x})^T (\eta_p - \eta) - \mu K \right]\\ &= E_{q_\eta}\left[ t(\mathbf{x}) t(\mathbf{x})^T\right] (\eta_p - \eta) + \mu K - \mu \mu ^T (\eta_p - \eta) - \mu K \\ &= \left( E_{q_\eta}\left[ t(\mathbf{x}) t(\mathbf{x})^T\right] - \mu \mu ^T\right) (\eta_p - \eta) \\ &=\mathcal{F}_\eta (\eta_p - \eta) \end{align*}

Thus we see that the natural gradient F1ηEq[logp(xqη(x]\mathcal{F}^{-1} \nabla_\eta E_q\left[\log \frac{p(\mathbf{x}}{q_\eta(\mathbf{x}}\right] is simply the difference in natural parameters, ηpη\eta_p - \eta! This implies that a natural gradient step with step-size 11 will cause qηq_\eta to match pp in a single step.

A few caveats: this is the natural gradient wrt η\eta. If we take small steps (to account for discretization error in the Fisher metric), natural gradient is invariant to parameterization, so we should have that the variational distribution described by the step wrt η\eta in the η\eta-parameterization is the same as that described by the step wrt μ\mu in the μ\mu-parameterization.

Thoughts from NameRedacted:

  1. Natural gradient wrt the variational parameterization is not necessarily the best idea. Who says we want to take small steps in the Fisher metric? When we optimize a MAP estimate, our variational distribution is a delta and every step is infinitely large in Fisher space, but sometimes this is actually a good start to variational optimization.
Natural gradient in the mean parameterization

Gradients wrt the mean parameters of an exponential family density "look like" natural gradients in some sense? Maybe?

μEq[logp(x)/qμ(x)]=Eq[(μlogqμ(x))(logp(x)/qμ(x))+(μlogp(x)/qμ(x))]=Eq[(μlogqμ(x))(logp(x)/qμ(x)1)]=Eq[(μt(x)Tη(μ)A(η(μ)))(t(x)Tη(μP)A(η(μp))t(x)Tη(μ)+A(η(μ))1)]=Eq[(μt(x)Tη(μ)A(η(μ)))(t(x)T(η(μP)η(μ))+K)] for K=A(η(μ))A(η(μp))1\begin{align*} \nabla_\mu E_q[\log p(x)/q_\mu(x)] &= E_q\left[ \left(\nabla_{\mu} \log q_\mu(x)\right) \left(\log p(x)/q_\mu(x)\right) + \left(\nabla_{\mu}\log p(x)/q_\mu(x)\right) \right]\\ &= E_q\left[ \left(\nabla_{\mu} \log q_\mu(x)\right) \left(\log p(x)/q_\mu(x) - 1\right) \right]\\ &= E_q\left[ \left( \nabla_\mu t(x)^T \eta(\mu) - A(\eta(\mu)) \right) \left(t(x)^T \eta(\mu_P) - A(\eta(\mu_p)) - t(x)^T \eta(\mu) + A(\eta(\mu)) - 1\right)\right]\\ &= E_q\left[ \left( \nabla_\mu t(x)^T \eta(\mu) - A(\eta(\mu)) \right) \left(t(x)^T \left( \eta(\mu_P) - \eta(\mu) \right) + K\right)\right] \text{ for } K = A(\eta(\mu)) - A(\eta(\mu_p)) - 1 \end{align*}

I think I need a math interlude to work out some useful pieces.

μA(η(μ))=Aηημ in the univariate case.\begin{align*} \nabla_\mu A(\eta(\mu)) &= \frac{\partial A}{\partial \eta} \frac{\partial \eta}{\mu} \text{ in the univariate case.} \end{align*}

Here we know Aη=μ\frac{\partial A}{\partial \eta} = \mu. What about ημ\frac{\partial \eta}{\partial \mu}? By the inverse function theorem, this is the reciprocal of μη\frac{\partial \mu}{\partial \eta}, where μ\mu as a function of η\eta is the gradient of the log-normalizer, so ηA(η)η=A(η)2η\frac{\frac{\partial \eta}{\partial A(\eta)}}{\partial \eta} = \frac{\partial A(\eta)}{\partial^2 \eta}. Plugging this back into the original expression,

μA(η(μ))=Aηημ in the univariate case.=μA(η)2η\begin{align*} \nabla_\mu A(\eta(\mu)) &= \frac{\partial A}{\partial \eta} \frac{\partial \eta}{\mu} \text{ in the univariate case.}\\ &= \mu \frac{\partial A(\eta)}{\partial^2 \eta} \end{align*}

In general (waves hands) I think we expect this to be

μA(η(μ))=η2A(η(μ))μ,\nabla_\mu A(\eta(\mu)) = \nabla^2_\eta A(\eta(\mu)) \mu,

i.e., the gradient wrt to mean parameters is the mean, preconditioned by the Fisher information wrt the natural parameters. What does this mean?