深入理解 RMSNorm:大模型时代的归一化新宠
深入理解 RMSNorm:大模型时代的归一化新宠
一、前言
如果你关注过 LLaMA、Qwen、Gemma、Mistral 等当下主流大语言模型的架构设计,你一定会注意到一个共同的选择——它们都不约而同地抛弃了经典的 LayerNorm,转而使用 RMSNorm。
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
上面这段精简的代码就是一个不带可学习参数的纯函数式 RMSNorm。短短两行,却蕴含着深刻的设计哲学。本文将从原理、公式、对比、代码实现等多个维度,带你彻底搞懂 RMSNorm。
二、从 LayerNorm 说起
2.1 LayerNorm 回顾
LayerNorm(层归一化)由 Ba et al. 在 2016 年提出,是 Transformer 架构的标配组件。其核心操作如下:
LayerNorm(x)=γ⊙x−μσ2+ϵ+β \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γ⊙σ2+ϵx−μ+β
其中:
| 符号 | 含义 |
|---|---|
| μ=1n∑i=1nxi\mu = \frac{1}{n}\sum_{i=1}^{n}x_iμ=n1∑i=1nxi | 均值(mean) |
| σ2=1n∑i=1n(xi−μ)2\sigma^2 = \frac{1}{n}\sum_{i=1}^{n}(x_i - \mu)^2σ2=n1∑i=1n(xi−μ)2 | 方差(variance) |
| γ,β\gamma, \betaγ,β | 可学习的缩放和偏移参数 |
| ϵ\epsilonϵ | 防止除零的小常数 |
LayerNorm 做了两件事:
- Re-centering(重新中心化):减去均值 μ\muμ
- Re-scaling(重新缩放):除以标准差 σ\sigmaσ
2.2 一个自然的问题
这两步操作都是必要的吗?减去均值这一步真的不可或缺吗?
这正是 RMSNorm 论文提出的核心质疑。
三、RMSNorm 原理详解
3.1 论文出处
RMSNorm 由 Biao Zhang 和 Rico Sennrich 在 2019 年的论文中提出:
“Root Mean Square Layer Normalization” (NeurIPS 2019)
3.2 核心思想
论文的核心假设是:
LayerNorm 的成功主要归功于「重新缩放」(re-scaling)操作,而非「重新中心化」(re-centering)操作。
基于此,RMSNorm 大胆去掉了减均值的步骤,仅保留基于均方根(Root Mean Square, RMS)的缩放:
RMSNorm(x)=γ⊙xRMS(x)+ϵ \text{RMSNorm}(x) = \gamma \odot \frac{x}{\text{RMS}(x) + \epsilon} RMSNorm(x)=γ⊙RMS(x)+ϵx
其中 RMS 的计算为:
RMS(x)=1n∑i=1nxi2 \text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2} RMS(x)=n1i=1∑nxi2
3.3 直观理解
用一张图来对比两者的流程:
LayerNorm 流程:
┌─────────┐ ┌──────────────┐ ┌──────────────┐ ┌─────────┐
│ 输入 x │───▶│ 减去均值 (μ) │───▶│ 除以标准差 (σ)│───▶│ γ·x + β │
└─────────┘ └──────────────┘ └──────────────┘ └─────────┘
re-centering re-scaling affine
RMSNorm 流程:
┌─────────┐ ┌──────────────────┐ ┌─────────┐
│ 输入 x │───▶│ 除以 RMS 值 │───▶│ γ·x │
└─────────┘ └──────────────────┘ └─────────┘
re-scaling only scale only
关键区别:
- ❌ 不计算均值,不做中心化
- ❌ 通常不使用偏置项 β\betaβ(很多实现中也省略 γ\gammaγ)
- ✅ 仅通过 RMS 进行缩放归一化
四、为什么 RMSNorm 更受青睐?
4.1 计算效率更高
| 操作 | LayerNorm | RMSNorm |
|---|---|---|
| 求均值 μ\muμ | ✅ 需要 | ❌ 不需要 |
| 求方差 σ2\sigma^2σ2(需要先求 μ\muμ) | ✅ 需要 | ❌ 不需要 |
| 求平方均值 1n∑xi2\frac{1}{n}\sum x_i^2n1∑xi2 | ❌ | ✅ 需要 |
| 减去均值 | ✅ 需要 | ❌ 不需要 |
| 归约操作次数 | 2 次 | 1 次 |
RMSNorm 只需要一次归约操作(reduce),而 LayerNorm 需要两次(先求均值,再求方差),在 GPU 上归约操作是主要的性能瓶颈之一。
论文实验表明,RMSNorm 相比 LayerNorm 可以节省约 7%~64% 的计算时间。
4.2 效果不降反升
在多项任务的实验中,RMSNorm 的表现与 LayerNorm 持平甚至更优。这验证了「re-centering 并非必要」的假设。
4.3 大模型的一致选择
以下主流大模型全部使用 RMSNorm:
| 模型 | 机构 | 归一化方式 |
|---|---|---|
| LLaMA / LLaMA 2 / LLaMA 3 | Meta | RMSNorm |
| Qwen / Qwen2 | 阿里 | RMSNorm |
| Mistral / Mixtral | Mistral AI | RMSNorm |
| Gemma | RMSNorm | |
| DeepSeek | DeepSeek | RMSNorm |
| ChatGLM | 清华/智谱 | RMSNorm |
五、代码实现详解
5.1 文章开头的极简实现
import torch.nn.functional as F
def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
解析:
F.rms_norm:PyTorch 提供的函数式 RMSNorm(torch >= 2.4)(x.size(-1),):归一化的维度形状,即对最后一个维度做归一化- 无
weight和bias参数:这是一个纯粹的数学操作,没有可学习参数 - “Purely functional”:纯函数式风格,没有副作用,不持有状态
5.2 手动实现(教学版)
为了深入理解原理,我们从零实现一个 RMSNorm:
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Args:
dim: 特征维度(最后一维的大小)
eps: 防止除零的小常数
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # 可学习的缩放参数 γ
def _norm(self, x: torch.Tensor) -> torch.Tensor:
"""核心归一化逻辑(不含可学习参数)"""
# x.pow(2) : 每个元素平方
# .mean(-1, ...) : 对最后一维求均值 → 得到均方值
# + self.eps : 加上 epsilon 防止除零
# .rsqrt() : 求倒数平方根 = 1/sqrt(x)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# float() 确保以 float32 精度计算归一化,然后转回原精度
output = self._norm(x.float()).type_as(x)
return output * self.weight
5.3 逐步计算示例
import torch
# 假设一个简单的输入向量
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
print(f"输入 x: {x}")
# Step 1: 计算每个元素的平方
x_sq = x.pow(2)
print(f"平方: {x_sq}") # [1, 4, 9, 16]
# Step 2: 求均值(均方值, Mean Square)
ms = x_sq.mean(-1, keepdim=True)
print(f"均方值: {ms}") # (1+4+9+16)/4 = 7.5
# Step 3: 开方(均方根, Root Mean Square)
rms = ms.sqrt()
print(f"RMS: {rms}") # sqrt(7.5) ≈ 2.7386
# Step 4: 除以 RMS(归一化)
x_norm = x / rms
print(f"归一化: {x_norm}") # [0.3651, 0.7303, 1.0954, 1.4606]
# 验证:归一化后的向量,其平方均值应接近 1
print(f"验证 - 归一化后平方均值: {x_norm.pow(2).mean()}") # ≈ 1.0
输出:
输入 x: tensor([[1., 2., 3., 4.]])
平方: tensor([[ 1., 4., 9., 16.]])
均方值: tensor([[7.5000]])
RMS: tensor([[2.7386]])
归一化: tensor([[0.3651, 0.7303, 1.0954, 1.4606]])
验证 - 归一化后平方均值: 1.0
5.4 对比 LLaMA 官方实现
来看 Meta LLaMA 的官方实现,几乎完全一致:
# 来自 Meta LLaMA 官方代码
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
注意
.float()和.type_as(x)的技巧:即使模型用bfloat16/float16训练,归一化的中间计算始终在float32下进行,避免精度损失,最后再转回原始精度。
六、数学层面的深入理解
6.1 RMSNorm vs LayerNorm 的数学关系
LayerNorm 的方差可以展开为:
σ2=1n∑i=1n(xi−μ)2=1n∑i=1nxi2⏟RMS2−μ2 \sigma^2 = \frac{1}{n}\sum_{i=1}^{n}(x_i - \mu)^2 = \underbrace{\frac{1}{n}\sum_{i=1}^{n}x_i^2}_{\text{RMS}^2} - \mu^2 σ2=n1i=1∑n(xi−μ)2=RMS2 n1i=1∑nxi2−μ2
所以:
RMS2(x)=σ2+μ2 \text{RMS}^2(x) = \sigma^2 + \mu^2 RMS2(x)=σ2+μ2
当均值 μ≈0\mu \approx 0μ≈0 时,RMS(x)≈σ\text{RMS}(x) \approx \sigmaRMS(x)≈σ,两者几乎等价。而在深度网络中,经过多层变换后,激活值的均值往往确实接近于零(特别是使用了 SwiGLU、GeGLU 等激活函数时),这为 RMSNorm 的有效性提供了理论支撑。
6.2 几何解释
RMSNorm 的本质操作可以理解为:
将向量投影到一个以原点为中心、半径为 n\sqrt{n}n 的超球面上。
归一化前: 向量长度不一,方向各异
↓ RMSNorm
归一化后: 所有向量的「RMS 长度」统一为 1
(即每个分量的平方均值为 1)
七、Pre-Norm vs Post-Norm
在现代大模型中,RMSNorm 通常搭配 Pre-Norm 架构使用:
Post-Norm(GPT-1/原始 Transformer):
x → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·) → LayerNorm
Pre-Norm(GPT-2/LLaMA 等现代架构):
x → RMSNorm → Attention → Add(x, ·) → RMSNorm → FFN → Add(x, ·)
# Pre-RMSNorm 的典型用法(伪代码)
class TransformerBlock(nn.Module):
def forward(self, x):
# 注意力子层:Pre-Norm + 残差连接
x = x + self.attention(self.norm1(x))
# FFN 子层:Pre-Norm + 残差连接
x = x + self.ffn(self.norm2(x))
return x
Pre-Norm + RMSNorm 的组合优势:
- ✅ 训练更稳定,梯度流更顺畅
- ✅ 可以支持更大的学习率
- ✅ 允许更深的网络层数
八、性能对比实测
import torch
import time
def benchmark(fn, x, name, runs=1000):
# 预热
for _ in range(100):
fn(x)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(runs):
fn(x)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / runs * 1e6 # 微秒
print(f"{name}: {elapsed:.1f} μs")
# 设置
x = torch.randn(32, 2048, 4096, device='cuda', dtype=torch.bfloat16)
ln = torch.nn.LayerNorm(4096, device='cuda', dtype=torch.bfloat16)
# 对比测试
benchmark(lambda x: F.rms_norm(x, (4096,)), x, "RMSNorm (F.rms_norm)")
benchmark(lambda x: ln(x), x, "LayerNorm (nn.LayerNorm)")
在典型的 LLM 推理场景中,RMSNorm 比 LayerNorm 快 10%~30%。
九、总结
| 维度 | LayerNorm | RMSNorm |
|---|---|---|
| 提出时间 | 2016 | 2019 |
| 中心化(减均值) | ✅ 有 | ❌ 无 |
| 缩放归一化 | ✅ 有 | ✅ 有 |
| 可学习参数 | γ,β\gamma, \betaγ,β | 通常仅 γ\gammaγ |
| 归约操作次数 | 2 次 | 1 次 |
| 计算速度 | 较慢 | 更快 |
| 模型效果 | 好 | 相当甚至更好 |
| 大模型采用率 | 较少(旧模型) | 主流选择 |
一句话总结:
RMSNorm 去掉了 LayerNorm 中「多余的」均值中心化操作,只保留了基于均方根的缩放归一化,在保持效果的同时显著提升了计算效率,已成为当代大语言模型的标准组件。
参考文献
- Zhang, B. & Sennrich, R. (2019). “Root Mean Square Layer Normalization”. NeurIPS 2019.
- Touvron, H. et al. (2023). “LLaMA: Open and Efficient Foundation Language Models”. Meta AI.
- Ba, J. L. et al. (2016). “Layer Normalization”. arXiv:1607.06450.
后记
2026年5月18日于上海,在claude opus 4.6辅助下完成。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)