diffusion model: Nonlinear Function
Created: August 25, 2022
Modified: August 31, 2022

diffusion model

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

Diffusion models for image generation were independently invented at least twice:

  • in a discrete-time variational inference framework (developed by Jascha Sohl-Dickstein and others)
  • in a continuous-time SDE 'score-matching' framework (developed by Stefano Ermon, Yang Song, etc.)

and these are now considered to be different perspectives on the same model family. I'll start here by trying to explain the Bayesian/VI take.

Denoising diffusion probabilistic models

(reference: Ho, Jain, Abbeel 2020)

Consider the following generative model for images:

  1. Start by sampling Gaussian noise xTN(0,(1+βT)I)x_T \sim N(0, (1 + \beta T) \mathbf{I}), for some positive β\beta (which will turn out to be a per-step noise variance).We could simply use a standard normal, but taking variance 1+BT1 + \Beta T here helps avoid the need for rescaling further down that makes the expressions in the original paper more complicated.
  2. For tT,,1t \in {T, \ldots, 1}, let
    xt1N(μθ(xt,t),Σθ(xt,t))x_{t-1} \sim N(\mu_\theta(x_t, t), \Sigma_\theta(x_t, t))
    for some parameterized functions μ,Σ\mu, \Sigma (we'll discuss some natural choices for these below).

This defines a joint distribution

pθ(x0,,xT)=t=1Tp(xt1xt)p_\theta(x_0, \ldots, x_T) = \prod_{t=1}^T p(x_{t-1} | x_t)

and in particular a marginal distribution pθ(x0)p_\theta(\mathbf{x}_0) on the final image x0\mathbf{x}_0, where x1,,xT\mathbf{x}_1, \ldots, \mathbf{x}_T are taken as latent variables. We can view this as a deep latent Gaussian model (or hierarchical VAE). Given parameters θ\theta it is straightforward to sample from this process to generate images. But how can we train such a system?

Generally we'd think of using variational inference with a multilevel 'encoder' or approximating posterior qϕ(x1,,xTx0)q_\phi(x_1, \ldots, x_T | x_0). The key trick is that we can simplify things by just fixing this to be the iterative application of Gaussian noise:

qβ(x1,,xTx0)=t=1Tqβ(xtxt1)q_\beta(x_1, \ldots, x_T | x_0) = \prod_{t=1}^T q_\beta(x_t | x_{t-1})

where each qβ(xtxt1)N(xt1,βI)q_\beta(x_t | x_{t-1}) \sim N(x_{t-1}, \beta I) simply adds noise to the previous step. This gives the ELBO

L(θ)=Ex1:Tqx0[logpθ(x0,,xT)logq(x1,,xTx0)]=Ex1:Tqx0[logp(xT)logq(xT)LT+t=2T(logpθ(xt1xt)logq(xt1xt,x0))Lt1+logpθ(x0x1)L0]\begin{align*} \mathcal{L}(\theta) &= \mathbb{E}_{x_{1:T}\sim q | x_0}\left[\log p_\theta(x_0, \ldots, x_T) - \log q(x_1, \ldots, x_T | x_0)\right]\\ &= \mathbb{E}_{x_{1:T}\sim q | x_0}\left[\underbrace{\log p(x_T) - \log q(x_T)}_{L_T} + \sum_{t=2}^T \underbrace{\left(\log p_\theta(x_{t-1} | x_t) - \log q(x_{t-1} | x_t , x_0)\right)}_{L_{t-1}} + \underbrace{\log p_\theta(x_0 | x_1)}_{L_0}\right]\\ \end{align*}

Assuming a fixed variance β\beta, the term LTL_T is constant and can be ignored, leaving us with a KL divergence term for each timestep

Lt1=Ext,xt1qx0[logpθ(xt1xt)logq(xt1xt,x0)]=Extqx0[DKL(q(xt1xt,x0)pθ(xt1xt))]\begin{align*} L_{t-1} &= \mathbb{E}_{x_t, x_{t-1} \sim q | x_0} \left[\log p_\theta(x_{t-1} | x_t) - \log q(x_{t-1} | x_t, x_0)\right]\\ &= -\mathbb{E}_{x_t \sim q | x_0} \left[\mathcal{D}_\text{KL}\left( q(x_{t-1} | x_t, x_0) \| p_\theta(x_{t-1} | x_t)\right)\right]\end{align*}

comparing the generative distribution pθ(xt1xt)p_\theta(x_{t-1}|x_t) with the 'forward posterior' q(xt1xt,x0)q(x_{t-1} | x_t, x_0), in expectation over xtq(xtx0)x_t\sim q(x_t | x_0). Since the forward process is a simple Gaussian diffusion, this last expression has the closed form

q(xtx0)=N(x0,βtI)q(x_t | x_0) = \mathcal{N}\left(x_0, \beta t \mathbf{I}\right)

and the posterior can be derived in closed formBy observing that it is proportional to N(xt1;x0,β(t1)I)N(xt;xt1,βI)\mathcal{N}(x_{t-1}; x_0, \beta (t-1) I)\mathcal{N}(x_t; x_{t-1}, \beta I) and applying a multivariate gaussian identity. as

q(xt1xt,x0)=N(μ~(xt,x0),βt1tI)q(x_{t-1} | x_t, x_0) = \mathcal{N}\left(\tilde{\mu}(x_t, x_0), \beta \cdot \frac{t-1}{t} \mathbf{I}\right)

with mean

μ~(xt,x0)=1tx0+t1txt.\tilde{\mu}(x_t, x_0) = \frac{1}{t} \mathbf{x}_0 + \frac{t-1}{t}\mathbf{x}_t.

This suggests that we should take our generative variance to equal that of the target we're comparing against; that is, take Σθ(xt,t)\Sigma_\theta(\mathbf{x}_t, t) to be the constant βt1t\beta \frac{t-1}{t}. This choice simplifies the KL divergence of multivariate normals , so that we derive

Lt1=Eq[t2β(t1)μθ(xt,t)μ~(xt,x0)2],L_{t-1} = -E_q\left[\frac{t}{2\beta(t-1)}\|\mu_\theta(x_t, t) - \tilde{\mu}(x_t, x_0)\|^2\right],

How should we choose μθ\mu_\theta? If we reparameterize to write xt=x0+βtϵ\mathbf{x}_t = \mathbf{x}_0 + \beta t \mathbf{\epsilon} for ϵN(0,I)\mathbf{\epsilon}\sim \mathcal{N}(0, \mathbf{I}), then our target mean becomes

μ~(xt,xtβtϵ)=xtβϵ.\tilde{\mu}(x_t, x_t - \beta t \epsilon) = \mathbf{x}_t - \beta\epsilon.

which suggests a generative model of the form

μθ(xt,t)=xtβϵθ(xt,t)\mu_\theta(x_t, t) = \mathbf{x}_t - \beta \epsilon_\theta(x_t, t)

Under this choice (and using the reparameterization just discussed) the divergence reduces to

Lt1=Eϵ[β(t1)2tϵθ(x0+βtϵ,t)ϵ2],L_{t-1} = -\mathbb{E}_\epsilon\left[\frac{\beta(t-1)}{2 t}\left\|\epsilon_\theta(\mathbf{x}_0 + \beta t \epsilon, t) - \mathbf{\epsilon} \right\|^2\right],

where ϵθ(xt)\epsilon_\theta(\mathbf{x}_t) is now clearly being trained as a 'denoising' model attempting to model a step towards the noise-free image x0\mathbf{x}_0; note also that we pulled out a factor of β2\beta^2 from the squared error. It turns out that we can also write the final L0L_0 term in a form very similar to the other terms,

L0=Ex1q[logpθ(x0x1)]=Ex1q[12βx0μθ(x1,1)212log2πβ]=Eϵ[β2ϵθ(x0+βϵ,1)ϵ2]+constant\begin{align*} L_0 &= \mathbb{E}_{\mathbf{x}_1\sim q} \left[-\log p_\theta(\mathbf{x}_0 | \mathbf{x}_1)\right]\\ &= \mathbb{E}_{\mathbf{x}_1\sim q} \left[\frac{1}{2\beta} \left\|\mathbf{x}_0 - \mu_\theta(\mathbf{x}_1, 1)\right\|^2 - \frac{1}{2}\log 2\pi\beta \right]\\ &= \mathbb{E}_{\epsilon} \left[\frac{\beta}{2} \left\| \epsilon_\theta(\mathbf{x}_0 + \beta\epsilon, 1) - \epsilon\right\|^2\right] + \text{constant}\\ \end{align*}

Training. Now we have a practical training procedure. For each input image x0\mathbf{x}_0, we

  1. Sample a target timestep t{1,T}t \in \{1, T\} uniformly at random.
  2. Sample ϵN(0,I)\epsilon \sim \mathcal{N}(0, \mathbf{I}) Gaussian white noise.
  3. Compute the stochastic loss term
    (θ)=ϵθ(x0+βtϵ,t)ϵ2\ell(\theta) = \left\|\epsilon_\theta(\mathbf{x}_0 + \beta t\epsilon, t) - \epsilon\right\|^2
    comparing the 'denoised' image to the original input.
  4. Take a gradient descent step θθαθ(θ)\theta \leftarrow \theta - \alpha \nabla_\theta \ell(\theta), or the equivalent using your gradient-based optimizer of choice (Adam, etc.)

Scaling. In the above we've defined the forward process as a simple random walk, so that the scale of xtx_t increases with βt\sqrt{\beta t}. As a practical matter (e.g., for numerical stability), we may wish to work with the normalized iterates x~t=xt/1+βt\tilde{\mathbf{x}}_t = \mathbf{x}_t / \sqrt{1 + \beta t} which have unit scale (assuming that x0\mathbf{x}_0 has unit scale), as the paper (Ho, Jain, Abbeel 2020) does. This introduces some scaling factors in the loss, but apparently it works better to ignore them anyway?

Per-timestep variance: Ho et al. use a variance schedule βt\beta_t that can differ across timesteps. This makes the math a bit uglier, but is probably helpful? See the paper for details. Later papers like the classifier-free guidance paper (linked below) formulate this instead in continuous time as a nonuniform distribution over the timestep tt, which seems cleaner to me.

Conditional diffusion

Suppose we want to generate images conditioned on a label or caption yy. The obvious thing to me would be to learn a gradient model ϵθ(xt,t,y)\epsilon_\theta(x_t, t, y) that incorporates the conditioning information. But it seems that people actually do various other things that might or might not be equivalent/related?

Classifier guidance (Dhariwal and Nichol, 2021): Given an image classifier fϕ(yx)f_\phi(y | x), we define the joint distribution p(x,y)=pθ(x)fϕ(yx)p(x, y) = p_\theta(x)f_\phi(y | x), which by fixing yy becomes an unnormalized conditional distribution logp(xy)+logZ\log p(x | y) + \log Z. The gradients of this density wrt xx are just the gradients of logpθ(x)\log p_\theta(x) (which we are estimating as ϵθ(x)\epsilon_\theta(x)) plus the gradient xlogfϕ(yx)\nabla_x \log f_\phi(y | x). Incorporating the latter term in the generation process (with a tunable weight ww) pushes the model towards the appropriate conditional slice of image space.

Classifier-free guidance (Ho and Salimans, 2021): instead of training a separate classifier model, we instead train an unconditional diffusion model logpθ(x)\log p_\theta(x) and a conditional model logpθ(xy)\log p_\theta(x | y); these can in fact be modeled using the same network ϵθ\epsilon_\theta trained by randomly dropping out the conditioning information to ensure that it also learns the unconditional model. Then we generate using the learned 'gradient'

(1+w)ϵθ(x,y)wϵθ(x)=ϵθ(x,y)+w(ϵθ(x,y)ϵθ(x))xlogp(xy)+wlogp(xy)p(x)=xlogp(xy)+wlogp(yx)\begin{align*} (1 + w)\epsilon_\theta(x, y) - w\epsilon_\theta(x) &= \epsilon_\theta(x, y) + w(\epsilon_\theta(x, y) - \epsilon_\theta(x))\\ &\approx \nabla_x \log p(x | y) + w \log \frac{ p(x | y)}{ p(x)}\\ &= \nabla_x \log p(x | y) + w \log p(y | x)\\ \end{align*}

which we interpret as classifier guidance using the classifier p(yx)=p(xy)p(y)p(x)p(y | x) = \frac{p(x | y)p(y)}{p(x)} implicitly defined by the ratio of the conditional and unconditional models (note we can ignore the p(y)p(y) factor since this is constant wrt xx and so does not contribute to the gradient).

Text encoding: for a text2image model, we represent the conditioning text yy by an embedding vector. This can come from a generic pretrained language model (Google's Imagen found that using a large capable language model gives more coherent, logical generations) or from a CLIP model trained to co-embed images and text.

High resolution

latent diffusion super-resolution