gradient clipping: Nonlinear Function
Created: April 29, 2023
Modified: April 29, 2023

gradient clipping

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

Why do we clip gradients in deep learning? When is it important and what is the right way to do it?

It seems like the standard recipe used in many transformer models is to clip the global gradient norm to 1.0 (or some constant). That is, letting gg be the gradient vector, we compute clip(g) = g / max(norm(g), 1) so that the resulting gradient norm is at most 1. Some systems distinguish between clipping the global (concatenated) gradient vector (clip_by_global_norm in Optax) versus separately clipping the various components of the gradient (clip_by_block_rms in Optax). Usually it makes more sense to clip by the global norm because:

  1. Ultimately we are running algorithms like SGD that are formulated and analyzed in terms of a single global parameter vector.
  2. Global gradient normalization is invariant to implementation choices about how to split up the parameters into different blocks (eg, holding weights and biases in separate arrays versus combining them in a single big array), so it's a more 'fundamental' operation in this sense.
  3. Clipping globally preserves the direction of the gradient vector, while blockwise clipping can change the direction significantly. This can be useful (todo track down the paper with a learned optimizer that uses only the signs?) but all else equal it's discarding information that you could use to do adaptive optimization.

There's a nice theoretical analysis of gradient clipping in Zhang et al (2019), Why gradient clipping accelerates training. They generalize the assumption of Lipschitz-smooth gradients, requiring that 2f(x)L\|\nabla^2 f(x)\| \le L for all xx, to a finer-gradient criterion in which ff is (L0,L1)(L_0, L_1)-smooth if

2f(x)L0+L1f(x),\|\nabla^2 f(x)\| \le L_0 + L_1 \|\nabla f(x)\|,

where the bound on the Hessian norm is allowed to vary with the gradient norm. This condition captures a range of well-behaved functions whose gradients do not have finite Lipschitz constants, for example, polynomials with degree >2> 2. It seems to encode the intuition that the first-order gradient term is 'more important' than the second-order term? The analysis shows that clipped gradient descent converges under certain conditions in a number of steps proportional to L0+L12/L0L_0 + L_1^2 / L_0, while non-clipped gradient descent requires steps roughly proportional to L1ML_1 M where M=maxx s.t. f(x)f(x0)f(x)M = \max_{x\text{ s.t. }f(x) \le f(x_0)} \|\nabla f(x)\| is the largest gradient norm that the optimizer might encounter.