Automatic Structured Variational Inference: Nonlinear Function
Created: May 29, 2020
Modified: June 03, 2020

Automatic Structured Variational Inference

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.
  • research ideas
  • The ASVI paper proposes adding a learnable constant to the mean (a canonical parameter) of each Gaussian distribution. They say they mean to exploit conjugacy, which uses the natural parameter and works more generally in exponential families, but IIRC the only equation they write uses the mean of a Gaussian.
  • I guess their general claim is that they don't assume that prior or posterior are exponential families. Instead, we just add a learnable value to whatever the prior parameters are. If the prior is conjugate, then there is some setting of its parameters such that we'll get the correct posterior, so of course there exists some offset to the prior parameters that we can learn. Fine. But I think we need proper exponential-family support. Why?
    • Think of Gaussians. Marginalization is naturally represented in mean parameters (covariances), but conditioning is naturally represented in natural parameters (precision).
    • Consider a Gaussian chain wN(0,1);xN(w,1);yN(x,1)w \sim N(0, 1); x \sim N(w, 1); y \sim N(x, 1). Suppose we observe yy. The posterior on (w,x)(w, x) is jointly Gaussian.
  • It seems like actually working on natural param space should be better, because it helps exploit conjugacy. See this email:
    • I finally had a chance to read this paper over the weekend and wrote up some thoughts below. tl;dr I think there's valuable stuff in it, but I was also a bit confused and it might(?) be somewhat broken as written (I'm not an exponential family expert though). Thoughts welcome.
    • The core idea is that, given a model
    • p(z, x) = p(z[1]; θ[1]) p(z[n] | z[1], …, z[n-1];  θ[n]) * p(x , z)
    • in which the latents z are factored into conditional distributions (i.e., a directed graphical model, or 'probabilistic program' as they put it), we can build a variational family by just taking the same conditionals as in the model
    • q(z) = p(z[1]; θ[1]) p(z[n] | z[1], …, z[n-1]; θ[n])
    • but for each one we mix the prior parameters with a variational perturbation:
    •   θ*[i]  = λ[i] θ[i] + (1 - λ[i]) α[i]
    • where (λ,  α) are variational parameters. Each λ[i] is constrained to the unit interval, so the posterior interpolates between the prior (when λ=1) and a mean-field distribution defined by the α's (when λ=0). Note that the prior has structure, while a fully mean-field posterior does not, so one way to think of λ is as representing the amount of structure in the posterior (or equivalently, 1 -  λ represents the amount of independent evidence that informs each latent in the posterior).
    • The motivation as I understand it is that if these were natural parameters of conditionally conjugate exponential families, then the posterior would be in the same family as the prior, but with updated parameter
    •   new_param = old_param + sufficient_statistics
    • so we might as well just assume we're always in this regime and see what happens. This yields a variational family that contains the true posterior as long as:
    •  (1) the model is composed of conjugate conditionals, and
    •  (2) the model contains no colliders (i.e., the posterior has no dependence structure that wasn't already present in the prior).
    • and otherwise gives a posterior approximation that is at least no worse than mean-field (since that's a special case). And the nice thing is that we built it by adding only a linear number of parameters and didn't have to do any serious model analysis (the q is just a locally transformed version of p, which one could implement as an effect handler, or we could manually implement for JDs).
    • What confuses me is the move from the natural parameterization to a mean parameterization, which is what they actually use in the paper. They seem to claim that the mean parameter of a conjugate posterior can always be represented as a (fixed) convex combination of the prior and MLE mean parameters:
    • new_mean_param =  λ old_mean_param + (1 -  λ) sufficient_statistics
    • Maybe I'm missing something obvious, but that doesn't seem to follow at all from the natural parameter case? And they don't cite any source. All of their experiments are on univariate normal distributions, where this does happen to be true (though still suboptimal I think, in that they need λ to learn the implicit ratio of precisions that the natural parameters would have incorporated directly), but in general I think it might only work to do this in natural parameterization---in which case, the paper as written is a bit broken. Does anyone know if there's some standard result I'm missing?
    • It does seem like at least a version of this approach in which we work in natural parameters and just optimize over 'variational sufficient statistics' added to each conditional, would be easy to implement and potentially pretty powerful.
    • There are also interesting connections to the automatic noncentering stuff Maria worked on a couple summers ago: this family seems like it can do qualitatively similar things (essentially, use prior structure as a preconditioner when evidence is weak) and is also more powerful in at least some cases (e.g., it contains the true posterior for linear Gaussian SSMs, whereas our form of partial noncentering did not).
  • Relationship to non-centering:
  • Exercises:
    • Step 1: take a simple Gaussian SSM JD with fixed variances. Write a trainable JD expressing a mean-field VI Gaussian posterior. Then write a JD expressing a structured posterior.
      • This would be something like

def model():
  prior_pm0, prior_prec0 = mean_to_natural(prior_loc0, prior_scale0)
  posterior_prec0 = prior_prec0 + learnable_prec0
  posterior_pm0 = prior_pm0 + learnable_pm0
  loc, scale = natural_to_mean(posterior_pm0, posterior_prec0)
  x0 = yield Normal(loc=loc, scale=scale)

  # Now instead of
  #  x1 = yield Normal(x0, prior_scale1)
  prior_pm1, prior_prec1 = mean_to_natural(x0, prior_scale1)
  posterior_prec1 = prior_prec1 + learnable_prec1
  posterior_pm1 = prior_pm1 + learnable_ pm1
  loc, scale = natural_to_mean(posterior_pm1, posterior_prec1)
  x1 = yield Normal(loc=loc, scale=scale)

  # etc

Plotting this as a 2D posterior (compared to what we get from MCMC) should make clear that the fit is a lot better.

  • Step 2: Write the model function that takes in a JD containing just Gaussians with conditional mean dependencies and automatically transforms it to produce a structured variational model.
  • If we go down the MCMC-preconditioning road, we'd want this to be a bijector. How does that work? We want the bijector that takes the mean-field VI Gaussian dist, and transforms it to the structured dist. (or vice versa). For Gaussians and MVNs this seems easy enough. We can just run the model forward, compute the posterior loc/scale at every sample point, and call into the relevant shift/scale bijector. What about other distributions?
  • Say we have something like a Gamma. Assume we can evaluate the CDF but not the inverse CDF. We want the bijector that transforms from a standard normal to this Gamma.
  • We can transform from the gamma to a uniform by computing the CDF. That would let us implement the inverse of an inference bijector. But the forward relies on the inverse CDF which is presumably hard. There's probably a way to use implicit-reparameterization ideas to get some sort of inference working here but it might not fit naturally in existing APIs.
  • Of course we don't really want a uniform anyway. And it's inference so we only care about approximate reparameterization. We could certainly use a bijector with appropriate tails (Softplus?) to pull the gamma into an unconstrained space and then at least normalize its mean and stddev in that space.