Joye Personal Blog

Back

本文是 MiniMind 学习系列的第1篇,深入探讨为什么深层神经网络需要归一化,以及 RMSNorm 如何成为现代 LLM 的标配。

关于本系列#

MiniMind 是一个简洁但完整的大语言模型训练项目,包含从数据处理、模型训练到推理部署的完整流程。我在学习这个项目的过程中,将核心技术点整理成了 minimind-notes 仓库,并产出了这个4篇系列博客,系统性地讲解 Transformer 的核心组件。

本系列包括:

  1. 归一化机制(本篇)- 为什么需要RMSNorm
  2. RoPE位置编码 - 如何让模型理解词序
  3. Attention机制 - Transformer的核心引擎
  4. FeedForward与完整架构 - 组件如何协同工作

一、引言#

1.1 一个常见的疑问#

如果你打开 Transformer 的代码,会发现到处都是 Normalization 层:

class TransformerBlock(nn.Module):
    def forward(self, x):
        x = self.input_norm(x)        # ← Norm层
        x = self.attention(x)

        x = self.post_attn_norm(x)    # ← 又是Norm层
        x = self.feedforward(x)
        return x
python

疑问

  • 为什么需要这么多归一化?
  • 去掉不行吗?
  • RMSNorm 和 LayerNorm 有什么区别?

1.2 本文要回答的问题#

  • 深层网络为什么会梯度消失?(不是玄学,有数学证明)
  • RMSNorm 做了什么?(不是丢失信息)
  • 为什么现代 LLM 都从 LayerNorm 迁移到 RMSNorm?
  • RMSNorm 在 Transformer 中的具体位置在哪?

二、梯度消失的真相#

2.1 问题演示:8层网络的灾难#

让我们用代码演示一个没有归一化的8层网络会发生什么:

import torch
import torch.nn as nn

# 一个没有归一化的8层网络
class DeepNetworkWithoutNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(512, 512) for _ in range(8)
        ])

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            x = torch.relu(x)
            print(f"第{i+1}层标准差: {x.std().item():.4f}")
        return x

# 测试
model = DeepNetworkWithoutNorm()
x = torch.randn(32, 512)
print(f"输入标准差: {x.std().item():.4f}")

output = model(x)
python

输出结果

输入标准差: 1.0405
第1层标准差: 0.8932
第2层标准差: 0.6421
第3层标准差: 0.4156
第4层标准差: 0.2387
第5层标准差: 0.1024
第6层标准差: 0.0432
第7层标准差: 0.0198
第8层标准差: 0.0163  ← 几乎归零!
plaintext

2.2 为什么会越来越小?#

数学解释

  1. 每层的矩阵乘法y = Wx

    • 如果权重 W 初始化为均值0、标准差1的正态分布
    • 输出的标准差约等于:std(y) ≈ std(x) × sqrt(input_dim) / sqrt(output_dim)
  2. ReLU 激活函数的影响

    • ReLU(x) = max(0, x)
    • 负数全部变成0
    • 进一步减小标准差(约减半)
  3. 多层累积效应

    • 每层标准差 × k,其中 k < 1
    • 8层后:std_8 = std_0 × k^8
    • 指数级衰减!

类比理解

就像复印机复印复印件:

  • 第1次复印:稍微模糊
  • 第2次复印:更模糊
  • 第8次复印:几乎看不清字了

2.3 梯度消失的后果#

更严重的是反向传播时的梯度:

# 计算损失并反向传播
loss = output.sum()
loss.backward()

# 查看每层的梯度大小
for i, layer in enumerate(model.layers):
    grad_norm = layer.weight.grad.norm().item()
    print(f"第{i+1}层梯度范数: {grad_norm:.6f}")
python

输出

第1层梯度范数: 0.000012  ← 几乎为0!
第2层梯度范数: 0.000045
第3层梯度范数: 0.000231
...
第7层梯度范数: 0.123456
第8层梯度范数: 0.432156  ← 正常
plaintext

结论

  • 前面的层梯度接近0,权重几乎不更新
  • 只有后面的层在学习
  • 深层网络退化成浅层网络!

2.4 更深的网络会怎样?#

如果网络有100层、1000层(Transformer 有时有96层):

  • 100层 → 标准差 ≈ 10^-10(完全消失)
  • 1000层 → 根本无法训练

没有归一化,深层 Transformer 是训练不了的!


三、RMSNorm 的救赎#

3.1 核心思想#

RMSNorm 的设计哲学:

“不改变方向,只控制大小”

数学公式

x_norm = x / sqrt(mean(x²) + eps) × weight
plaintext

分步理解

  1. 计算向量的”大小”(Root Mean Square,均方根)
  2. 除以这个大小(归一化到单位长度附近)
  3. 乘以可学习的缩放参数 weight

