系列文章目录



前言

在这里插入图片描述


一、RMSNorm

1.1 核心概念

  • 在深度神经网络中,数据在层层传递时,数值可能会变得特别大(爆炸) 或者特别小(消失),这会导致模型训练不稳定。
  • Normalization(归一化) 的作用,就像是一个 “自动音量调节器”。不管输入的声音(数据)是大是小,它都把它调整到一个合适的音量范围,让后面的模块听起来更舒服。
  • 传统的 LayerNorm:先减去平均值(让中心归零),再除以标准差(让方差归一)。
    y = x − mean std × γ + β y = \frac{x - \text{mean}}{\text{std}} \times \gamma + \beta y=stdxmean×γ+β
  • Llama 用的 RMSNorm不减去平均值,只除以均方根 (Root Mean Square)
    y = x RMS × γ y = \frac{x}{\text{RMS}} \times \gamma y=RMSx×γ
    • 好处:少了一步减法运算,速度更快,且在 LLM 中效果往往更好。

1.2 数学公式拆解

RMSNorm ( x ) = x 1 n ∑ i = 1 n x i 2 + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2 + \epsilon}} \cdot \gamma RMSNorm(x)=n1i=1nxi2+ϵ xγ

  • 符号含义:
  1. x x x:输入的数据(向量)。
  2. n n n:向量的维度(比如隐藏层大小 hidden_size)。
  3. ∑ x i 2 \sum x_i^2 xi2:把向量里每个数字平方后加起来。
  4. … \sqrt{\dots} :开根号,这就是 RMS (均方根)
  5. ϵ \epsilon ϵ (epsilon):一个极小的数(比如 1 e − 6 1e-6 1e6),防止分母为 0 导致程序崩溃。
  6. γ \gamma γ (gamma/weight):缩放参数,可学习的权重。归一化后,模型需要自己决定要不要把数据再放大或缩小一点,这个 γ \gamma γ 就是用来学的。

1.3 代码实践

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        """
        dim: 归一化的维度,通常是隐藏层的大小 (hidden_size)
        eps: 防止除以零的小数
        """
        super().__init__()
        # 1. 定义可学习的权重 g (gamma)
        # 初始化为 1,形状是 (dim,)
        self.weight = nn.Parameter(torch.ones(dim))
        
        # 2. 保存 epsilon
        self.eps = eps

    def forward(self, x):
        """
        x: 输入张量,形状通常是 (batch_size, seq_len, dim)
        """
        # --- 核心逻辑开始 ---
        
        # 3. 计算平方和
        # x.pow(2) 把所有元素平方
        # .mean(-1, keepdim=True) 在最后一个维度求平均,keepdim 保证维度不变,方便广播
        variance = x.pow(2).mean(-1, keepdim=True)
        
        # 4. 计算 RMS (均方根)
        # torch.rsqrt 是 1 / sqrt(x) 的优化版本,比先 sqrt 再除法更快
        # 加上 self.eps 保证数值稳定
        rms = torch.rsqrt(variance + self.eps)
        
        # 5. 归一化并缩放
        # x * rms 相当于 x / sqrt(variance)
        # * self.weight 是乘以可学习参数
        output = x * rms * self.weight
        
        # --- 核心逻辑结束 ---
        
        return output



def test_rms_norm():
    '''
        我们需要验证三件事:
    1.  **形状对不对?** (输入输出维度一致)
    2.  **数值对不对?** (RMS 是否接近 1)
    3.  **梯度能不能传?** (确保能训练)
    '''
    print("--- 开始测试 RMSNorm ---")
    
    # 1. 准备数据
    batch_size = 2
    seq_len = 10
    dim = 512  # 隐藏层大小
    
    # 创建一个随机输入,模拟神经网络的中间层输出
    # 我们故意让数值大一点,看看归一化效果
    x = torch.randn(batch_size, seq_len, dim) * 10 
    
    # 2. 实例化模型
    norm = RMSNorm(dim=dim)
    
    # 3. 前向传播
    output = norm(x)
    
    # --- 验证点 1: 形状检查 ---
    assert x.shape == output.shape, f"形状错误!输入 {x.shape}, 输出 {output.shape}"
    print(f"形状检查通过:{x.shape}")
    
    # --- 验证点 2: 统计特性检查 ---
    # RMSNorm 的特性是:输出的均方根 (RMS) 应该接近 1 (在乘以 weight 之前)
    # 因为 output = x / rms * weight
    # 所以 output / weight 的 rms 应该约等于 1
    
    # 为了简单验证,我们暂时假设 weight 都是 1 (初始化就是 1)
    output_rms = torch.sqrt(output.pow(2).mean(-1))
    
    # 允许一点误差,比如 0.99 到 1.01 之间
    assert torch.allclose(output_rms, torch.ones_like(output_rms), atol=1e-3), "RMS 值偏离 1 太多"
    print(f"统计特性检查通过:输出 RMS 均值约为 {output_rms.mean().item():.4f}")
    
    # --- 验证点 3: 梯度检查 (确保可训练) ---
    # 创建一个假损失
    loss = output.sum()
    # 反向传播
    loss.backward()
    
    # 检查 weight 是否有梯度
    assert norm.weight.grad is not None, "权重没有梯度!"
    print(f"梯度检查通过:weight 梯度形状 {norm.weight.grad.shape}")
    
    print("--- 所有测试通过!RMSNorm 构建成功 ---")

# 运行测试
if __name__ == "__main__":
    test_rms_norm()

总结

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