Weight Tying
Weight tying is a small architectural trick that saves parameters and
(usually) improves quality: share the weights between the input token
embedding and the output unembedding (lm_head). It's standard
in GPT-2 and Llama. The line that does it is unassuming:
# from nanoGPT/model.py
self.transformer.wte.weight = self.lm_head.weight
That's the whole thing.
What's happening structurally
The model has two (vocab_size, n_embd) weight matrices:
Both have shape:
Weight tying forces them to be the same matrix — one set of parameters, used in two places.
In build-nanogpt/train_gpt2.py:
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
# ...
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# weight sharing scheme
self.transformer.wte.weight = self.lm_head.weight
The order matters slightly — you assign
wte.weight = lm_head.weight so that the buffer is the one
allocated by nn.Linear (which is the right shape for
matmul). After that line, wte.weight and
lm_head.weight are the same tensor (same storage,
same .data pointer), not two copies.
Why share them
Think about what each matrix is doing:
- The embedding
wte - maps token
ito a vectorv_i. After training, this vector represents "what does token i mean." - The unembedding
lm_head - takes a final hidden state
hand computesh @ W^T. The logit for tokeniish @ v_i— the dot product of the hidden state with that token's "meaning vector."
These are the same thing! Both representations live in the same space — the embedding side asks "what dense vector represents this token?" and the unembedding side asks "how aligned is this hidden state with that same dense vector?" Sharing the matrix:
-
Saves parameters. For GPT-2 (124M), the
(50257, 768)matrix is ~38M parameters — about 30% of the model. Without tying, you'd carry two copies. - Improves quality. Empirically demonstrated in Press & Wolf 2017 ("Using the Output Embedding to Improve Language Models") and Inan et al. 2017. The shared representation gets stronger gradient signal from both ends.
nanoGPT/model.py has a wry comment:
# with weight tying when using torch.compile() some warnings get generated:
# "UserWarning: functional_call was passed multiple values for tied weights.
# This behavior is deprecated and will be an error in future versions"
# not 100% sure what this is, so far seems to be harmless. TODO investigate
Even with weight tying being a 6-year-old trick in stable production use, PyTorch's tooling is mildly confused about it. Such is open source.
The "non-embedding" param count
When you see "GPT-2 (124M)" the count actually excludes the position
embeddings. Karpathy's get_num_params:
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
He explicitly notes: token embeddings aren't subtracted, because due to weight tying they ARE the unembedding — they're being used as weights in the final layer. So they should count. Position embeddings are subtracted because they're not weight-tied to anything.
This is a small thing but it gets reported parameter counts to match between papers and implementations.
Llama also ties
llama2.c/model.py:
# share the unembedding parameters with the embedding parameters
self.tok_embeddings.weight = self.output.weight
Same pattern, same reasoning. Worth noting because at inference time in
llama2.c/run.c, the model file can store just one copy of the
matrix and use it both at the embedding lookup and the final logits
matmul. Compact representation.