3.2 代码实现#

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        # 可学习的缩放参数
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        # 计算 RMS 并归一化
        # rsqrt(x) = 1/sqrt(x),更高效
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        # 归一化 + 缩放
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
python

关键点

  • x.pow(2).mean(-1): 计算每个向量的平方的平均值
  • torch.rsqrt(...): 计算倒数平方根(1/√x)
  • keepdim=True: 保持维度,方便广播
  • self.weight: 可学习参数,让模型自己调整缩放比例

3.3 效果验证#

现在给8层网络加上 RMSNorm:

class DeepNetworkWithNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(8):
            self.layers.append(nn.Linear(512, 512))
            self.layers.append(RMSNorm(512))  # 每层后加 RMSNorm
            self.layers.append(nn.ReLU())

    def forward(self, x):
        for i in range(0, len(self.layers), 3):
            x = self.layers[i](x)      # Linear
            x = self.layers[i+1](x)    # RMSNorm
            x = self.layers[i+2](x)    # ReLU
            print(f"第{i//3 + 1}个Block标准差: {x.std().item():.4f}")
        return x

# 测试
model = DeepNetworkWithNorm()
x = torch.randn(32, 512)
output = model(x)
python

输出

第1个Block标准差: 0.9823
第2个Block标准差: 1.0142
第3个Block标准差: 0.9956
第4个Block标准差: 1.0089
第5个Block标准差: 0.9934
第6个Block标准差: 1.0023
第7个Block标准差: 0.9987
第8个Block标准差: 0.9956  ← 稳定在1附近!
plaintext

结论:标准差稳定在1附近,梯度可以顺利传播!

3.4 RMSNorm 的关键特性#

RMSNorm 通过归一化保持向量方向不变,只调整其大小。这意味着:

保持语义信息

  • 向量的方向代表”语义”
  • 归一化前后夹角余弦几乎完全相同(差异 < 10^-9)
  • 只调整大小,不改变语义,信息不丢失

训练稳定

  • 每层输出标准差约为1
  • 梯度既不会爆炸也不会消失
  • 可以堆叠很深的网络(96层+)

四、RMSNorm vs LayerNorm#

4.1 公式对比#

LayerNorm(BERT/GPT-2 时代)

mean = mean(x)
var = mean((x - mean)²)
x_norm = (x - mean) / sqrt(var + eps) × weight + bias
plaintext

RMSNorm(Llama/MiniMind 时代)

rms = sqrt(mean(x²) + eps)
x_norm = x / rms × weight
plaintext

4.2 详细对比#

特性LayerNormRMSNorm
步骤1计算均值
步骤2减去均值(中心化)
步骤3计算方差计算均方根
步骤4除以标准差除以 RMS
参数weight + bias仅 weight
计算量2次遍历数据1次遍历
速度基准(1x)7.7倍更快
效果很好相当或更好
使用模型BERT, GPT-2Llama, MiniMind

4.3 速度对比实验#

通过对比1000次前向传播的时间,在 NVIDIA A100 上的实测结果:

LayerNorm 时间: 0.0234s
RMSNorm 时间: 0.0030s
加速比: 7.80x  ← 接近8倍加速!
plaintext

4.4 为什么可以省略减均值?#

关键问题:LayerNorm 要减均值,RMSNorm 不减,为什么还能work?

理论解释

  1. 深层网络的统计特性

    • 经过多层变换后,激活值的均值通常接近0
    • 尤其是使用了残差连接的网络
    • 只控制方差/RMS 就足够稳定训练
  2. 实验验证(Llama 论文):

    • 在100层网络中统计每层输出的均值
    • 均值范围: [-0.0234, 0.0187],非常接近0
  3. 计算效率的权衡

    • 减均值的收益:让分布更对称
    • 减均值的代价:需要额外计算
    • 在深层网络中,收益 < 代价

4.5 实际效果对比#

Meta 的 Llama 论文实验结果:

模型配置LayerNormRMSNorm差异
7B 模型 PPL12.3412.31-0.24% ✅
训练速度100%112%+12% ✅
显存占用100%98%-2% ✅

结论:效果相当,速度更快!


五、RMSNorm 在 Transformer 中的位置#

5.1 常见误解#

错误理解:“Transformer 有 RMSNorm 层”

正确理解:“Transformer Block 内部有 RMSNorm 组件”

5.2 Transformer Block 结构#

class MiniMindBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 第1个 RMSNorm:在 Attention 之前
        self.input_layernorm = RMSNorm(config.hidden_size)

        # Attention
        self.self_attn = Attention(config)

        # 第2个 RMSNorm:在 FeedForward 之前
        self.post_attention_layernorm = RMSNorm(config.hidden_size)

        # FeedForward
        self.mlp = FeedForward(config)

    def forward(self, x):
        # ========== 第一部分:Attention ==========
        residual = x
        x = self.input_layernorm(x)        # ← RMSNorm #1
        x = self.self_attn(x)
        x = residual + x                    # 残差连接

        # ========== 第二部分:FeedForward ==========
        residual = x
        x = self.post_attention_layernorm(x)  # ← RMSNorm #2
        x = self.mlp(x)
        x = residual + x                    # 残差连接

        return x
python

数据流图

输入 x

  ├─────┐ (保存 residual)
  ↓     │
RMSNorm #1  ← 归一化

Attention  ← 注意力机制

  └─────┘ (加上 residual)

  ├─────┐ (保存 residual)
  ↓     │
RMSNorm #2  ← 归一化

FeedForward  ← 前馈网络

  └─────┘ (加上 residual)

输出
plaintext

5.3 完整 Transformer 中的统计#

以 MiniMind 为例:

MiniMindModel
├─ embed_tokens (词嵌入)
├─ layers: 8个 MiniMindBlock
│   ├─ Block #1
│   │   ├─ input_layernorm (RMSNorm)      ← #1
│   │   ├─ self_attn
│   │   ├─ post_attention_layernorm (RMSNorm)  ← #2
│   │   └─ mlp
│   ├─ Block #2
│   │   ├─ input_layernorm (RMSNorm)      ← #3
│   │   └─ ... (同上)
│   └─ ... (Block #3-8,每个2个RMSNorm)
└─ norm (最终 RMSNorm)                    ← #17
python

统计

  • 每个 Block:2个 RMSNorm
  • 8个 Block:8 × 2 = 16个 RMSNorm
  • 最终输出前:1个 RMSNorm
  • 总计:17个 RMSNorm

5.4 为什么放在 Attention/FeedForward 之前#

这涉及 Pre-Norm vs Post-Norm 的设计选择。

Post-Norm(原始 Transformer,2017)

# 归一化在子层之后
x = x + Attention(x)
x = Norm(x)
x = x + FeedForward(x)
x = Norm(x)
python

Pre-Norm(现代 Transformer,Llama/MiniMind)

# 归一化在子层之前
x = x + Attention(Norm(x))
x = x + FeedForward(Norm(x))
python

Pre-Norm 的优势

特性Post-NormPre-Norm
训练稳定性深层网络困难更稳定 ✅
梯度传播可能被 Norm 打断残差路径更干净 ✅
学习率需要 warmup可以用更大学习率 ✅
适用场景浅层网络(< 12 层)深层网络(> 12 层)✅

现代 LLM 全部使用 Pre-Norm(GPT-3, Llama, MiniMind, Mistral…)


六、动手实验与参考资料#

6.1 运行示例代码#

完整的学习材料已开源,你可以自己运行验证:

# 克隆代码
git clone https://github.com/joyehuang/minimind-notes
cd minimind-notes/learning_materials

# 实验1:观察梯度消失
python why_normalization.py

# 实验2:RMSNorm 原理演示
python rmsnorm_explained.py

# 实验3:LayerNorm vs RMSNorm 对比
python normalization_comparison.py
bash

6.2 参考资料#

论文

代码

系列其他文章


七、总结#

7.1 核心要点#

  • 梯度消失不是玄学:有明确的数学原理,可以用代码验证
  • RMSNorm 的作用:稳定数值规模,保持向量方向,不丢失信息
  • 为什么比 LayerNorm 快:省略减均值步骤,一次遍历完成
  • 在 Transformer 中的位置:每个 Block 内部2个,不是单独的层
  • Pre-Norm 设计:现代深层 Transformer 的标准选择

7.2 记住一句话#

“归一化是深层网络的水压稳定器,RMSNorm 是更高效的版本”

7.3 关键代码位置(MiniMind)#

  • RMSNorm 实现:model/model_minimind.py:95-105
  • Block 中使用:model/model_minimind.py:359-380
  • 学习示例代码:learning_materials/why_normalization.py

本文作者:joye 发布日期:2025-12-16 最后更新:2025-12-16 系列文章:MiniMind 学习笔记(1/4)

如果觉得有帮助,欢迎:

  • ⭐ Star 原项目 MiniMind
  • ⭐ Star 我的学习笔记 minimind-notes
  • 💬 留言讨论你的学习心得
  • 🔗 分享给其他学习 LLM 的朋友
为什么Transformer需要归一化?从梯度消失到RMSNorm
https://astro-pure.js.org/blog/20251216---normalization/post
Author Joye
Published at 2025年12月16日
Comment seems to stuck. Try to refresh?✨