Joye Personal Blog

Back

This is the first post in my MiniMind learning series, a deep dive into why deep neural networks need normalization, and how RMSNorm became standard in modern LLMs.

About This Series#

MiniMind is a concise but complete large language model training project, covering the full pipeline from data processing and model training to inference and deployment. As I worked through it, I collected the core technical points into the minimind-notes repo and produced this four-part blog series, walking through the core components of the Transformer systematically.

This series includes:

  1. Normalization (this post) — why we need RMSNorm
  2. RoPE positional encoding — how to make the model understand word order
  3. The Attention mechanism — the core engine of the Transformer
  4. FeedForward and the complete architecture — how the components work together

1. Introduction#

1.1 A Common Question#

If you open up the code for a Transformer, you’ll find Normalization layers everywhere:

class TransformerBlock(nn.Module):
    def forward(self, x):
        x = self.input_norm(x)        # ← Norm layer
        x = self.attention(x)

        x = self.post_attn_norm(x)    # ← another Norm layer
        x = self.feedforward(x)
        return x
python

The questions:

  • Why do we need so much normalization?
  • Can’t we just drop it?
  • What’s the difference between RMSNorm and LayerNorm?

1.2 What This Post Answers#

  • Why do deep networks suffer from vanishing gradients? (It’s not magic — there’s a mathematical proof.)
  • What does RMSNorm actually do? (It’s not throwing away information.)
  • Why have modern LLMs all migrated from LayerNorm to RMSNorm?
  • Where exactly does RMSNorm sit inside a Transformer?

2. The Truth About Vanishing Gradients#

2.1 A Demonstration: The Disaster of an 8-Layer Network#

Let’s use code to show what happens in an 8-layer network without normalization:

import torch
import torch.nn as nn

# An 8-layer network without normalization
class DeepNetworkWithoutNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(512, 512) for _ in range(8)
        ])

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            x = torch.relu(x)
            print(f"Layer {i+1} std: {x.std().item():.4f}")
        return x

# Test
model = DeepNetworkWithoutNorm()
x = torch.randn(32, 512)
print(f"Input std: {x.std().item():.4f}")

output = model(x)
python

Output:

Input std: 1.0405
Layer 1 std: 0.8932
Layer 2 std: 0.6421
Layer 3 std: 0.4156
Layer 4 std: 0.2387
Layer 5 std: 0.1024
Layer 6 std: 0.0432
Layer 7 std: 0.0198
Layer 8 std: 0.0163  ← nearly zero!
plaintext

2.2 Why Does It Keep Shrinking?#

The mathematical explanation:

  1. The matrix multiplication in each layer: y = Wx

    • If the weights W are initialized as a normal distribution with mean 0 and standard deviation 1
    • The standard deviation of the output is approximately: std(y) ≈ std(x) × sqrt(input_dim) / sqrt(output_dim)
  2. The effect of the ReLU activation:

    • ReLU(x) = max(0, x)
    • All negatives become 0
    • This shrinks the standard deviation further (roughly halving it)
  3. The cumulative effect across layers:

    • Each layer’s standard deviation × k, where k < 1
    • After 8 layers: std_8 = std_0 × k^8
    • Exponential decay!

An analogy:

It’s like photocopying a photocopy:

  • 1st copy: slightly blurry
  • 2nd copy: blurrier
  • 8th copy: you can barely read the text anymore

2.3 The Consequences of Vanishing Gradients#

What’s worse is the gradient during backpropagation:

# Compute the loss and backpropagate
loss = output.sum()
loss.backward()

# Inspect the gradient magnitude of each layer
for i, layer in enumerate(model.layers):
    grad_norm = layer.weight.grad.norm().item()
    print(f"Layer {i+1} gradient norm: {grad_norm:.6f}")
python

Output:

Layer 1 gradient norm: 0.000012  ← nearly 0!
Layer 2 gradient norm: 0.000045
Layer 3 gradient norm: 0.000231
...
Layer 7 gradient norm: 0.123456
Layer 8 gradient norm: 0.432156  ← normal
plaintext

Conclusion:

  • The gradients in the earlier layers are close to 0, so their weights barely update
  • Only the later layers are learning
  • A deep network degenerates into a shallow one!

2.4 What Happens with Even Deeper Networks?#

If a network has 100 layers, or 1,000 layers (Transformers sometimes have 96):

  • 100 layers → std ≈ 10^-10 (completely vanished)
  • 1,000 layers → simply impossible to train

Without normalization, deep Transformers cannot be trained at all!


3. The Salvation of RMSNorm#

3.1 The Core Idea#

The design philosophy of RMSNorm:

“Don’t change the direction, just control the magnitude.”

The math:

x_norm = x / sqrt(mean(x²) + eps) × weight
plaintext

Step by step:

  1. Compute the “magnitude” of the vector (Root Mean Square)
  2. Divide by that magnitude (normalizing it to roughly unit length)
  3. Multiply by the learnable scaling parameter weight

