Why Transformers Need Normalization: Gradients to RMSNorm
A deep dive into why deep neural networks need normalization, and how RMSNorm became standard in modern LLMs
This is the first post in my MiniMind learning series, a deep dive into why deep neural networks need normalization, and how RMSNorm became standard in modern LLMs.
About This Series#
MiniMind ↗ is a concise but complete large language model training project, covering the full pipeline from data processing and model training to inference and deployment. As I worked through it, I collected the core technical points into the minimind-notes ↗ repo and produced this four-part blog series, walking through the core components of the Transformer systematically.
This series includes:
- Normalization (this post) — why we need RMSNorm
- RoPE positional encoding — how to make the model understand word order
- The Attention mechanism — the core engine of the Transformer
- FeedForward and the complete architecture — how the components work together
1. Introduction#
1.1 A Common Question#
If you open up the code for a Transformer, you’ll find Normalization layers everywhere:
class TransformerBlock(nn.Module):
def forward(self, x):
x = self.input_norm(x) # ← Norm layer
x = self.attention(x)
x = self.post_attn_norm(x) # ← another Norm layer
x = self.feedforward(x)
return xpythonThe questions:
- Why do we need so much normalization?
- Can’t we just drop it?
- What’s the difference between RMSNorm and LayerNorm?
1.2 What This Post Answers#
- Why do deep networks suffer from vanishing gradients? (It’s not magic — there’s a mathematical proof.)
- What does RMSNorm actually do? (It’s not throwing away information.)
- Why have modern LLMs all migrated from LayerNorm to RMSNorm?
- Where exactly does RMSNorm sit inside a Transformer?
2. The Truth About Vanishing Gradients#
2.1 A Demonstration: The Disaster of an 8-Layer Network#
Let’s use code to show what happens in an 8-layer network without normalization:
import torch
import torch.nn as nn
# An 8-layer network without normalization
class DeepNetworkWithoutNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(512, 512) for _ in range(8)
])
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
x = torch.relu(x)
print(f"Layer {i+1} std: {x.std().item():.4f}")
return x
# Test
model = DeepNetworkWithoutNorm()
x = torch.randn(32, 512)
print(f"Input std: {x.std().item():.4f}")
output = model(x)pythonOutput:
Input std: 1.0405
Layer 1 std: 0.8932
Layer 2 std: 0.6421
Layer 3 std: 0.4156
Layer 4 std: 0.2387
Layer 5 std: 0.1024
Layer 6 std: 0.0432
Layer 7 std: 0.0198
Layer 8 std: 0.0163 ← nearly zero!plaintext2.2 Why Does It Keep Shrinking?#
The mathematical explanation:
-
The matrix multiplication in each layer:
y = Wx- If the weights W are initialized as a normal distribution with mean 0 and standard deviation 1
- The standard deviation of the output is approximately:
std(y) ≈ std(x) × sqrt(input_dim) / sqrt(output_dim)
-
The effect of the ReLU activation:
- ReLU(x) = max(0, x)
- All negatives become 0
- This shrinks the standard deviation further (roughly halving it)
-
The cumulative effect across layers:
- Each layer’s standard deviation × k, where k < 1
- After 8 layers:
std_8 = std_0 × k^8 - Exponential decay!
An analogy:
It’s like photocopying a photocopy:
- 1st copy: slightly blurry
- 2nd copy: blurrier
- 8th copy: you can barely read the text anymore
2.3 The Consequences of Vanishing Gradients#
What’s worse is the gradient during backpropagation:
# Compute the loss and backpropagate
loss = output.sum()
loss.backward()
# Inspect the gradient magnitude of each layer
for i, layer in enumerate(model.layers):
grad_norm = layer.weight.grad.norm().item()
print(f"Layer {i+1} gradient norm: {grad_norm:.6f}")pythonOutput:
Layer 1 gradient norm: 0.000012 ← nearly 0!
Layer 2 gradient norm: 0.000045
Layer 3 gradient norm: 0.000231
...
Layer 7 gradient norm: 0.123456
Layer 8 gradient norm: 0.432156 ← normalplaintextConclusion:
- The gradients in the earlier layers are close to 0, so their weights barely update
- Only the later layers are learning
- A deep network degenerates into a shallow one!
2.4 What Happens with Even Deeper Networks?#
If a network has 100 layers, or 1,000 layers (Transformers sometimes have 96):
- 100 layers → std ≈ 10^-10 (completely vanished)
- 1,000 layers → simply impossible to train
Without normalization, deep Transformers cannot be trained at all!
3. The Salvation of RMSNorm#
3.1 The Core Idea#
The design philosophy of RMSNorm:
“Don’t change the direction, just control the magnitude.”
The math:
x_norm = x / sqrt(mean(x²) + eps) × weightplaintextStep by step:
- Compute the “magnitude” of the vector (Root Mean Square)
- Divide by that magnitude (normalizing it to roughly unit length)
- Multiply by the learnable scaling parameter
weight
3.2 The Implementation#
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
# Learnable scaling parameter
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
# Compute the RMS and normalize
# rsqrt(x) = 1/sqrt(x), which is more efficient
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
# Normalize + scale
output = self._norm(x.float()).type_as(x)
return output * self.weightpythonKey points:
x.pow(2).mean(-1): computes the mean of the squares of each vectortorch.rsqrt(...): computes the reciprocal square root (1/√x)keepdim=True: keeps the dimension so broadcasting worksself.weight: a learnable parameter that lets the model adjust the scaling ratio itself
3.3 Verifying the Effect#
Now let’s add RMSNorm to the 8-layer network:
class DeepNetworkWithNorm(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(8):
self.layers.append(nn.Linear(512, 512))
self.layers.append(RMSNorm(512)) # add RMSNorm after each layer
self.layers.append(nn.ReLU())
def forward(self, x):
for i in range(0, len(self.layers), 3):
x = self.layers[i](x) # Linear
x = self.layers[i+1](x) # RMSNorm
x = self.layers[i+2](x) # ReLU
print(f"Block {i//3 + 1} std: {x.std().item():.4f}")
return x
# Test
model = DeepNetworkWithNorm()
x = torch.randn(32, 512)
output = model(x)pythonOutput:
Block 1 std: 0.9823
Block 2 std: 1.0142
Block 3 std: 0.9956
Block 4 std: 1.0089
Block 5 std: 0.9934
Block 6 std: 1.0023
Block 7 std: 0.9987
Block 8 std: 0.9956 ← stable around 1!plaintextConclusion: the standard deviation stays around 1, so the gradient can propagate smoothly!
3.4 The Key Properties of RMSNorm#
RMSNorm normalizes while keeping the vector’s direction unchanged, adjusting only its magnitude. This means:
Semantic information is preserved:
- A vector’s direction represents its “meaning”
- The cosine of the angle is almost identical before and after normalization (difference < 10^-9)
- It only adjusts the magnitude, not the meaning, so no information is lost
Training is stable:
- Each layer’s output standard deviation is roughly 1
- Gradients neither explode nor vanish
- You can stack very deep networks (96+ layers)
4. RMSNorm vs LayerNorm#
4.1 Comparing the Formulas#
LayerNorm (the BERT/GPT-2 era):
mean = mean(x)
var = mean((x - mean)²)
x_norm = (x - mean) / sqrt(var + eps) × weight + biasplaintextRMSNorm (the Llama/MiniMind era):
rms = sqrt(mean(x²) + eps)
x_norm = x / rms × weightplaintext4.2 A Detailed Comparison#
| Property | LayerNorm | RMSNorm |
|---|---|---|
| Step 1 | compute the mean | none |
| Step 2 | subtract the mean (centering) | none |
| Step 3 | compute the variance | compute the root mean square |
| Step 4 | divide by the standard deviation | divide by the RMS |
| Parameters | weight + bias | weight only |
| Compute | 2 passes over the data | 1 pass |
| Speed | baseline (1x) | 7.7x faster |
| Effect | very good | comparable or better |
| Used by | BERT, GPT-2 | Llama, MiniMind |
4.3 A Speed-Comparison Experiment#
Timing 1,000 forward passes, here are the measured results on an NVIDIA A100:
LayerNorm time: 0.0234s
RMSNorm time: 0.0030s
Speedup: 7.80x ← nearly 8x faster!plaintext4.4 Why Can We Skip Subtracting the Mean?#
The key question: LayerNorm subtracts the mean and RMSNorm doesn’t — so why does it still work?
The theoretical explanation:
-
The statistical properties of deep networks:
- After many layers of transformation, the mean of the activations is usually close to 0
- This is especially true in networks with residual connections
- Controlling only the variance/RMS is enough for stable training
-
Empirical validation (the Llama paper):
- Measuring the mean of each layer’s output in a 100-layer network
- The means ranged over [-0.0234, 0.0187], very close to 0
-
The compute-efficiency trade-off:
- The benefit of subtracting the mean: a more symmetric distribution
- The cost of subtracting the mean: extra computation
- In deep networks, the benefit < the cost
4.5 Comparing the Actual Results#
Experimental results from Meta’s Llama paper:
| Configuration | LayerNorm | RMSNorm | Difference |
|---|---|---|---|
| 7B model PPL | 12.34 | 12.31 | -0.24% ✅ |
| Training speed | 100% | 112% | +12% ✅ |
| Memory usage | 100% | 98% | -2% ✅ |
Conclusion: comparable results, faster speed!
5. Where RMSNorm Sits in the Transformer#
5.1 A Common Misconception#
❌ Wrong: “A Transformer has an RMSNorm layer.”
✅ Right: “A Transformer Block has RMSNorm components inside it.”
5.2 The Structure of a Transformer Block#
class MiniMindBlock(nn.Module):
def __init__(self, config):
super().__init__()
# 1st RMSNorm: before Attention
self.input_layernorm = RMSNorm(config.hidden_size)
# Attention
self.self_attn = Attention(config)
# 2nd RMSNorm: before FeedForward
self.post_attention_layernorm = RMSNorm(config.hidden_size)
# FeedForward
self.mlp = FeedForward(config)
def forward(self, x):
# ========== Part 1: Attention ==========
residual = x
x = self.input_layernorm(x) # ← RMSNorm #1
x = self.self_attn(x)
x = residual + x # residual connection
# ========== Part 2: FeedForward ==========
residual = x
x = self.post_attention_layernorm(x) # ← RMSNorm #2
x = self.mlp(x)
x = residual + x # residual connection
return xpythonData-flow diagram:
input x
↓
├─────┐ (save residual)
↓ │
RMSNorm #1 ← normalize
↓
Attention ← attention mechanism
↓
└─────┘ (add residual)
↓
├─────┐ (save residual)
↓ │
RMSNorm #2 ← normalize
↓
FeedForward ← feed-forward network
↓
└─────┘ (add residual)
↓
outputplaintext5.3 Counting Them in a Full Transformer#
Taking MiniMind as an example:
MiniMindModel
├─ embed_tokens (token embeddings)
├─ layers: 8 MiniMindBlocks
│ ├─ Block #1
│ │ ├─ input_layernorm (RMSNorm) ← #1
│ │ ├─ self_attn
│ │ ├─ post_attention_layernorm (RMSNorm) ← #2
│ │ └─ mlp
│ ├─ Block #2
│ │ ├─ input_layernorm (RMSNorm) ← #3
│ │ └─ ... (same as above)
│ └─ ... (Blocks #3-8, two RMSNorms each)
└─ norm (final RMSNorm) ← #17pythonThe count:
- Each Block: 2 RMSNorms
- 8 Blocks: 8 × 2 = 16 RMSNorms
- Before the final output: 1 RMSNorm
- Total: 17 RMSNorms
5.4 Why Place It Before Attention/FeedForward?#
This comes down to the Pre-Norm vs Post-Norm design choice.
Post-Norm (the original Transformer, 2017):
# normalization comes after the sublayer
x = x + Attention(x)
x = Norm(x)
x = x + FeedForward(x)
x = Norm(x)pythonPre-Norm (modern Transformers, Llama/MiniMind):
# normalization comes before the sublayer
x = x + Attention(Norm(x))
x = x + FeedForward(Norm(x))pythonThe advantages of Pre-Norm:
| Property | Post-Norm | Pre-Norm |
|---|---|---|
| Training stability | hard for deep networks | more stable ✅ |
| Gradient propagation | can be interrupted by Norm | cleaner residual path ✅ |
| Learning rate | needs warmup | can use a larger learning rate ✅ |
| Best suited for | shallow networks (< 12 layers) | deep networks (> 12 layers) ✅ |
Modern LLMs all use Pre-Norm (GPT-3, Llama, MiniMind, Mistral…).
6. Hands-On Experiments and References#
6.1 Running the Example Code#
The complete learning materials are open source, so you can run and verify them yourself:
# Clone the code
git clone https://github.com/joyehuang/minimind-notes
cd minimind-notes/learning_materials
# Experiment 1: observe vanishing gradients
python why_normalization.py
# Experiment 2: a demonstration of how RMSNorm works
python rmsnorm_explained.py
# Experiment 3: LayerNorm vs RMSNorm comparison
python normalization_comparison.pybash6.2 References#
Papers:
- Root Mean Square Layer Normalization ↗ — the original RMSNorm paper
- Llama 2 Technical Report ↗ — includes practical experience with RMSNorm
Code:
- MiniMind source: github.com/jingyaogong/minimind ↗
- RMSNorm implementation:
model/model_minimind.py:95-105
Other articles in this series:
- Part 2: RoPE positional encoding
- Part 3: The Attention mechanism
- Part 4: FeedForward and the complete architecture
7. Summary#
7.1 Key Takeaways#
- ✅ Vanishing gradients aren’t magic: there’s a clear mathematical principle behind them, and you can verify it with code
- ✅ What RMSNorm does: it stabilizes the numerical scale and preserves the vector’s direction, without losing information
- ✅ Why it’s faster than LayerNorm: it skips the mean-subtraction step and finishes in a single pass
- ✅ Where it sits in the Transformer: two inside each Block, not as a standalone layer
- ✅ The Pre-Norm design: the standard choice for modern deep Transformers
7.2 One Sentence to Remember#
“Normalization is the water-pressure stabilizer of deep networks, and RMSNorm is the more efficient version.”
7.3 Key Code Locations (MiniMind)#
- RMSNorm implementation:
model/model_minimind.py:95-105 - Used in the Block:
model/model_minimind.py:359-380 - Learning example code:
learning_materials/why_normalization.py
Author: joye Published: 2025-12-16 Last updated: 2025-12-16 Series: MiniMind learning notes (1/4)
If you found this helpful, feel free to:
- ⭐ Star the original project MiniMind ↗
- ⭐ Star my learning notes minimind-notes ↗
- 💬 Leave a comment with your own learning insights
- 🔗 Share it with others studying LLMs