software lessons from TFP: Nonlinear Function
Created: April 10, 2021
Modified: April 10, 2021

software lessons from TFP

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.
  • I've been really unhappy about how TFP is developed. It's felt pedantic. I waste a lot of effort thinking about things I don't want to think about. It sometimes feels like wearing a straitjacket.
  • And, it hasn't necessarily led to good software. There's a lot about TFP that just doesn't make sense.
    • non-autobatched JDs are a mess.
    • The MCMC API pretends to be compositional, but isn't really. It's also missing some pretty basic pieces, like mixture kernels.
  • On the other hand, I've become a much better engineer and I do think TFP is generally decently-written code.
  • Some software design choices that I think have panned out well:
    • Distributions and Bijectors are static and immutable.
    • Classes save their parameters, so they can be copied, sliced, reconstructed.
    • Joint distributions in general, and multipart bijectors. Manipulating structured samples is so nice.
    • Priorities: numerical stability first.
    • Methods convert everything to Tensor, so we can assume consistent types.
    • Name scopes are generally good, but we fuck them up a lot.
    • Public-calls-private pattern. Subclasses can override behavior but benefit from boilerplate being done only once in the public class.
    • The JAX backend. It complicates writing code, but it's greatly increased the life of TFP. I would find it harder to justify investing effort if TFP were TF-only.
  • Software design choices that have not worked out, or not justified the effort:
    • Explicit tape safety. It would have been so much easier just to do copy on method invocation. We still get lightweight construction, and automatic tape safety.
    • Validate-args should have been True by default.
    • Bijector caching is useful, but not worth the headaches. sample_and_log_prob is much easier to implement and to reason about. And it doesn't just make gradients disappear for no reason.
    • Separate implementations of event_shape and event_shape_tensor, etc., and in general, distinct codepaths for static and dynamic shape. All shapes should just be prefer-static.
  • Some coding style choices that have panned out well:
    • Extensive docstrings, with clear documentation of shapes.
    • Math as Python code.
    • Long names with no Greek.
    • Named args from second args onward.