Training · GPU precision

Mixed Precision and MFU

Training big models on GPUs requires using lower-than-fp32 precision wherever you can. The two relevant data types are fp16 (older, with a tricky underflow problem) and bfloat16 (newer, simpler). Karpathy's build-nanogpt uses bf16 for the forward and backward pass and fp32 for the optimizer state — the standard recipe.

Why mixed precision

A 124M-parameter GPT-2 in fp32 uses 496MB of weight storage. In bf16, 248MB. Plus the gradients (same size), plus the AdamW optimizer state (which is 2 * weights in fp32). Halving the working precision roughly halves memory pressure, but more importantly, GPU tensor cores are massively faster at bf16/fp16 than fp32.

On an A100, bf16 matmul throughput is 312 TFLOPS vs 19.5 TFLOPS for fp32 — a 16× speed difference. On an H100, the gap is even wider. Using fp32 leaves an order of magnitude of performance on the table.

bf16
312 TFLOPS fp32
19.5 TFLOPS
A100 matmul throughput, bf16 vs fp32 — a 16× speed difference.

bf16 vs fp16

Both are 16-bit, but the bits are split differently.

fp16

S E E E E E M M M M M M M M M M
Sign
1 bit
Exponent
5 bits
Mantissa
10 bits

bf16

S E E E E E E E E M M M M M M M
Sign
1 bit
Exponent
8 bits
Mantissa
7 bits

bf16 has fp32's exponent range with worse precision. fp16 has fp32's precision with much smaller range. For deep learning, range matters more than precision — activations and gradients can span many orders of magnitude, but you don't need 10 bits of mantissa to learn well. bf16 just works; fp16 requires "loss scaling" to multiply the loss by a large constant before backprop and divide gradients by the same constant on the way down, to avoid underflowing to zero.

build-nanogpt uses bf16 via torch.autocast:

with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
    logits, loss = model(x, y)

Inside that context, eligible ops (matmul, conv) run in bf16; others (softmax, normalization) stay in fp32. PyTorch handles the casting automatically.

The line just above sets fp32 matmuls to use TF32 (Tensor Float 32, also called "TF32") — A100's accelerated 19-bit format:

torch.set_float32_matmul_precision('high')

This gives a free 8× speedup on fp32 matmuls without changing user code.

fp32 master params

In llm.c, the explicit pattern is:

Working params bf16 (or fp16). What the forward and backward operate on.
Master params fp32. Updated by AdamW.
Recompute working After each update, bf16 working copy is recomputed from the master via stochastic rounding.
// from llm.c/llmc/adamw.cuh
float old_param = (master_params_memory != NULL) ? master_params_memory[idx] : (float)params_memory[idx];
float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param));
stochastic_rounding(param, &params_memory[idx], seed);
if (master_params_memory != NULL) { master_params_memory[idx] = param; }

The master copy keeps long-term numerical fidelity across thousands of small updates. Without it, repeated bf16 updates would accumulate rounding error and drift.

MFU: Model Flops Utilization

How fast are you really training? "tok/sec" is one measure; MFU is the more meaningful one. From nanoGPT/model.py:

def estimate_mfu(self, fwdbwd_per_iter, dt):
    """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
    N = self.get_num_params()
    cfg = self.config
    L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
    flops_per_token = 6*N + 12*L*H*Q*T
    flops_per_fwdbwd = flops_per_token * T
    flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
    flops_achieved = flops_per_iter * (1.0/dt)   # per second
    flops_promised = 312e12   # A100 bf16 peak
    mfu = flops_achieved / flops_promised
    return mfu
6 * N the standard "6 flops per parameter per token" (2 for forward, 4 for backward, all matmuls counted). 12 * L * H * Q * T the attention's T^2 cost — outside the per-parameter accounting because attention work scales with sequence length, not parameter count.

Multiply by the number of tokens processed, divide by wall-clock seconds, divide by the GPU's peak rated FLOPS. A well-optimized training run hits 40-50% MFU on an A100. The PaLM paper hit 46.2% MFU on 6144 TPU chips, which is considered the state of the art for that era.

MFU < 10% means something is wrong — usually CPU bottleneck (dataloader can't keep up), bad attention kernel, or unfused ops.

Related