weighted importance sampling: Nonlinear Function
Created: April 23, 2022
Modified: April 23, 2022

weighted importance sampling

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

Reference: Mahmood et al., 2014. Weighted importance sampling for off-policy learning with linear function approximation

Here's a situation where ordinary importance sampling performs badly. Suppose that pp and qq are very different distributions, so that the weights wiw_i vary wildly, but that ff is constant, say f(x)=1f(x) = 1 for all xx. Clearly the distribution of xx doesn't matter at all to the expectation of a constant function, but ordinary importance sampling in this setting will return the average weight, 1Ni=1Nwi\frac{1}{N}\sum_{i=1}^N w_i, which is a high-variance random quantity. We would be better off using the weighted importance sampling (WIS) estimate

F^(x)=i=1Nwif(xi)i=1Nwi\hat{F}(x) = \frac{\sum_{i=1}^N w_i f(x_i)}{\sum_{i=1}^N w_i}

which 'divides out' the randomness in the weights. Note that the denominator is the same in expectation, E[i=1Nwi]=N\mathbb{E}[\sum_{i=1}^N w_i] = N. This estimator is biased, but is consistent in the limit NN\to\infty and can have much lower variance than the OIS estimator.

In supervised learning: suppose we are minimizing squared loss

L(θ)=i=1N(yifθ(xi))2θ=argminθ  L(θ)\begin{align*} L(\theta) &= \sum_{i=1}^N\left(y_i - f_\theta(x_i)\right)^2\\ \theta^* &= \text{argmin}_\theta \; L(\theta) \end{align*}

Here we have two different random quantities: given a sampled dataset (xi,yi)i=1N(x_i, y_i)_{i=1}^N, we can define an empirical L(θ)L(\theta) for any θ\theta, and we can also define the optimal parameters θ\theta^*.

To simplify things enormously, let's use the 'constant predictor' fθ(x)=θf_\theta(x) = \theta that just predicts a single value for all outputs. We can then solve the minimization problem trivially as the mean θ=E[y]\theta^* = \mathbb{E}[y].

Now suppose that we have labels yy sampled from a conditional distribution q(yx)q(y | x) different from our target distribution p(yx)p(y | x), so we have wi=p(yixi)q(yixi)w_i = \frac{p(y_i | x_i)}{q(y_i | x_i)} (for current purposes we'll assume that the marginal distribution of inputs p(x)p(x) is the same in both cases and so cancels out). How should we estimate the quantity θ\theta^*?

One approach is to apply ordinary importance sampling directly to estimate θ1Ni=1Nwiyi\theta^* \approx \frac{1}{N}\sum_{i=1}^N w_i y_i. Working backwards, we can view this as implicitly minimizing a loss in which each target yiy_i is scaled by the importance weight:

θ~(x)=argminθi=1N(wiyiθ)2=1Niwiyi\begin{align*} \tilde{\theta}(x) &= \text{argmin}_\theta \sum_{i=1}^N\left(w_i y_i - \theta\right)^2\\ &= \frac{1}{N}\sum_i w_i y_i\\ \end{align*}

Another approach is to define and minimize the importance-sampled loss:

θ^(x)=argminFi=1Nwi(yiθ)2=1i=1Nwii=1Nwiyi\begin{align*} \hat{\theta}(x) &= \text{argmin}_F \sum_{i=1}^N w_i \left(y_i - \theta\right)^2\\ &= \frac{1}{\sum_{i=1}^N w_i} \sum_{i=1}^N w_i y_i \end{align*}

Surprisingly we see that this recovers the weighted (WIS) solution for the parameter θ\theta!

Similar math goes through for the case of linear function approximation, where f(xi)=θTϕ(xi)f(x_i) = \theta^T \phi(x_i) (Mahmood et al., 2014.), which also specializes to the tabular case f(xi)=θif(x_i) = \theta_i. Since WIS solutions are often preferred, this suggests a general strategy of addressing shifts in the output distribution (e.g., in off-policy reinforcement learning) by minimizing an importance-weighted loss.

Generalizing to nonlinear cases, one can still define the two objectives

θ~(x)=argminFi=1N(wiyifθ(xi))2θ^(x)=argminFi=1Nwi(yifθ(xi))2\begin{align*} \tilde{\theta}(x) &= \text{argmin}_F \sum_{i=1}^N \left(w_iy_i - f_\theta(x_i)\right)^2\\ \hat{\theta}(x) &= \text{argmin}_F \sum_{i=1}^N w_i \left(y_i - f_\theta(x_i)\right)^2 \end{align*}

and the argument here is that the latter should in general be preferred. I'm not sure I would have ever even thought of doing the former, but considering the simpler linear/constant cases at least gives some intuition for what's going on.