Modified: May 21, 2022
exponential family notes
This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.Exponential Families, Conjugacy, Convexity, and Variational Inference
Any parameterized family of probability densities that can be written in the form
is an exponential family. The important restriction is that the data and the parameter interact only via the log-linear term . Surprisingly many common distribution families can be written in this form, including (multivariate) Gaussians, gamma, exponential, bernoulli/categorical, Poisson, beta, Dirichlet, Wishart, chi-squared, geometric (see examples below). In this form, we label the terms as
- sufficient statistics : this is some function or vector-valued collection of functions of by which the variable interacts with the parameters . They are "sufficient" in that they are sufficient for parameter estimation; if we know the sufficient statistic we may throw away the original observation and still compute (something proportional to) the likelihood. This becomes more interesting when we consider multiple draws from the same exponential family, which will form a new exponential family in which the sufficient statistic may be a very condensed representation of the original observations, i.e., the sufficient statistic for a sum of Bernoulli distributions (coinflips) is simply the total number of heads.
- Natural parameters : These index a specific distribution within the family under consideration. They are taken to be "natural" because they directly multiply the sufficient statistic, and will have other nice properties we may explore later. They do not in general correspond to the parameters typically used for a given distribution, i.e., for a multivariate Gaussian the natural parameters are the precision matrix and the precision-scaled mean , as opposed to the more common parameterization in terms of the mean and covariance (although these correspond to canonical parameters, discussed below).
- Base measure: the term represents a base measure on . The important condition is that this does not vary with the parameters ; otherwise we could write any distribution family as an exponential family by absorbing its density into .In particular: any single distribution is an exponential family in the trivial sense that we can smuggle it in as a base measure: this reemphasizes that exponential family machinery is about parameterized families of distributions, i.e., the family of all Normal distributions. This is not necessarily a probability measure, i.e., it does not necessarily normalize. Sometimes this is factored into a literal "base measure" representing something like Lebesgue or counting measure, and another term h(x) representing a density atop that base measure, but these can always be combined into a single measure.
- Log normalizer : also called the cumulant
function. This is defined to normalize the distribution, i.e., we always haveA "tractable" exponential family is, informally, one for which this integral has a closed-form solution. Note that, for a given exponential family defined by sufficient statistics and base measure , this integral may be finite only for some values of ; these values define the natural parameter space.
An exponential family is regular if the natural parameter space is a nonempty open set. (TODO: when does this not happen?). It is minimal if there are no linear dependencies among the natural parameters or among the sufficient statistics (with respect to the base measure). For example, would not be the sufficient statistic of a minimal family. Non-minimal families can always be reduced to minimal families by suitable reparameterization.
Exponential Family representations of common distributions
Properties of Exponential Families
The Mean Parameterization
We've introduced exponential families in their natural parameterization, i.e., written as a function of the natural parameter . The natural parameter is "the coefficient on the sufficient statistics in the log-density", i.e., it's the thing you'd get by taking gradients of the log-density wrt the sufficient statistics.This fact will become useful in automatic conjugacy detection. For many distributions, it's more common to use the mean parameterization, given by the expected sufficient statistics:
For example, the family of Gaussian distributions of varying mean and variance has sufficient statistics , so its expected sufficient statistics are
i.e., the standard mean-and-variance parameterization, with the minor difference that the second moment is uncentered. We write an exponential family in its mean parameterization as
where denotes the mapping from mean parameters to natural parameters (it's not yet clear what this mapping is, we will explore further below) and
is the log normalizer expressed in terms of mean parameters.
The mapping from natural to mean parameters is 1-1 for any minimal exponential family. It's easy to see that for any natural parameters , the corresponding mean parameters are well-defined by the expectation . How would we go the other way? Note that the mapping from mean to natural parameters cannot be a well-defined function if we're in a non-minimal family, i.e., if there are multiple natural-parameter settings that define the same distribution (and hence the same expectations, and in particular, the same expected sufficient statistics, i.e., mean parameters). We'll need some more machinery to understand this mapping; luckily the machinery is also generally useful.
First we present a general approach for computing the expected sufficient statistics , i.e., the mean parameters . Computing expectations is hard in general, because integrals are hard, but for exponential families it turns out this expectation can be computed as a derivative, specifically, the gradient of the log-normalizer. This is nice because derivatives are easy to compute.
For any tractable exponential family (one for which we have in closed form), this gives us an easy recipe for converting natural parameters to mean parameters .This relationship goes suprisingly deep, e.g., into connections between exponential families and graphical models. The log-normalizer of a tree-structured graphical model can be computed by one-way message passing, e.g., the forward algorithm in Kalman filters or HMMs. Taking gradients of this log-normalizer with respect to the natural parameters gives us the mean parameters, i.e., the node and edge marginals. This implies that autodiff implements backwards message passing!
How can we go the other way? We'll need a bit more machinery.
- Convexity of . This guarantees the gradient is monotonic, thus that we have a 1-1 map.
- conjugate duality. what is the dual? what does it mean?
- The dual of is
- Gradients of the dual map back to . What does this mean? already includes the map , so how can we compute it or its gradients if we don't already know the map?
Convexity
First we see that is convex. This follows immediately from its cumulant-generating properties. We saw above that the gradient of gives the expected sufficient statistics. It is straightforward to verify that the second derivative (in general, the Hessian) of is the variance (in general, covariance) of the sufficient statistics, which is nonnegative (positive semidefinite) by definition. This establishes that is convex.
We additionally see that is strictly convex iff the family is minimal. Suppose that is not strictly convex, i.e., the covariance admits some nonzero vector such that . Note that ; for this variance to be zero implies that is deterministic , but this contradicts the definition of minimality! Meanwhile, if is strictly convex, then no such can exist, which implies the family is minimal.
We immediately see that the mapping from natural to mean parameters is 1-1 if and only if we are in a minimal family: if the Hessian is always positive definite, the gradient is always increasing. (this can be made more precise in a multidimensional sense, see Wainwright and Jordan). How do we get this mapping? The answer runs through conjugate duality.
Conjugate duality
The conjugate dual of a function is defined as
Conjugate duals are defined even on non-convex functions, but the conjugate dual is always convex (see convex duality).
In the context of exponential families, the right side has probabilistic interpretations. It can be interpreted as computing a maximum-likelihood estimate of given observed sufficient statistics . The optimal value is the (negative) entropy of a distribution with parameter .
Conjugacy
Bayesian inference is the generally correct means of updating beliefs represented by probability distributions. We begin with a prior belief about some aspect of the world's state encoded in the variable , and want to update this belief in light of observed evidence . To do this we specify a likelihood giving the probability of observed , given explanation , and combine this with the prior via Bayes' rule
to produce a posterior distribution representing our new belief given the observed evidence. In general the posterior density will not have any 'nice' closed form, even if the prior and likelihood do, because the implicit normalizing constant given by the integral
may not be solvable analytically. This is a problem when implementing Bayesian inference on computers, where a system must commit to a particular representation of beliefs---some concrete data structure---and incorporate evidence into that fixed belief representation.Typical developments, including this one, focus on the case where our data structure is a closed-form probability density function. However, this is quite artificial; more generally we could consider belief distributions represented by implicit models defined by sampling processes, families of algorithms for computing densities (closed-form densities are a special case of such algorithms, but we could consider, e.g., all polynomial-time algorithms) or for directly computing measures (not necessarily by integrating over a density), or ??? (interesting research question? the brain does not pass around closed-form densities). Bayesian updating only works on real computers when the posterior admits the same representation as the prior.That is, exact Bayesian updating only works in this circumstance. Human brains almost certainly do not perform exact Bayesian updating or maintain correct Bayesian posterior beliefs. In general, there is not reason to expect that exact Bayesian updating on a conjugate model will perform better on a task of interest than approximate inference in a more appropriate model.
For given likelihood , a conjugate prior is a family of distributions , which we will take to be parameterized by , such that posterior beliefs under that likelihood remain in the same family as the prior, that is we have for some parameter . There may be many conjugate priors for a given likelihood. For example, the family of all probability distributions is conjugate to every likelihood, but this is not particularly useful since we cannot represent arbitrary probability distributions on computers.
In the case where the likelihood is in the exponential family, with natural parameters determined by some link function of the conditioning information,
we can define an exponential-family prior whose sufficient statistics are, by construction, exactly the information about used to parameterize the likelihood:
This turns out to be a conjugate prior, as we can check:
with new parameter and new base measure . As we incorporate additional points , we get the posterior
where aggregates the sufficient statistics of all data points with the prior parameters , which we can view as effectively representing the sufficient statistics of "hallucinated" prior data, and enforces concentration of the base measure as we observe additional data.
Note that the conjugate families produced by this generic procedure may not be tractable. For example, let where the link function is a neural decoder. Then we have something like (TODO be more careful here)
and the conjugate prior on has sufficient statistics given by the decoder output ,
Even if this normalizes TODO does this always normalize?, the normalization constant involves integrating through a neural net and this will not in general have any nice closed form.
In general a tractable conjugate prior requires the normalization constant to be easily evaluated which occurs only when the link function is *TODO what are the special cases where conjugate priors do have nice forms?*
Characterization of conjugacy
I believe there is a true result that, any likelihood function that admits a finite-dimensional conjugate prior, and support does not depend on the parameter, must be an exponential family. possible citation: Diaconis and someone, 1978ish
Conditionally conjugate models
A directed model of many variables is considered {\em conditionally conjugate} if all of the complete conditionals are in the exponential family. This will occur if,these conditions are sufficient but maybe not necessary. I worked them out myself but someone must have written about this. is there a broader characterization? for each variable ,
- its parent-conditional distribution
is in the exponential family, i.e.,for some sufficient statistic and parameter , and
- all edges to children are exponential-family likelihoods to which the parent-conditional distribution is conjugate, i.e., are of the form where represents contributions to the natural parameter from other parents of ; treating these as fixed we are left with an exponential-family likelihood for xj|xi.in general we just need that the likelihood is conditionally expfam given the other parents. it's sufficient to have this form, where the likelihood sums natural parameters across parents, but is it necessary? in general we could have arbitrary likelihoods that become conditionally expfam , e.g., one parent is a discrete variable that switches whether the other conditional is gaussian/gamma/whatever. how does this relate to condition 1? in this case we'd have parent-conditionals in the exponential family but the parents could determine not just the parameter but also the sufficient statistic function.
In this case we get many nice properties:
- Gibbs sampling updates are straightforward.
- We can perform variational message passing which updates an exponential-family factor at each variable, as a function of exponential-family factors at neighboring variables. does this require additional conditions on the link functions to be tractable?
Multi-affine representation
For simplicity consider a chain structured model
that is conditionally conjugate, i.e., the parent-conditional of each node is a conjugate prior for its outgoing edge. Then, suppressing the base measure, we can write the model in the form
where the identity encodes our assumption of conditional conjugacy, i.e., the natural parameters of a child are exactly the sufficient statistics of its parent. Pushing the product inside the exponential, we find a {\em multi-affine} function
i.e., although the quantity inside the exponential is not jointly affine in all the sufficient statistics (because there are second-order interaction terms), it is individually affine in any one of them, in the same sense that is an affine function of any of its arguments when the others are held fixed. Any multi-affine function whose arguments live in vector spaces can be written as an affine function of the tensor product of those vector spaces, so we have
where is the
Detecting conjugacy
Suppose we have a model in which the complete conditional for a particular variable is in the exponential family. We might want to detect this, and construct the complete conditional, by examining the computation graph for the model's log joint density. Recall that conditionals (and in particular, the complete conditional for ) are proportional to the joint density. So examining the log joint density, we want to identify a set of sufficient statistics---functions of ---so that the only interactions between and other model variables are linear in the sufficient statistics. That is, the natural parameters of the complete conditional will be given by the coefficients in the log-joint of the sufficient statistics we identify.
For example, if the log-joint density contains a term , we can write this as the sufficient statistic and natural parameter . On the other hand, if the log-joint contains a term then we have a nonlinear interaction and the complete conditional is not in the exponential family. More commonly we might encounter a log-joint term of the form (arising from a Gaussian likelihood for with mean defined by a neural-net link function of ), in which we can identify a sufficient statistic , but this does not correspond to any tractable exponential family, i.e., an exponential family that we know how to normalize.
Edward implements a version of this approach for computation graphs defined in TensorFlow. Specifically, it
- extracts the computation graph for the log-joint with respect to the Markov blanket of
- inspects the graph and does some symbolic algebra voodoo (TODO understand the voodoo, can we improve it?) to identify sufficient statistics of . The process might fail at this point if the algebra system fails to find the appropriate rewrite that exposes sufficient statistics.
- looks up these sufficient statistics in a dictionary of tractable exponential families, to get the form of the complete conditional (e.g., normal, beta, dirichlet, etc.). The process might fail at this point if the relevant exponential family is not registered with the system.
- extracts natural parameters by evaluating the gradient of the log-joint with respect to the sufficient statistics. This works because the log-joint is (by assumption) linear in sufficient statistics, so taking the gradient gives the coefficients, i.e., natural parameters. I believe this could in principle have been done by the symbolic algebra system (trivially in some sense, gradients are symbolic algebra) but I guess it's convenient to exploit automatic differentiation when available.
- Constructs the appropriate distribution using its natural parameters. This requires a constructor that accepts natural parameters.
Thoughts on bringing this directly into TF:
- we would need to add to each distribution (or perhaps somewhere else), some exponential family properties:
- sufficient statistics, in a form available to a symbolic algebra system. (also note some distributions might have multiple exponential family representations with different parameterizations, e.g., normal with fixed vs unknown variance)
- constructors that accept natural parameters.
- some sort of symbolic algebra to extract sufficient statistics.
SAVI
If the energy function is multi-affine, we get some nice properties.
- Tractable expectations .
- "Conjugate proximal operators" TODO UNDERSTAND THIS
Natural Gradient VI
Natural Gradient
Gradient ascent doesn't really type-check. The standard update
adds a vector to a gradient. But gradients are not members of the original vector space; they're members of the tangent space. \todo{better understand what goes wrong here. is it that we're choosing a basis?}
The gradient-ascent update is equivalent to the maximization problem
this can be verified by taking the derivative and setting it to zero. Viewing as a Laplace multiplier, we interpret this update as finding the step that best aligns with the gradient while remaining within a -ball of the original point, measured in Euclidean () distance.
However, Euclidean distance is not always the most relevant metric. When can be viewed as parameterizing a family of probability measures -- that is, our objective is really defined in terms of the distribution -- it makes sense to think about optimizing directly over the space of distributions, rather than their parameters. Some examples of when this comes up:
- Regression: most regression models define a distribution over predictions, sometimes implicitly (e.g., squared loss is equivalent to a Gaussian log-density, and cross-entropy loss is equivalent to a categorical log-density).
- Variational inference: explicit optimization over a probability distribution that describes a Bayesian posterior.
We can do this by minimizing distance traveled in the Fisher metric,
where
is the Fisher information matrix. The Fisher is the covariance of the score function under the distribution ; intuitively it measures the sensitivity of the distribution to movement in parameter space. The Fisher metric is a second-order Taylor approximation of (it actually doesn't matter which way the KL goes here; it's locally symmetric).
Solving the above maximization problem gives the natural gradient ascent step
We refer to the scaled gradients as the natural gradient. Note that if we have a stochastic gradient , i.e., a random variable whose expectation is the true gradient, then we can still compute a stochastic natural gradient since scaling by the Fisher falls outside of the expectation. If we also have a stochastic Fisher matrix, our natural gradient will still be unbiased as long as the Fisher is an independent RV from the original gradient itself (e.g., computed using an independent sample of ).
Fisher matrices
There's some stuff to say about Fisher matrices:
- Instantaneous KL divergence.
- Alternative form as negative expected Hessian.
- Various stochastic approximations.
- Riemannian geometry and interpretation as a metric tensor. (maybe this is more general machinery, but it should be said somewhere in this document).
Natural gradient in exponential families
Suppose our objective is a KL-divergence between two exponential-family terms that are conjugate to each other, e.g., a conditional distribution in a probability model (parent-conditional 'prior') and a variational posterior. The natural (stochastic) gradient
has a nice form. First let's work out the Fisher matrix:
Next, we'll work out the gradient. Our general strategy will be to pull out the terms so we can evaluate the expectation. We'll find that the gradient contains a hidden factor corresponding to the Fisher matrix!
Thus we see that the natural gradient is simply the difference in natural parameters, ! This implies that a natural gradient step with step-size will cause to match in a single step.
A few caveats: this is the natural gradient wrt . If we take small steps (to account for discretization error in the Fisher metric), natural gradient is invariant to parameterization, so we should have that the variational distribution described by the step wrt in the -parameterization is the same as that described by the step wrt in the -parameterization.
Thoughts from NameRedacted:
- Natural gradient wrt the variational parameterization is not necessarily the best idea. Who says we want to take small steps in the Fisher metric? When we optimize a MAP estimate, our variational distribution is a delta and every step is infinitely large in Fisher space, but sometimes this is actually a good start to variational optimization.
Natural gradient in the mean parameterization
Gradients wrt the mean parameters of an exponential family density "look like" natural gradients in some sense? Maybe?
I think I need a math interlude to work out some useful pieces.
Here we know . What about ? By the inverse function theorem, this is the reciprocal of , where as a function of is the gradient of the log-normalizer, so . Plugging this back into the original expression,
In general (waves hands) I think we expect this to be
i.e., the gradient wrt to mean parameters is the mean, preconditioned by the Fisher information wrt the natural parameters. What does this mean?