Training Stability
Training a transformer from scratch is a chain of small details, any of which can quietly break the run. The build-nanogpt training loop and the makemore lectures together form the most honest list I've seen of "things that go wrong, and what to do about them." This page collects the safety net.
The basic stability stack
In rough order of "most essential":
-
Pre-norm in the transformer blockWithout it, residual stream variance blows up with depth and training is fragile to LR.
-
LayerNorm (or RMSNorm) itselfCleans up activation magnitudes layer-by-layer.
-
Linear LR warmup for several hundred to several thousand stepsAdamW's
vestimate needs time to populate. -
Cap global gradient norm at 1.0.
-
Scaled init for residual output projections
-
Avoids the underflow/loss-scaling complications of fp16.
Gradient clipping
# from build-nanogpt/train_gpt2.py
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
This computes the global gradient norm ||g||_2 across all parameters, then if it exceeds the threshold (here, 1.0), rescales the entire gradient vector to have norm 1.0:
g_clipped = g * min(1.0, 1.0 / ||g||_2)
The direction of the update is preserved; only the magnitude is capped. Without clipping, a single huge gradient (from an outlier batch or a numerical instability) can push parameters into a bad region that takes hundreds of normal steps to recover from. Clipping turns "rare catastrophic step" into "rare slightly-truncated step."
norm per step is a free health check: in a healthy run, norm stays in a tight range (often 0.1-0.3 once warmup is done) and spikes are rare. If norm is regularly hitting the cap, something is off.
What the makemore activations lecture teaches
Lecture 4, "Activations & Gradients, BatchNorm," is the closest Karpathy gets to a complete tour of failure modes. The lecture trains a deep MLP on names and shows what happens at each iteration of the design:
The takeaway isn't BatchNorm specifically (which transformers don't use) — it's that the shape of the activation distribution across layers tells you whether your model can learn. If you're debugging a transformer that won't train, dump the activation statistics per layer; if any layer is saturated or near-zero, you've found your problem.
Loss spikes
A common pattern in large-scale LLM training: a single batch causes the loss to jump 10× and then slowly recover over the next thousand steps. Causes:
Bad data
A single web-scraped HTML blob that tokenizes to a wall of identical tokens, producing extreme attention patterns.
fp16 overflow
Somewhere in the forward pass (avoided by using bf16 instead).
Numerical instability
In the attention softmax (avoided by computing softmax in fp32, which torch.autocast does automatically).
The defenses: bf16 instead of fp16, gradient clipping (so the spike doesn't break the optimizer state), and skipping batches whose loss is anomalously high.
Reproducibility
torch.manual_seed(1337)
if torch.cuda.is_available():
torch.cuda.manual_seed(1337)
Karpathy uses seed 1337 everywhere. Seeding once and recording it makes runs reproducible enough to debug.
torch.use_deterministic_algorithms(True) if you really need bitwise reproducibility.
"Becoming a Backprop Ninja"
Lecture 5 is the deep-end version of training-stability training. Karpathy reimplements every backward pass by hand and shows where it's easy to get gradients silently wrong — the LayerNorm backward, the cross-entropy backward, the softmax backward. The exercise pays off in being able to spot in code whether a backward formula matches the forward, which is the kind of thing that auto-grad usually saves you from but occasionally hides.
Related
- learning-rate-schedules — warmup is part of the stability stack
- weight-init — init is part of the stability stack
- layernorm-vs-rmsnorm — normalization is part of the stability stack
- mixed-precision-and-mfu — bf16 vs fp16 choice
- zero-to-hero-arc — the relevant lectures