proof of the policy gradient theorem: Nonlinear Function
Created: April 01, 2022 Modified: April 02, 2022
proof of the policy gradient theorem
This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.
The policy gradient theorem says that
∇θVπθ∝Es∼πθ[a∑(∇θπθ(a∣s))Qπθ(s,a)]
For simplicity we'll assume a fixed initial state s0 and fixed-length finite trajectories, but the result can be generalized to discounted-reward or average-reward notions of value in the continuing setting.
It seems like there are two proof approaches:
We can either recursively expand the value as a sequence of Q-values, beginning at the start state, or
We can work with rewards of entire trajectories R(τ), leading to a product of T2 interactions between policy choices and rewards, and then remove the 'acausal' terms because they have expectation zero.
Sutton and Barto do the first, while John Schulman does the second in my deep RL notes.
Recursive expansion proof
Reproducing the Sutton and Barto approach for completeness: they expand out the value Vπθ(s0) to an expectation over q-values ∑a0πθ(a0∣s0)Qπθ(s0,a0), push the gradient inside and then apply the product rule
so that plugging this all back in we get the policy gradient as a sum in which each term ∇θπθ(a∣s)Qπθ(s,a) has weight equal to the expected number of times that state s is visited under policy πθ, which is just off from the expression at the top (which takes an expectation over states) by a constant factor of T.
which we can see as a variant of the theorem above in which the expectation over state visitations is split into an expectation over trajectories and the empirical sum over states within a trajectory.
Note that it's valid to move between sums of empirical rewards ∑t′>tγtrt and their expectations---the Q values Qπθ(st,at)---because this all happens inside of the expectation over st,at. I was worried that in general rt′>t is not independent of ∇θlogπθ(at∣st), but that doesn't matter because the latter is just a constant once we condition on (st,at).