Post

FSDP tutorial for JAX NNX

If you find yourself with the daunting task of implementing production-level FSDP in JAX NNX, then this tutorial is for you. This notebook will guide you step by step through the process.

You will learn how to implement a fully working FSDP on TPU — with all critical operations JIT compiled — that evenly shards all weights across the devices together with DDP. Additionally, you will see how to use distributed checkpointing to save to/restore from disk or GCP bucket via Orbax, set up reproducible nnx.Rngs for noise generation and dropout, and maintain an EMA model that is also sharded. You can find the full codebase and executable notebook here.

Let’s begin.