深入理解 RMSNorm:大模型时代的归一化新宠


一、前言

如果你关注过 LLaMA、Qwen、Gemma、Mistral 等当下主流大语言模型的架构设计,你一定会注意到一个共同的选择——它们都不约而同地抛弃了经典的 LayerNorm,转而使用 RMSNorm

def norm(x):
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))

上面这段精简的代码就是一个不带可学习参数的纯函数式 RMSNorm。短短两行,却蕴含着深刻的设计哲学。本文将从原理、公式、对比、代码实现等多个维度,带你彻底搞懂 RMSNorm。


二、从 LayerNorm 说起

2.1 LayerNorm 回顾

LayerNorm(层归一化)由 Ba et al. 在 2016 年提出,是 Transformer 架构的标配组件。其核心操作如下:

LayerNorm(x)=γ⊙x−μσ2+ϵ+β \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γσ2+ϵ xμ+β

其中:

符号 含义
μ=1n∑i=1nxi\mu = \frac{1}{n}\sum_{i=1}^{n}x_iμ=n1i=1nxi 均值(mean)
σ2=1n∑i=1n(xi−μ)2\sigma^2 = \frac{1}{n}\sum_{i=1}^{n}(x_i - \mu)^2σ2=n1i=1n(xiμ)2 方差(variance)
γ,β\gamma, \betaγ,β 可学习的缩放和偏移参数
ϵ\epsilonϵ 防止除零的小常数

LayerNorm 做了两件事:

  1. Re-centering(重新中心化):减去均值 μ\muμ
  2. Re-scaling(重新缩放):除以标准差 σ\sigmaσ

2.2 一个自然的问题

这两步操作都是必要的吗?减去均值这一步真的不可或缺吗?

这正是 RMSNorm 论文提出的核心质疑。


三、RMSNorm 原理详解

3.1 论文出处

RMSNorm 由 Biao Zhang 和 Rico Sennrich 在 2019 年的论文中提出:

“Root Mean Square Layer Normalization” (NeurIPS 2019)

3.2 核心思想

论文的核心假设是:

LayerNorm 的成功主要归功于「重新缩放」(re-scaling)操作,而非「重新中心化」(re-centering)操作。

基于此,RMSNorm 大胆去掉了减均值的步骤,仅保留基于均方根(Root Mean Square, RMS)的缩放:

RMSNorm(x)=γ⊙xRMS(x)+ϵ \text{RMSNorm}(x) = \gamma \odot \frac{x}{\text{RMS}(x) + \epsilon} RMSNorm(x)=γRMS(x)+ϵx

其中 RMS 的计算为:

RMS(x)=1n∑i=1nxi2 \text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2} RMS(x)=n1i=1nxi2

3.3 直观理解

用一张图来对比两者的流程:

LayerNorm 流程:
┌─────────┐    ┌──────────────┐    ┌──────────────┐    ┌─────────┐
│  输入 x  │───▶│ 减去均值 (μ) │───▶│ 除以标准差 (σ)│───▶│ γ·x + β │
└─────────┘    └──────────────┘    └──────────────┘    └─────────┘
                  re-centering        re-scaling          affine

RMSNorm 流程:
┌─────────┐    ┌──────────────────┐    ┌─────────┐
│  输入 x  │───▶│ 除以 RMS 值      │───▶│   γ·x   │
└─────────┘    └──────────────────┘    └─────────┘
                   re-scaling only       scale only