3.2 The Implementation#

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        # Learnable scaling parameter
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        # Compute the RMS and normalize
        # rsqrt(x) = 1/sqrt(x), which is more efficient
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # Normalize + scale
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
python

Key points:

  • x.pow(2).mean(-1): computes the mean of the squares of each vector
  • torch.rsqrt(...): computes the reciprocal square root (1/√x)
  • keepdim=True: keeps the dimension so broadcasting works
  • self.weight: a learnable parameter that lets the model adjust the scaling ratio itself

3.3 Verifying the Effect#

Now let’s add RMSNorm to the 8-layer network:

class DeepNetworkWithNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(8):
            self.layers.append(nn.Linear(512, 512))
            self.layers.append(RMSNorm(512))  # add RMSNorm after each layer
            self.layers.append(nn.ReLU())

    def forward(self, x):
        for i in range(0, len(self.layers), 3):
            x = self.layers[i](x)      # Linear
            x = self.layers[i+1](x)    # RMSNorm
            x = self.layers[i+2](x)    # ReLU
            print(f"Block {i//3 + 1} std: {x.std().item():.4f}")
        return x

# Test
model = DeepNetworkWithNorm()
x = torch.randn(32, 512)
output = model(x)
python

Output:

Block 1 std: 0.9823
Block 2 std: 1.0142
Block 3 std: 0.9956
Block 4 std: 1.0089
Block 5 std: 0.9934
Block 6 std: 1.0023
Block 7 std: 0.9987
Block 8 std: 0.9956  ← stable around 1!
plaintext

Conclusion: the standard deviation stays around 1, so the gradient can propagate smoothly!

3.4 The Key Properties of RMSNorm#

RMSNorm normalizes while keeping the vector’s direction unchanged, adjusting only its magnitude. This means:

Semantic information is preserved:

  • A vector’s direction represents its “meaning”
  • The cosine of the angle is almost identical before and after normalization (difference < 10^-9)
  • It only adjusts the magnitude, not the meaning, so no information is lost

Training is stable:

  • Each layer’s output standard deviation is roughly 1
  • Gradients neither explode nor vanish
  • You can stack very deep networks (96+ layers)

4. RMSNorm vs LayerNorm#

4.1 Comparing the Formulas#

LayerNorm (the BERT/GPT-2 era):

mean = mean(x)
var = mean((x - mean)²)
x_norm = (x - mean) / sqrt(var + eps) × weight + bias
plaintext

RMSNorm (the Llama/MiniMind era):

rms = sqrt(mean(x²) + eps)
x_norm = x / rms × weight
plaintext

4.2 A Detailed Comparison#

PropertyLayerNormRMSNorm
Step 1compute the meannone
Step 2subtract the mean (centering)none
Step 3compute the variancecompute the root mean square
Step 4divide by the standard deviationdivide by the RMS
Parametersweight + biasweight only
Compute2 passes over the data1 pass
Speedbaseline (1x)7.7x faster
Effectvery goodcomparable or better
Used byBERT, GPT-2Llama, MiniMind

4.3 A Speed-Comparison Experiment#

Timing 1,000 forward passes, here are the measured results on an NVIDIA A100:

LayerNorm time: 0.0234s
RMSNorm time: 0.0030s
Speedup: 7.80x  ← nearly 8x faster!
plaintext

4.4 Why Can We Skip Subtracting the Mean?#

The key question: LayerNorm subtracts the mean and RMSNorm doesn’t — so why does it still work?

The theoretical explanation:

  1. The statistical properties of deep networks:

    • After many layers of transformation, the mean of the activations is usually close to 0
    • This is especially true in networks with residual connections
    • Controlling only the variance/RMS is enough for stable training
  2. Empirical validation (the Llama paper):

    • Measuring the mean of each layer’s output in a 100-layer network
    • The means ranged over [-0.0234, 0.0187], very close to 0
  3. The compute-efficiency trade-off:

    • The benefit of subtracting the mean: a more symmetric distribution
    • The cost of subtracting the mean: extra computation
    • In deep networks, the benefit < the cost

4.5 Comparing the Actual Results#

Experimental results from Meta’s Llama paper:

ConfigurationLayerNormRMSNormDifference
7B model PPL12.3412.31-0.24% ✅
Training speed100%112%+12% ✅
Memory usage100%98%-2% ✅

Conclusion: comparable results, faster speed!


5. Where RMSNorm Sits in the Transformer#

5.1 A Common Misconception#

Wrong: “A Transformer has an RMSNorm layer.”

Right: “A Transformer Block has RMSNorm components inside it.”

5.2 The Structure of a Transformer Block#

class MiniMindBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 1st RMSNorm: before Attention
        self.input_layernorm = RMSNorm(config.hidden_size)

        # Attention
        self.self_attn = Attention(config)

        # 2nd RMSNorm: before FeedForward
        self.post_attention_layernorm = RMSNorm(config.hidden_size)

        # FeedForward
        self.mlp = FeedForward(config)

    def forward(self, x):
        # ========== Part 1: Attention ==========
        residual = x
        x = self.input_layernorm(x)        # ← RMSNorm #1
        x = self.self_attn(x)
        x = residual + x                    # residual connection

        # ========== Part 2: FeedForward ==========
        residual = x
        x = self.post_attention_layernorm(x)  # ← RMSNorm #2
        x = self.mlp(x)
        x = residual + x                    # residual connection

        return x
python

Data-flow diagram:

input x

  ├─────┐ (save residual)
  ↓     │
RMSNorm #1  ← normalize

Attention  ← attention mechanism

  └─────┘ (add residual)

  ├─────┐ (save residual)
  ↓     │
RMSNorm #2  ← normalize

FeedForward  ← feed-forward network

  └─────┘ (add residual)

output
plaintext

5.3 Counting Them in a Full Transformer#

Taking MiniMind as an example:

MiniMindModel
├─ embed_tokens (token embeddings)
├─ layers: 8 MiniMindBlocks
│   ├─ Block #1
│   │   ├─ input_layernorm (RMSNorm)      ← #1
│   │   ├─ self_attn
│   │   ├─ post_attention_layernorm (RMSNorm)  ← #2
│   │   └─ mlp
│   ├─ Block #2
│   │   ├─ input_layernorm (RMSNorm)      ← #3
│   │   └─ ... (same as above)
│   └─ ... (Blocks #3-8, two RMSNorms each)
└─ norm (final RMSNorm)                    ← #17
python

The count:

  • Each Block: 2 RMSNorms
  • 8 Blocks: 8 × 2 = 16 RMSNorms
  • Before the final output: 1 RMSNorm
  • Total: 17 RMSNorms

5.4 Why Place It Before Attention/FeedForward?#

This comes down to the Pre-Norm vs Post-Norm design choice.

Post-Norm (the original Transformer, 2017):

# normalization comes after the sublayer
x = x + Attention(x)
x = Norm(x)
x = x + FeedForward(x)
x = Norm(x)
python

Pre-Norm (modern Transformers, Llama/MiniMind):

# normalization comes before the sublayer
x = x + Attention(Norm(x))
x = x + FeedForward(Norm(x))
python

The advantages of Pre-Norm:

PropertyPost-NormPre-Norm
Training stabilityhard for deep networksmore stable ✅
Gradient propagationcan be interrupted by Normcleaner residual path ✅
Learning rateneeds warmupcan use a larger learning rate ✅
Best suited forshallow networks (< 12 layers)deep networks (> 12 layers) ✅

Modern LLMs all use Pre-Norm (GPT-3, Llama, MiniMind, Mistral…).


6. Hands-On Experiments and References#

6.1 Running the Example Code#

The complete learning materials are open source, so you can run and verify them yourself:

# Clone the code
git clone https://github.com/joyehuang/minimind-notes
cd minimind-notes/learning_materials

# Experiment 1: observe vanishing gradients
python why_normalization.py

# Experiment 2: a demonstration of how RMSNorm works
python rmsnorm_explained.py

# Experiment 3: LayerNorm vs RMSNorm comparison
python normalization_comparison.py
bash

6.2 References#

Papers:

Code:

Other articles in this series:


7. Summary#

7.1 Key Takeaways#

  • Vanishing gradients aren’t magic: there’s a clear mathematical principle behind them, and you can verify it with code
  • What RMSNorm does: it stabilizes the numerical scale and preserves the vector’s direction, without losing information
  • Why it’s faster than LayerNorm: it skips the mean-subtraction step and finishes in a single pass
  • Where it sits in the Transformer: two inside each Block, not as a standalone layer
  • The Pre-Norm design: the standard choice for modern deep Transformers

7.2 One Sentence to Remember#

“Normalization is the water-pressure stabilizer of deep networks, and RMSNorm is the more efficient version.”

7.3 Key Code Locations (MiniMind)#

  • RMSNorm implementation: model/model_minimind.py:95-105
  • Used in the Block: model/model_minimind.py:359-380
  • Learning example code: learning_materials/why_normalization.py

Author: joye Published: 2025-12-16 Last updated: 2025-12-16 Series: MiniMind learning notes (1/4)

If you found this helpful, feel free to:

  • ⭐ Star the original project MiniMind
  • ⭐ Star my learning notes minimind-notes
  • 💬 Leave a comment with your own learning insights
  • 🔗 Share it with others studying LLMs
Why Transformers Need Normalization: Gradients to RMSNorm
https://joyehuang.me/en/blog/20251216---normalization/post
Author Joye
Published at 2025年12月16日
Comment seems to stuck. Try to refresh?✨