【Datawhale2603】happy-llm task07 Llama2:实现关键模块
·
系列文章目录
前言

一、RMSNorm
1.1 核心概念
- 在深度神经网络中,数据在层层传递时,数值可能会变得特别大(爆炸) 或者特别小(消失),这会导致模型训练不稳定。
- Normalization(归一化) 的作用,就像是一个 “自动音量调节器”。不管输入的声音(数据)是大是小,它都把它调整到一个合适的音量范围,让后面的模块听起来更舒服。
- 传统的 LayerNorm:先减去平均值(让中心归零),再除以标准差(让方差归一)。
y = x − mean std × γ + β y = \frac{x - \text{mean}}{\text{std}} \times \gamma + \beta y=stdx−mean×γ+β - Llama 用的 RMSNorm:不减去平均值,只除以均方根 (Root Mean Square)。
y = x RMS × γ y = \frac{x}{\text{RMS}} \times \gamma y=RMSx×γ- 好处:少了一步减法运算,速度更快,且在 LLM 中效果往往更好。
1.2 数学公式拆解
RMSNorm ( x ) = x 1 n ∑ i = 1 n x i 2 + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2 + \epsilon}} \cdot \gamma RMSNorm(x)=n1∑i=1nxi2+ϵx⋅γ
- 符号含义:
- x x x:输入的数据(向量)。
- n n n:向量的维度(比如隐藏层大小 hidden_size)。
- ∑ x i 2 \sum x_i^2 ∑xi2:把向量里每个数字平方后加起来。
- … \sqrt{\dots} …:开根号,这就是 RMS (均方根)。
- ϵ \epsilon ϵ (epsilon):一个极小的数(比如 1 e − 6 1e-6 1e−6),防止分母为 0 导致程序崩溃。
- γ \gamma γ (gamma/weight):缩放参数,可学习的权重。归一化后,模型需要自己决定要不要把数据再放大或缩小一点,这个 γ \gamma γ 就是用来学的。
1.3 代码实践
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
"""
dim: 归一化的维度,通常是隐藏层的大小 (hidden_size)
eps: 防止除以零的小数
"""
super().__init__()
# 1. 定义可学习的权重 g (gamma)
# 初始化为 1,形状是 (dim,)
self.weight = nn.Parameter(torch.ones(dim))
# 2. 保存 epsilon
self.eps = eps
def forward(self, x):
"""
x: 输入张量,形状通常是 (batch_size, seq_len, dim)
"""
# --- 核心逻辑开始 ---
# 3. 计算平方和
# x.pow(2) 把所有元素平方
# .mean(-1, keepdim=True) 在最后一个维度求平均,keepdim 保证维度不变,方便广播
variance = x.pow(2).mean(-1, keepdim=True)
# 4. 计算 RMS (均方根)
# torch.rsqrt 是 1 / sqrt(x) 的优化版本,比先 sqrt 再除法更快
# 加上 self.eps 保证数值稳定
rms = torch.rsqrt(variance + self.eps)
# 5. 归一化并缩放
# x * rms 相当于 x / sqrt(variance)
# * self.weight 是乘以可学习参数
output = x * rms * self.weight
# --- 核心逻辑结束 ---
return output
def test_rms_norm():
'''
我们需要验证三件事:
1. **形状对不对?** (输入输出维度一致)
2. **数值对不对?** (RMS 是否接近 1)
3. **梯度能不能传?** (确保能训练)
'''
print("--- 开始测试 RMSNorm ---")
# 1. 准备数据
batch_size = 2
seq_len = 10
dim = 512 # 隐藏层大小
# 创建一个随机输入,模拟神经网络的中间层输出
# 我们故意让数值大一点,看看归一化效果
x = torch.randn(batch_size, seq_len, dim) * 10
# 2. 实例化模型
norm = RMSNorm(dim=dim)
# 3. 前向传播
output = norm(x)
# --- 验证点 1: 形状检查 ---
assert x.shape == output.shape, f"形状错误!输入 {x.shape}, 输出 {output.shape}"
print(f"形状检查通过:{x.shape}")
# --- 验证点 2: 统计特性检查 ---
# RMSNorm 的特性是:输出的均方根 (RMS) 应该接近 1 (在乘以 weight 之前)
# 因为 output = x / rms * weight
# 所以 output / weight 的 rms 应该约等于 1
# 为了简单验证,我们暂时假设 weight 都是 1 (初始化就是 1)
output_rms = torch.sqrt(output.pow(2).mean(-1))
# 允许一点误差,比如 0.99 到 1.01 之间
assert torch.allclose(output_rms, torch.ones_like(output_rms), atol=1e-3), "RMS 值偏离 1 太多"
print(f"统计特性检查通过:输出 RMS 均值约为 {output_rms.mean().item():.4f}")
# --- 验证点 3: 梯度检查 (确保可训练) ---
# 创建一个假损失
loss = output.sum()
# 反向传播
loss.backward()
# 检查 weight 是否有梯度
assert norm.weight.grad is not None, "权重没有梯度!"
print(f"梯度检查通过:weight 梯度形状 {norm.weight.grad.shape}")
print("--- 所有测试通过!RMSNorm 构建成功 ---")
# 运行测试
if __name__ == "__main__":
test_rms_norm()
总结
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)