关键区别

  • ❌ 不计算均值,不做中心化
  • ❌ 通常不使用偏置项 β\betaβ(很多实现中也省略 γ\gammaγ
  • ✅ 仅通过 RMS 进行缩放归一化

四、为什么 RMSNorm 更受青睐?

4.1 计算效率更高

操作 LayerNorm RMSNorm
求均值 μ\muμ ✅ 需要 ❌ 不需要
求方差 σ2\sigma^2σ2(需要先求 μ\muμ ✅ 需要 ❌ 不需要
求平方均值 1n∑xi2\frac{1}{n}\sum x_i^2n1xi2 ✅ 需要
减去均值 ✅ 需要 ❌ 不需要
归约操作次数 2 次 1 次

RMSNorm 只需要一次归约操作(reduce),而 LayerNorm 需要两次(先求均值,再求方差),在 GPU 上归约操作是主要的性能瓶颈之一。

论文实验表明,RMSNorm 相比 LayerNorm 可以节省约 7%~64% 的计算时间。

4.2 效果不降反升

在多项任务的实验中,RMSNorm 的表现与 LayerNorm 持平甚至更优。这验证了「re-centering 并非必要」的假设。

4.3 大模型的一致选择

以下主流大模型全部使用 RMSNorm

模型 机构 归一化方式
LLaMA / LLaMA 2 / LLaMA 3 Meta RMSNorm
Qwen / Qwen2 阿里 RMSNorm
Mistral / Mixtral Mistral AI RMSNorm
Gemma Google RMSNorm
DeepSeek DeepSeek RMSNorm
ChatGLM 清华/智谱 RMSNorm

五、代码实现详解

5.1 文章开头的极简实现

import torch.nn.functional as F

def norm(x):
    # Purely functional rmsnorm with no learnable params
    return F.rms_norm(x, (x.size(-1),))

解析:

  • F.rms_norm:PyTorch 提供的函数式 RMSNorm(torch >= 2.4
  • (x.size(-1),):归一化的维度形状,即对最后一个维度做归一化
  • weightbias 参数:这是一个纯粹的数学操作,没有可学习参数
  • “Purely functional”:纯函数式风格,没有副作用,不持有状态

5.2 手动实现(教学版)

为了深入理解原理,我们从零实现一个 RMSNorm:

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Args:
            dim: 特征维度(最后一维的大小)
            eps: 防止除零的小常数
        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # 可学习的缩放参数 γ
  
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        """核心归一化逻辑(不含可学习参数)"""
        # x.pow(2)       : 每个元素平方
        # .mean(-1, ...)  : 对最后一维求均值 → 得到均方值
        # + self.eps      : 加上 epsilon 防止除零
        # .rsqrt()        : 求倒数平方根 = 1/sqrt(x)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # float() 确保以 float32 精度计算归一化,然后转回原精度
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

5.3 逐步计算示例

import torch

# 假设一个简单的输入向量
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
print(f"输入 x: {x}")

# Step 1: 计算每个元素的平方
x_sq = x.pow(2)
print(f"平方:   {x_sq}")          # [1, 4, 9, 16]

# Step 2: 求均值(均方值, Mean Square)
ms = x_sq.mean(-1, keepdim=True)
print(f"均方值: {ms}")             # (1+4+9+16)/4 = 7.5

# Step 3: 开方(均方根, Root Mean Square)
rms = ms.sqrt()
print(f"RMS:    {rms}")            # sqrt(7.5) ≈ 2.7386

# Step 4: 除以 RMS(归一化)
x_norm = x / rms
print(f"归一化: {x_norm}")         # [0.3651, 0.7303, 1.0954, 1.4606]

# 验证:归一化后的向量,其平方均值应接近 1
print(f"验证 - 归一化后平方均值: {x_norm.pow(2).mean()}")  # ≈ 1.0

输出:

输入 x: tensor([[1., 2., 3., 4.]])
平方:   tensor([[ 1.,  4.,  9., 16.]])
均方值: tensor([[7.5000]])
RMS:    tensor([[2.7386]])
归一化: tensor([[0.3651, 0.7303, 1.0954, 1.4606]])
验证 - 归一化后平方均值: 1.0

5.4 对比 LLaMA 官方实现

来看 Meta LLaMA 的官方实现,几乎完全一致:

# 来自 Meta LLaMA 官方代码
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

注意 .float().type_as(x) 的技巧:即使模型用 bfloat16/float16 训练,归一化的中间计算始终在 float32 下进行,避免精度损失,最后再转回原始精度。


六、数学层面的深入理解

6.1 RMSNorm vs LayerNorm 的数学关系

LayerNorm 的方差可以展开为:

σ2=1n∑i=1n(xi−μ)2=1n∑i=1nxi2⏟RMS2−μ2 \sigma^2 = \frac{1}{n}\sum_{i=1}^{n}(x_i - \mu)^2 = \underbrace{\frac{1}{n}\sum_{i=1}^{n}x_i^2}_{\text{RMS}^2} - \mu^2 σ2=n1i=1n(xiμ)2=RMS2 n1i=1nxi2μ2

所以:
RMS2(x)=σ2+μ2 \text{RMS}^2(x) = \sigma^2 + \mu^2 RMS2(x)=σ2+μ2

当均值 μ≈0\mu \approx 0μ0 时,RMS(x)≈σ\text{RMS}(x) \approx \sigmaRMS(x)σ,两者几乎等价。而在深度网络中,经过多层变换后,激活值的均值往往确实接近于零(特别是使用了 SwiGLU、GeGLU 等激活函数时),这为 RMSNorm 的有效性提供了理论支撑。

6.2 几何解释

RMSNorm 的本质操作可以理解为:

将向量投影到一个以原点为中心、半径为 n\sqrt{n}n 的超球面上。

归一化前:  向量长度不一,方向各异
              ↓  RMSNorm
归一化后:  所有向量的「RMS 长度」统一为 1
           (即每个分量的平方均值为 1)

七、Pre-Norm vs Post-Norm

在现代大模型中,RMSNorm 通常搭配 Pre-Norm 架构使用:

Post-Norm(GPT-1/原始 Transformer):
x → Attention → Add(x, ·) → LayerNorm → FFN → Add(·, ·) → LayerNorm

Pre-Norm(GPT-2/LLaMA 等现代架构):
x → RMSNorm → Attention → Add(x, ·) → RMSNorm → FFN → Add(x, ·)
# Pre-RMSNorm 的典型用法(伪代码)
class TransformerBlock(nn.Module):
    def forward(self, x):
        # 注意力子层:Pre-Norm + 残差连接
        x = x + self.attention(self.norm1(x))
      
        # FFN 子层:Pre-Norm + 残差连接
        x = x + self.ffn(self.norm2(x))
      
        return x

Pre-Norm + RMSNorm 的组合优势:

  • ✅ 训练更稳定,梯度流更顺畅
  • ✅ 可以支持更大的学习率
  • ✅ 允许更深的网络层数

八、性能对比实测

import torch
import time

def benchmark(fn, x, name, runs=1000):
    # 预热
    for _ in range(100):
        fn(x)
    torch.cuda.synchronize()
  
    start = time.perf_counter()
    for _ in range(runs):
        fn(x)
    torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) / runs * 1e6  # 微秒
    print(f"{name}: {elapsed:.1f} μs")

# 设置
x = torch.randn(32, 2048, 4096, device='cuda', dtype=torch.bfloat16)
ln = torch.nn.LayerNorm(4096, device='cuda', dtype=torch.bfloat16)

# 对比测试
benchmark(lambda x: F.rms_norm(x, (4096,)), x, "RMSNorm (F.rms_norm)")
benchmark(lambda x: ln(x),                   x, "LayerNorm (nn.LayerNorm)")

在典型的 LLM 推理场景中,RMSNorm 比 LayerNorm 快 10%~30%


九、总结

维度 LayerNorm RMSNorm
提出时间 2016 2019
中心化(减均值) ✅ 有 ❌ 无
缩放归一化 ✅ 有 ✅ 有
可学习参数 γ,β\gamma, \betaγ,β 通常仅 γ\gammaγ
归约操作次数 2 次 1 次
计算速度 较慢 更快
模型效果 相当甚至更好
大模型采用率 较少(旧模型) 主流选择

一句话总结

RMSNorm 去掉了 LayerNorm 中「多余的」均值中心化操作,只保留了基于均方根的缩放归一化,在保持效果的同时显著提升了计算效率,已成为当代大语言模型的标准组件。


参考文献

  1. Zhang, B. & Sennrich, R. (2019). “Root Mean Square Layer Normalization”. NeurIPS 2019.
  2. Touvron, H. et al. (2023). “LLaMA: Open and Efficient Foundation Language Models”. Meta AI.
  3. Ba, J. L. et al. (2016). “Layer Normalization”. arXiv:1607.06450.

后记

2026年5月18日于上海,在claude opus 4.6辅助下完成。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