从0完成轻量级大模型全链路训练与对齐框架——模型搭建
一 整体方案
整个项目的实现路径可以分为以下几个关键阶段:
-
架构基座 (Architecture) :手写基于 PyTorch 的 Decoder-only Transformer,并集成 Flash Attention 以优化显存和计算效率。
-
专家引入 (MoE) :在基础前馈网络中加入混合专家路由(Routing)机制,并设计负载均衡损失函数(Load Balancing Loss)。
二 模型结构
相比于传统的 Transformer 架构,因为我们要完成的是生成式语言模型,去掉Encoder能够让模型更有性价比,只需要让模型关注下一词的生成,所以我们采用的是Llama-style Decoder-only 架构设计。
三 具体实现
3.1 注意力模块
我们将在这个阶段重点打磨两个核心组件:
-
旋转位置编码 (RoPE) :抛弃传统的绝对位置编码,让模型更好地理解 Token 之间的相对距离。
-
因果注意力机制 (Causal Attention) :实现带掩码 (Mask) 的自注意力,并为后续接入 Flash Attention 做好张量维度的准备
3.1.1 旋转位置编码
在我们动手编写 Attention 类的代码之前,先来看看位置信息是如何注入的。传统的 Transformer 会在最开始把位置编码直接加到输入的词向量 (Embedding) 上。但是,RoPE (Rotary Position Embedding) 改变了这个做法,它是在 Attention 机制内部,且在点积计算发生之前应用的。
回顾一下标准的注意力计算核心公式:
Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V
-
Q (Query) 和 K (Key) 负责“匹配”与“打分”:在 Transformer 中,Q 和 K 的点积用于计算两个 Token 之间的相关性得分(Attention Score)。我们希望模型能知道“这两个词离得多远”,所以给 Q 和 K 加上 RoPE 旋转。旋转后,Q 和 K 的点积结果会自然地带上它们之间的相对位置信息。
-
V (Value) 负责“提供内容”:V 包含了 Token 的实际语义。前面的 $softmax$ 步骤算出的 Attention Score 已经包含了位置信息(即“我应该放多少注意力在这个 Token 上”)。最后这一步只是用算好的得分去对 V 进行加权求和。因此,V 本身只需要纯粹地提供它的内容即可,不需要再去扭曲它的向量空间。
打个比方:Q 是你想找什么样的人,K 是别人胸前的名牌,RoPE 决定了你们两人在会场里的物理距离。距离会影响你对这个人的“关注度”(打分),但一旦你决定听他说话,他嘴里说出的具体内容(V)是客观不变的,不需要因为你们的距离而重新编码。
RoPE 的核心在于对向量进行旋转操作,这需要依赖正弦 (Sine) 和余弦 (Cosine) 函数。为了提高训练效率,我们通常会在模型初始化时预先计算 (Precompute) 好这些角度频率,而不是每次前向传播都重新算一遍。
实际代码如下:
# 生成旋转矩阵所需的正弦和余弦值
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
orig_max, factor, beta_fast, beta_slow, attn_factor = (
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
)
if end / orig_max > 1.0:
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
freqs = freqs * (1 - ramp + ramp / factor)
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
return freqs_cos, freqs_sin
# 执行旋转变换
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
return q_embed, k_embed
3.1.2 因果注意力机制
第一步我们就需要把刚刚实现的RoPE工具添加到我们的注意力机制中,特别注意,我们只需要为q,k添加位置信息。
cos, sin = position_embeddings # 获取旋转位置嵌入
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin) # 应用旋转位置嵌入
在传统的多头注意力(MHA)中,Query (Q)、Key (K) 和 Value (V) 的头数量是1:1:1完全相等的。每个 Q 头都有自己专属的 K 和 V 头。
但在分组查询注意力(GQA )中,为了节省显存和计算量,Query 的头数量是多于 Key/Value 的头数量的。相当于把 Query 分成了几个“小组”,每个小组内的 Q 头共享同一个 K 和 V 头。
回到代码中的 self.n_rep = self.n_local_heads // self.n_local_kv_heads:
-
比例关系: 这里的除法算出的就是 Q 头数量和 KV 头数量的倍数。比如模型配置了 32 个 Q 头,8 个 KV 头,那么
n_rep就是 4。意思是每 4 个 Q 头组成一个小组,共享 1 组 K 和 V。
在进行注意力分数计算时,矩阵的维度必须是对齐的。既然 K 和 V 的头数变少了,为了能和 Q 进行运算,我们就必须把 K 和 V 复制(Repeat)相应的倍数,让它们的头数量在计算瞬间“膨胀”到和 Q 一样多。代码最上方的 repeat_kv 函数就是专门做这件事的!
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
KV 头重复函数,将 KV 头的数量扩展到与 Query 头数量一致
"""
bs, slen, num_key_value_heads, head_dim = x.shape
if n_rep == 1: return x
return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))
为了加速推理,大模型会把历史已经计算过的词的 K 和 V 保存在显存里,这就是大名鼎鼎的 KV Cache(键值缓存)。
但随着上下文越来越长,这些保存下来的 K 和 V 张量会吃掉极大的显存空间。 这就是 GQA 发挥的地方: 因为 Key 和 Value 的头数量比 Query 少(比如 Q 有 32 个头,KV 只有 8 个头),我们需要保存在显存里的 KV Cache 体积直接被压缩了数倍! 这样不仅节省了显存,也加快了数据读取速度。所以我们需要完成KV Cache 的拼接与更新。
if past_key_value is not None: # 如果提供了有历史KV,则将其与当前计算的 KV 拼接起来
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None # 将past_key_value中的KV拼接到当前KV中
在 PyTorch 2.0+ 环境下,官方提供了原生的 Flash Attention 接口torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)。这个函数不仅会在底层自动启用 Flash Attention 优化显存,is_causal=True 参数还会自动生成下因果注意力掩码,确保模型不会“偷看”到未来的 Token。
当然,仅仅靠Flash Attention是不安全的,我们还需要普通的Attention计算。以下是整个注意力模块的代码:
class Attention(nn.Module):
def __init__(self, config: PocketLLMConfig):
super().__init__()
self.num_key_value_heads = config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads # Key/Value 头数量可以独立于 Query 头数量配置,以支持 GQA
self.n_local_heads = config.num_attention_heads # 实际使用的 Query 头数量
self.n_local_kv_heads = self.num_key_value_heads # 实际使用的 Key/Value 头数量
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 每个 Key/Value 头需要被重复多少次以匹配 Query 头数量,必须是整数,否则会在 forward 中报错
self.head_dim = config.head_dim # 每个注意力头的维度,通常是 hidden_size // num_attention_heads
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.is_causal = True
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.dropout = config.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and config.flash_attn
def forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
bsz, seq_len, _ = x.shape # 批量大小,序列长度,嵌入维度
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) # 计算Q、K、V
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) # 拆成多个注意力头
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = self.q_norm(xq), self.k_norm(xk) # 应用RMSNorm归一化
cos, sin = position_embeddings # 获取旋转位置嵌入
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin) # 应用旋转位置嵌入
if past_key_value is not None: # 如果提供了有历史KV,则将其与当前计算的 KV 拼接起来
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None # 将past_key_value中的KV拼接到当前KV中
xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)) # 调整维度以适配注意力计算,重复KV以匹配Query头数量(GQA)
if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
# 使用 PyTorch 的 Flash Attention 计算注意力输出,自动处理因果掩码和 dropout
output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal)
else:
# 公式:(Q * K ^ T) / sqrt(d_k)
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) # 计算注意力分数,缩放因子为头维度的平方根
if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1) # 添加因果掩码 防止偷看未来
if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9 # 添加注意力掩码 让模型不关注无用的padding
output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv # 计算注意力权重并应用于V,得到注意力输出
output = output.transpose(1, 2).reshape(bsz, seq_len, -1) # 将多头注意力输出重新组合成原始的嵌入维度
output = self.resid_dropout(self.o_proj(output)) # 最后通过输出投影层,并应用残差连接前的 dropout
return output, past_kv
3.2 FFN模块
普通的前馈层的实现比较简单,通过一个门控机制和两个线性变换,包括一个可选的激活函数来实现,这里我们默认使用的是 SiLU 激活函数,相对于ReLU, SiLU ,优势在于:
-
更强的表达能力:通过增加一个额外的线性层,模型学习参数的自由度更高。
-
平滑性:SiLU 激活函数比 ReLU 更平滑,有助于梯度流动和模型收敛。
-
实践证明:在 Llama 等大模型中,这种结构表现出了明显的性能提升。
所以随着网络加深,ReLU 的硬截断会导致信息丢失严重,SiLU 的平滑特性有助于深层信号的传播。
以下是完整代码:
class FeedForward(nn.Module):
# 公式:FFN(x) = W_down * (act(W_gate * x) ⊙ W_up * x),其中 ⊙ 表示逐元素乘法,W_gate、W_up 和 W_down 分别是三个线性变换矩阵,act 是激活函数
def __init__(self, config: PocketLLMConfig, intermediate_size: int = None):
super().__init__()
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
3.2 MoE模块
我们的架构的一大亮点是引入了 MoE(混合专家网络)。在 MoE 中,我们将一个庞大的 FFN 替换为多个小型的 FFN(即“专家”),并且需要一个“路由器(Router)”来决定当前的 Token 应该交给哪几个专家处理。
我们来拆解一下路由器 的工作流:
-
输入:每个 Token 此时都是一个特征向量(维度大小为
hidden_dim)。 -
打分:我们用一个
nn.Linear(hidden_dim, num_experts)对这个向量进行映射,直接输出该 Token 对应每个专家的原始得分(Logits)。 -
概率化:将这些得分通过
Softmax函数,转换成加和为 1 的概率值。
实现代码非常简单
class MoERouter(nn.Module):
def __init__(self, hidden_dim: int, num_experts: int):
super().__init__()
# 路由器核心:将 Token 维度映射到专家数量维度
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
def forward(self, x: torch.Tensor):
# x 形状: [batch_size, seq_len, hidden_dim]
logits = self.gate(x)
# 将原始打分转化为路由概率
routing_probs = F.softmax(logits, dim=-1)
return routing_probs, logits
为了实现真正的“轻量化”和“稀疏计算”,我们不能让一个 Token 把所有专家都跑一遍。通常,我们只会选择概率最高的前 K 个专家(比如在你的项目中,可能每次只激活 2 个专家)。
基于我们刚刚算出的 routing_probs(或者 logits),我们需要挑出得分最高的前 2 个值以及它们对应的专家编号(索引)。
既然我们有了路由器,下一步就是构建“专家”本身,并将它们组合起来。每个专家其实就是一个前馈神经网络 (FFN)。
for i, expert in enumerate(self.experts):
mask = (topk_idx == i)
if mask.any():
token_idx = mask.any(dim=-1).nonzero().flatten()
weight = topk_weight[mask].view(-1, 1)
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
elif self.training:
y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())
MoE 的“马太效应”:
在训练初期,路由器的权重是随机初始化的。假设某个专家(比如“专家 A”)运气好,稍微比其他专家表现得好一点点。
-
路由器为了降低当前的预测误差,就会倾向于把更多的 Token 扔给“专家 A”。
-
“专家 A”得到的训练数据越多,它的参数更新就越充分,能力就越强。
-
下一次路由时,路由器发现“专家 A”更靠谱了,于是把所有的 Token 都给了它。
最终结果就是:旱的旱死,涝的涝死。专家 A 过拟合了,而其他专家根本没有被激活过,相当于参数被白白浪费了,这就违背了我们引入 MoE 来增加模型容量的初衷。
为了防止路由器偷懒,我们在总的损失函数(Loss)旁边,额外加上一个负载均衡损失。它的核心思想是:惩罚那些把 Token 集中分配给少数专家的行为。
我们可以单独来计算这个 Loss。
if self.training and self.config.router_aux_loss_coef > 0:
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
else:
self.aux_loss = scores.new_zeros(1).squeeze()
以下是整体代码:
# MOE专家类
class MOEFeedForward(nn.Module):
def __init__(self, config: PocketLLMConfig):
super().__init__()
self.config = config
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
batch_size, seq_len, hidden_dim = x.shape
x_flat = x.view(-1, hidden_dim) # 将输入展平为 (batch_size * seq_len, hidden_dim),以便对每个 token 进行独立的专家路由和计算
scores = F.softmax(self.gate(x_flat), dim=-1) # 计算每个 token 路由到各个专家的概率分布,得到 (batch_size * seq_len, num_experts) 的分数矩阵
topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False) # 对每个 token 选择 top-k 个专家,得到对应的权重和索引,形状为 (batch_size * seq_len, num_experts_per_tok)
# 可选地对 top-k 权重进行重新归一化,使其和为 1,增加数值稳定性
if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20)
y = torch.zeros_like(x_flat)
# 遍历每个专家,检查哪些 token 被路由到该专家,并将其计算结果加权累加到输出 y 中
for i, expert in enumerate(self.experts):
mask = (topk_idx == i)
if mask.any():
token_idx = mask.any(dim=-1).nonzero().flatten()
weight = topk_weight[mask].view(-1, 1)
y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
elif self.training:
y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())
# 计算路由均衡的辅助损失,鼓励模型在专家之间分配负载,防止某些专家过载而其他专家闲置
if self.training and self.config.router_aux_loss_coef > 0:
load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
else:
self.aux_loss = scores.new_zeros(1).squeeze()
return y.view(batch_size, seq_len, hidden_dim)
3.3 TransformerBlock
在早期的 Transformer(Post-Norm 架构)中,残差连接是这样做的:输出 = Norm(输入 + 子层(输入))。这就导致信号在通过每一层后都会被重新归一化,网络一深,靠近底层的梯度就很容易消失或者爆炸,通常需要很小心的“学习率预热 (Warm-up)”才能训练起来。
而现代大模型普遍采用的 Pre-Norm 架构是这样的:输出 = 输入 + 子层(Norm(输入))。 主干道上的“输入”信号没有经过任何阻拦(不被 Norm 截断),直接一路流向深层。这就好比修建了一条畅通无阻的高速公路,梯度回传非常顺畅,极大地提升了训练的稳定性。
同时,为了极致的计算效率,我们还会把传统的 LayerNorm 替换为 RMSNorm (均方根归一化)。它去掉了均值计算,只计算方差,效果一样好但跑得更快。
class PocketLLMBlock(nn.Module):
def __init__(self, layer_id: int, config: PocketLLMConfig):
super().__init__()
self.self_attn = Attention(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 在注意力输入前添加一个 LayerNorm
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 在注意力输出后添加一个 LayerNorm
self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config) # 根据配置选择使用普通前馈层还是 MoE 前馈层
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
residual = hidden_states
# 执行自注意力计算,输入经过 LayerNorm 归一化,并传入位置嵌入、历史 KV 和注意力掩码等信息,得到新的隐藏状态和当前的 KV 用于缓存
hidden_states, present_key_value = self.self_attn(
self.input_layernorm(hidden_states), position_embeddings,
past_key_value, use_cache, attention_mask
)
hidden_states += residual # 添加残差连接,将注意力输出与输入相加
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states)) # 经过另一个 LayerNorm 归一化后,输入前馈层计算,并添加残差连接
return hidden_states, present_key_value
3.5 模型总体架构
总模型 PocketLLM 的职责就像是一个车间主任:它负责接收最初的输入(Token ID),将其转化为向量,然后依次让这些向量穿过所有的 Transformer 层,最后通过一个线性分类器(LM Head)预测出词表中每个词的概率。同时,它还要负责收集各个 MoE 层的“路由打分”,以便计算我们刚才提到的负载均衡损失。
# PocketLLM 主干模型
class PocketLLMModel(nn.Module):
def __init__(self, config: PocketLLMConfig):
super().__init__()
self.config = config
self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([PocketLLMBlock(l, config) for l in range(self.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim, end=config.max_position_embeddings, rope_base=config.rope_theta, rope_scaling=config.rope_scaling)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, **kwargs):
batch_size, seq_length = input_ids.shape
if hasattr(past_key_values, 'layers'): past_key_values = None
past_key_values = past_key_values or [None] * len(self.layers) # 初始化 past_key_values 为 None,用于存储历史 KV
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0 # 计算起始位置,即历史 KV 中最后一个 token 的位置
hidden_states = self.dropout(self.embed_tokens(input_ids)) # 输入经过嵌入层,得到隐藏状态
position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length]) # 获取位置嵌入,用于计算 Rope
presents = [] # 初始化 presents 列表,用于存储每个 layer 的当前 KV
# 遍历每个 layer,执行自注意力计算和前馈计算,并添加残差连接
for layer, past_key_value in zip(self.layers, past_key_values):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.append(present)
hidden_states = self.norm(hidden_states) # 添加 LayerNorm 归一化
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze()) # 计算 MoE 的辅助损失
return hidden_states, presents, aux_loss # 返回隐藏状态、当前 KV 和 MoE 的辅助损失
讲解一些其中的我自己的理解:
1. start_pos
这个字段的意义是通过检查 KV Cache 的长度,确定当前输入在序列中的偏移量。
我们知道,大模型生成文本是一个 Token 一个 Token 往后跳的。
-
没有
start_pos时:如果你已经写了 100 个字,生成第 101 个字时,模型需要把前 100 个字重新算一遍,算力浪费随字数增加呈几何倍增长。 -
有了
start_pos:模型会查一下缓存(past_key_values)。如果缓存里已经存了 100 个向量,start_pos就是 100。模型这次只需要计算当前的第 101 个词,然后把它拼接到旧缓存后面。
特别是对于旋转位置编码 (RoPE) 的精确对齐是非常重要的。
位置编码必须是绝对唯一的。如果模型正在生成第 101 个词,它必须使用“位置 101”对应的正弦/余弦频率。如果没有 start_pos 指明当前偏移量,模型会错误地从“位置 0”开始计算编码,导致模型分不清“我是谁”和“谁是我”(因为词语的顺序乱了)。
2. 为什么将离散的 input_ids 映射为稠密向量 hidden_states
input_ids 本质上是整数索引(例如:[101, 234, 55]),就像是查字典时的页码。如果不做映射(Embedding),会有以下三个致命问题:
A. 语义空间缺失(语义相关性)
-
离散编码(One-hot):在数字层面,"猫" (ID: 10) 和 "狗" (ID: 11) 的距离,跟 "猫" (ID: 10) 和 "手机" (ID: 999) 的距离没什么区别。数字大小不代表语义远近。
-
稠密向量(Embedding):通过映射,"猫" 和 "狗" 在高维空间中的**向量距离(余弦相似度)**会非常近,而与 "手机" 非常远。这种“位置即语义”的特性让神经网络能够理解词与词之间的联系。
B. 维度的“死亡”与“重生”
-
如果使用 One-hot 编码,词表有 10 万个词,每个词就要用 10 万维的向量表示,其中 99,999 维都是 0。这极其浪费显存,且无法进行有效的矩阵运算。
-
稠密向量通常只有 4096 维(Llama 2-7B 级别),每个维度都是一个浮点数。这不仅压缩了空间,还让每个维度都能承载信息(比如某一维可能暗含“词性”,另一维暗含“情感”)。
C. 梯度回传与可学习性
-
整数(ID)是不可导的。神经网络通过反向传播来学习,需要所有参数都是连续、可微的。
-
映射后的
hidden_states是浮点数矩阵。在训练过程中,模型可以根据损失函数微调这些向量的数值。随着训练进行,原本随机的向量会逐渐在空间中自动归类,形成逻辑严密的知识图谱。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)