一 整体方案

整个项目的实现路径可以分为以下几个关键阶段:

  1. 架构基座 (Architecture) :手写基于 PyTorch 的 Decoder-only Transformer,并集成 Flash Attention 以优化显存和计算效率。

  2. 专家引入 (MoE) :在基础前馈网络中加入混合专家路由(Routing)机制,并设计负载均衡损失函数(Load Balancing Loss)。

二 模型结构

相比于传统的 Transformer 架构,因为我们要完成的是生成式语言模型,去掉Encoder能够让模型更有性价比,只需要让模型关注下一词的生成,所以我们采用的是Llama-style Decoder-only 架构设计。

三 具体实现

3.1 注意力模块

我们将在这个阶段重点打磨两个核心组件:

  • 旋转位置编码 (RoPE) :抛弃传统的绝对位置编码,让模型更好地理解 Token 之间的相对距离。

  • 因果注意力机制 (Causal Attention) :实现带掩码 (Mask) 的自注意力,并为后续接入 Flash Attention 做好张量维度的准备

3.1.1 旋转位置编码

在我们动手编写 Attention 类的代码之前,先来看看位置信息是如何注入的。传统的 Transformer 会在最开始把位置编码直接加到输入的词向量 (Embedding) 上。但是,RoPE (Rotary Position Embedding) 改变了这个做法,它是在 Attention 机制内部,且在点积计算发生之前应用的。

回顾一下标准的注意力计算核心公式:

Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V

  • Q (Query) 和 K (Key) 负责“匹配”与“打分”:在 Transformer 中,Q 和 K 的点积用于计算两个 Token 之间的相关性得分(Attention Score)。我们希望模型能知道“这两个词离得多远”,所以给 Q 和 K 加上 RoPE 旋转。旋转后,Q 和 K 的点积结果会自然地带上它们之间的相对位置信息

  • V (Value) 负责“提供内容”:V 包含了 Token 的实际语义。前面的 $softmax$ 步骤算出的 Attention Score 已经包含了位置信息(即“我应该放多少注意力在这个 Token 上”)。最后这一步只是用算好的得分去对 V 进行加权求和。因此,V 本身只需要纯粹地提供它的内容即可,不需要再去扭曲它的向量空间。

打个比方:Q 是你想找什么样的人,K 是别人胸前的名牌,RoPE 决定了你们两人在会场里的物理距离。距离会影响你对这个人的“关注度”(打分),但一旦你决定听他说话,他嘴里说出的具体内容(V)是客观不变的,不需要因为你们的距离而重新编码。

RoPE 的核心在于对向量进行旋转操作,这需要依赖正弦 (Sine) 和余弦 (Cosine) 函数。为了提高训练效率,我们通常会在模型初始化时预先计算 (Precompute) 好这些角度频率,而不是每次前向传播都重新算一遍。

实际代码如下:

# 生成旋转矩阵所需的正弦和余弦值
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6, rope_scaling: dict = None):
    freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
    if rope_scaling is not None: # YaRN: f'(i) = f(i)((1-γ) + γ/s), where γ∈[0,1] is linear ramp
        orig_max, factor, beta_fast, beta_slow, attn_factor = (
            rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 16),
            rope_scaling.get("beta_fast", 32.0), rope_scaling.get("beta_slow", 1.0), rope_scaling.get("attention_factor", 1.0)
        )
        if end / orig_max > 1.0:
            inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
            low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
            ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
            freqs = freqs * (1 - ramp + ramp / factor)
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
    return freqs_cos, freqs_sin

# 执行旋转变换
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    def rotate_half(x): return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
    q_embed = ((q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))).to(q.dtype)
    k_embed = ((k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))).to(k.dtype)
    return q_embed, k_embed

3.1.2 因果注意力机制

第一步我们就需要把刚刚实现的RoPE工具添加到我们的注意力机制中,特别注意,我们只需要为q,k添加位置信息。

cos, sin = position_embeddings  # 获取旋转位置嵌入
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)  # 应用旋转位置嵌入

在传统的多头注意力(MHA)中,Query (Q)、Key (K) 和 Value (V) 的头数量是1:1:1完全相等的。每个 Q 头都有自己专属的 K 和 V 头。

但在分组查询注意力(GQA )中,为了节省显存和计算量,Query 的头数量是多于 Key/Value 的头数量的。相当于把 Query 分成了几个“小组”,每个小组内的 Q 头共享同一个 K 和 V 头。

回到代码中的 self.n_rep = self.n_local_heads // self.n_local_kv_heads

  • 比例关系: 这里的除法算出的就是 Q 头数量和 KV 头数量的倍数。比如模型配置了 32 个 Q 头,8 个 KV 头,那么 n_rep 就是 4。意思是每 4 个 Q 头组成一个小组,共享 1 组 K 和 V。

在进行注意力分数计算时,矩阵的维度必须是对齐的。既然 K 和 V 的头数变少了,为了能和 Q 进行运算,我们就必须把 K 和 V 复制(Repeat)相应的倍数,让它们的头数量在计算瞬间“膨胀”到和 Q 一样多。代码最上方的 repeat_kv 函数就是专门做这件事的!

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
        KV 头重复函数,将 KV 头的数量扩展到与 Query 头数量一致
    """
    bs, slen, num_key_value_heads, head_dim = x.shape
    if n_rep == 1: return x
    return (x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim))

为了加速推理,大模型会把历史已经计算过的词的 K 和 V 保存在显存里,这就是大名鼎鼎的 KV Cache(键值缓存)

但随着上下文越来越长,这些保存下来的 K 和 V 张量会吃掉极大的显存空间。 这就是 GQA 发挥的地方: 因为 Key 和 Value 的头数量比 Query 少(比如 Q 有 32 个头,KV 只有 8 个头),我们需要保存在显存里的 KV Cache 体积直接被压缩了数倍! 这样不仅节省了显存,也加快了数据读取速度。所以我们需要完成KV Cache 的拼接与更新。

if past_key_value is not None:  # 如果提供了有历史KV,则将其与当前计算的 KV 拼接起来
    xk = torch.cat([past_key_value[0], xk], dim=1) 
    xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None  # 将past_key_value中的KV拼接到当前KV中

在 PyTorch 2.0+ 环境下,官方提供了原生的 Flash Attention 接口torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)。这个函数不仅会在底层自动启用 Flash Attention 优化显存,is_causal=True 参数还会自动生成下因果注意力掩码,确保模型不会“偷看”到未来的 Token。

当然,仅仅靠Flash Attention是不安全的,我们还需要普通的Attention计算。以下是整个注意力模块的代码:

class Attention(nn.Module):
    def __init__(self, config: PocketLLMConfig):
        super().__init__()
        self.num_key_value_heads = config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads # Key/Value 头数量可以独立于 Query 头数量配置,以支持 GQA
        self.n_local_heads = config.num_attention_heads # 实际使用的 Query 头数量
        self.n_local_kv_heads = self.num_key_value_heads # 实际使用的 Key/Value 头数量
        self.n_rep = self.n_local_heads // self.n_local_kv_heads # 每个 Key/Value 头需要被重复多少次以匹配 Query 头数量,必须是整数,否则会在 forward 中报错
        self.head_dim = config.head_dim # 每个注意力头的维度,通常是 hidden_size // num_attention_heads
        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
        self.is_causal = True
        self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.dropout = config.dropout
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and config.flash_attn

    def forward(self, x, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
        bsz, seq_len, _ = x.shape  # 批量大小,序列长度,嵌入维度
        xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)  # 计算Q、K、V

        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)  # 拆成多个注意力头
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)

        xq, xk = self.q_norm(xq), self.k_norm(xk)  # 应用RMSNorm归一化
        cos, sin = position_embeddings  # 获取旋转位置嵌入
        xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)  # 应用旋转位置嵌入

        if past_key_value is not None:  # 如果提供了有历史KV,则将其与当前计算的 KV 拼接起来
            xk = torch.cat([past_key_value[0], xk], dim=1)
            xv = torch.cat([past_key_value[1], xv], dim=1)
        past_kv = (xk, xv) if use_cache else None  # 将past_key_value中的KV拼接到当前KV中
        xq, xk, xv = (xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2),
                      repeat_kv(xv, self.n_rep).transpose(1, 2))  # 调整维度以适配注意力计算,重复KV以匹配Query头数量(GQA)

        if self.flash and (seq_len > 1) and (not self.is_causal or past_key_value is None) and (attention_mask is None or torch.all(attention_mask == 1)):
            # 使用 PyTorch 的 Flash Attention 计算注意力输出,自动处理因果掩码和 dropout
            output = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal)
        else:
            # 公式:(Q * K ^ T) / sqrt(d_k)
            scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)  # 计算注意力分数,缩放因子为头维度的平方根
            if self.is_causal: scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf"), device=scores.device).triu(1) # 添加因果掩码 防止偷看未来
            if attention_mask is not None: scores += (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -1e9 # 添加注意力掩码 让模型不关注无用的padding
            output = self.attn_dropout(F.softmax(scores.float(), dim=-1).type_as(xq)) @ xv # 计算注意力权重并应用于V,得到注意力输出

        output = output.transpose(1, 2).reshape(bsz, seq_len, -1) # 将多头注意力输出重新组合成原始的嵌入维度
        output = self.resid_dropout(self.o_proj(output)) # 最后通过输出投影层,并应用残差连接前的 dropout
        return output, past_kv

3.2 FFN模块

普通的前馈层的实现比较简单,通过一个门控机制和两个线性变换,包括一个可选的激活函数来实现,这里我们默认使用的是 SiLU 激活函数,相对于ReLU, SiLU ,优势在于

  • 更强的表达能力:通过增加一个额外的线性层,模型学习参数的自由度更高。

  • 平滑性:SiLU 激活函数比 ReLU 更平滑,有助于梯度流动和模型收敛。

  • 实践证明:在 Llama 等大模型中,这种结构表现出了明显的性能提升。

所以随着网络加深,ReLU 的硬截断会导致信息丢失严重,SiLU 的平滑特性有助于深层信号的传播。

以下是完整代码:

class FeedForward(nn.Module):
# 公式:FFN(x) = W_down * (act(W_gate * x) ⊙ W_up * x),其中 ⊙ 表示逐元素乘法,W_gate、W_up 和 W_down 分别是三个线性变换矩阵,act 是激活函数
    def __init__(self, config: PocketLLMConfig, intermediate_size: int = None):
        super().__init__()
        intermediate_size = intermediate_size or config.intermediate_size 
        self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

3.2 MoE模块

我们的架构的一大亮点是引入了 MoE(混合专家网络)。在 MoE 中,我们将一个庞大的 FFN 替换为多个小型的 FFN(即“专家”),并且需要一个“路由器(Router)”来决定当前的 Token 应该交给哪几个专家处理。

我们来拆解一下路由器  的工作流:

  1. 输入:每个 Token 此时都是一个特征向量(维度大小为 hidden_dim)。

  2. 打分:我们用一个 nn.Linear(hidden_dim, num_experts) 对这个向量进行映射,直接输出该 Token 对应每个专家的原始得分(Logits)。

  3. 概率化:将这些得分通过 Softmax 函数,转换成加和为 1 的概率值。

实现代码非常简单

class MoERouter(nn.Module):
    def __init__(self, hidden_dim: int, num_experts: int):
        super().__init__()
        # 路由器核心:将 Token 维度映射到专家数量维度
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)

    def forward(self, x: torch.Tensor):
        # x 形状: [batch_size, seq_len, hidden_dim]
        logits = self.gate(x) 
        
        # 将原始打分转化为路由概率
        routing_probs = F.softmax(logits, dim=-1) 
        return routing_probs, logits

为了实现真正的“轻量化”和“稀疏计算”,我们不能让一个 Token 把所有专家都跑一遍。通常,我们只会选择概率最高的前 K 个专家(比如在你的项目中,可能每次只激活 2 个专家)。

基于我们刚刚算出的 routing_probs(或者 logits),我们需要挑出得分最高的前 2 个值以及它们对应的专家编号(索引)。

既然我们有了路由器,下一步就是构建“专家”本身,并将它们组合起来。每个专家其实就是一个前馈神经网络 (FFN)。

for i, expert in enumerate(self.experts):
    mask = (topk_idx == i)
    if mask.any():
        token_idx = mask.any(dim=-1).nonzero().flatten()
        weight = topk_weight[mask].view(-1, 1)
        y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
    elif self.training:
        y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())

 MoE 的“马太效应”:

在训练初期,路由器的权重是随机初始化的。假设某个专家(比如“专家 A”)运气好,稍微比其他专家表现得好一点点。

  • 路由器为了降低当前的预测误差,就会倾向于把更多的 Token 扔给“专家 A”。

  • “专家 A”得到的训练数据越多,它的参数更新就越充分,能力就越强。

  • 下一次路由时,路由器发现“专家 A”更靠谱了,于是把所有的 Token 都给了它。

最终结果就是:旱的旱死,涝的涝死。专家 A 过拟合了,而其他专家根本没有被激活过,相当于参数被白白浪费了,这就违背了我们引入 MoE 来增加模型容量的初衷。

为了防止路由器偷懒,我们在总的损失函数(Loss)旁边,额外加上一个负载均衡损失。它的核心思想是:惩罚那些把 Token 集中分配给少数专家的行为。

我们可以单独来计算这个 Loss。

if self.training and self.config.router_aux_loss_coef > 0:
    load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
    self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
else:
    self.aux_loss = scores.new_zeros(1).squeeze()

以下是整体代码:

# MOE专家类
class MOEFeedForward(nn.Module):
    def __init__(self, config: PocketLLMConfig):
        super().__init__()
        self.config = config
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
        self.experts = nn.ModuleList([FeedForward(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)])
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.shape
        x_flat = x.view(-1, hidden_dim) # 将输入展平为 (batch_size * seq_len, hidden_dim),以便对每个 token 进行独立的专家路由和计算
        scores = F.softmax(self.gate(x_flat), dim=-1) # 计算每个 token 路由到各个专家的概率分布,得到 (batch_size * seq_len, num_experts) 的分数矩阵
        topk_weight, topk_idx = torch.topk(scores, k=self.config.num_experts_per_tok, dim=-1, sorted=False) # 对每个 token 选择 top-k 个专家,得到对应的权重和索引,形状为 (batch_size * seq_len, num_experts_per_tok)

        # 可选地对 top-k 权重进行重新归一化,使其和为 1,增加数值稳定性
        if self.config.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-20) 
        y = torch.zeros_like(x_flat)

        # 遍历每个专家,检查哪些 token 被路由到该专家,并将其计算结果加权累加到输出 y 中
        for i, expert in enumerate(self.experts):
            mask = (topk_idx == i)
            if mask.any():
                token_idx = mask.any(dim=-1).nonzero().flatten()
                weight = topk_weight[mask].view(-1, 1)
                y.index_add_(0, token_idx, (expert(x_flat[token_idx]) * weight).to(y.dtype))
            elif self.training:
                y[0, 0] += 0 * sum(p.sum() for p in expert.parameters())

        # 计算路由均衡的辅助损失,鼓励模型在专家之间分配负载,防止某些专家过载而其他专家闲置
        if self.training and self.config.router_aux_loss_coef > 0:
            load = F.one_hot(topk_idx, self.config.num_experts).float().mean(0)
            self.aux_loss = (load * scores.mean(0)).sum() * self.config.num_experts * self.config.router_aux_loss_coef
        else:
            self.aux_loss = scores.new_zeros(1).squeeze()

        return y.view(batch_size, seq_len, hidden_dim)

3.3 TransformerBlock

在早期的 Transformer(Post-Norm 架构)中,残差连接是这样做的:输出 = Norm(输入 + 子层(输入))。这就导致信号在通过每一层后都会被重新归一化,网络一深,靠近底层的梯度就很容易消失或者爆炸,通常需要很小心的“学习率预热 (Warm-up)”才能训练起来。

而现代大模型普遍采用的 Pre-Norm 架构是这样的:输出 = 输入 + 子层(Norm(输入))。 主干道上的“输入”信号没有经过任何阻拦(不被 Norm 截断),直接一路流向深层。这就好比修建了一条畅通无阻的高速公路,梯度回传非常顺畅,极大地提升了训练的稳定性。

同时,为了极致的计算效率,我们还会把传统的 LayerNorm 替换为 RMSNorm (均方根归一化)。它去掉了均值计算,只计算方差,效果一样好但跑得更快。

class PocketLLMBlock(nn.Module):
    def __init__(self, layer_id: int, config: PocketLLMConfig):
        super().__init__()
        self.self_attn = Attention(config) 
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 在注意力输入前添加一个 LayerNorm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # 在注意力输出后添加一个 LayerNorm
        self.mlp = FeedForward(config) if not config.use_moe else MOEFeedForward(config) # 根据配置选择使用普通前馈层还是 MoE 前馈层

    def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False, attention_mask=None):
        residual = hidden_states
        # 执行自注意力计算,输入经过 LayerNorm 归一化,并传入位置嵌入、历史 KV 和注意力掩码等信息,得到新的隐藏状态和当前的 KV 用于缓存
        hidden_states, present_key_value = self.self_attn(
            self.input_layernorm(hidden_states), position_embeddings,
            past_key_value, use_cache, attention_mask
        )
        hidden_states += residual # 添加残差连接,将注意力输出与输入相加
        hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states)) # 经过另一个 LayerNorm 归一化后,输入前馈层计算,并添加残差连接
        return hidden_states, present_key_value

3.5 模型总体架构

总模型 PocketLLM 的职责就像是一个车间主任:它负责接收最初的输入(Token ID),将其转化为向量,然后依次让这些向量穿过所有的 Transformer 层,最后通过一个线性分类器(LM Head)预测出词表中每个词的概率。同时,它还要负责收集各个 MoE 层的“路由打分”,以便计算我们刚才提到的负载均衡损失。

# PocketLLM 主干模型
class PocketLLMModel(nn.Module):
    def __init__(self, config: PocketLLMConfig):
        super().__init__()
        self.config = config
        self.vocab_size, self.num_hidden_layers = config.vocab_size, config.num_hidden_layers
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([PocketLLMBlock(l, config) for l in range(self.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.head_dim, end=config.max_position_embeddings, rope_base=config.rope_theta, rope_scaling=config.rope_scaling)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

    def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False, **kwargs):
        batch_size, seq_length = input_ids.shape
        if hasattr(past_key_values, 'layers'): past_key_values = None
        past_key_values = past_key_values or [None] * len(self.layers) # 初始化 past_key_values 为 None,用于存储历史 KV
        start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0 # 计算起始位置,即历史 KV 中最后一个 token 的位置
        hidden_states = self.dropout(self.embed_tokens(input_ids)) # 输入经过嵌入层,得到隐藏状态
        position_embeddings = (self.freqs_cos[start_pos:start_pos + seq_length], self.freqs_sin[start_pos:start_pos + seq_length]) # 获取位置嵌入,用于计算 Rope
        presents = [] # 初始化 presents 列表,用于存储每个 layer 的当前 KV
        # 遍历每个 layer,执行自注意力计算和前馈计算,并添加残差连接
        for layer, past_key_value in zip(self.layers, past_key_values):
            hidden_states, present = layer(
                hidden_states,
                position_embeddings,
                past_key_value=past_key_value,
                use_cache=use_cache,
                attention_mask=attention_mask
            )
            presents.append(present)
        hidden_states = self.norm(hidden_states) # 添加 LayerNorm 归一化
        aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze()) # 计算 MoE 的辅助损失
        return hidden_states, presents, aux_loss # 返回隐藏状态、当前 KV 和 MoE 的辅助损失

讲解一些其中的我自己的理解:

1. start_pos

这个字段的意义是通过检查 KV Cache 的长度,确定当前输入在序列中的偏移量。

我们知道,大模型生成文本是一个 Token 一个 Token 往后跳的。

  • 没有 start_pos:如果你已经写了 100 个字,生成第 101 个字时,模型需要把前 100 个字重新算一遍,算力浪费随字数增加呈几何倍增长。

  • 有了 start_pos:模型会查一下缓存(past_key_values)。如果缓存里已经存了 100 个向量,start_pos 就是 100。模型这次只需要计算当前的第 101 个词,然后把它拼接到旧缓存后面。

特别是对于旋转位置编码 (RoPE) 的精确对齐是非常重要的。

位置编码必须是绝对唯一的。如果模型正在生成第 101 个词,它必须使用“位置 101”对应的正弦/余弦频率。如果没有 start_pos 指明当前偏移量,模型会错误地从“位置 0”开始计算编码,导致模型分不清“我是谁”和“谁是我”(因为词语的顺序乱了)。

2. 为什么将离散的 input_ids 映射为稠密向量 hidden_states

input_ids 本质上是整数索引(例如:[101, 234, 55]),就像是查字典时的页码。如果不做映射(Embedding),会有以下三个致命问题:

A. 语义空间缺失(语义相关性)

  • 离散编码(One-hot):在数字层面,"猫" (ID: 10) 和 "狗" (ID: 11) 的距离,跟 "猫" (ID: 10) 和 "手机" (ID: 999) 的距离没什么区别。数字大小不代表语义远近。

  • 稠密向量(Embedding):通过映射,"猫" 和 "狗" 在高维空间中的**向量距离(余弦相似度)**会非常近,而与 "手机" 非常远。这种“位置即语义”的特性让神经网络能够理解词与词之间的联系。

B. 维度的“死亡”与“重生”

  • 如果使用 One-hot 编码,词表有 10 万个词,每个词就要用 10 万维的向量表示,其中 99,999 维都是 0。这极其浪费显存,且无法进行有效的矩阵运算。

  • 稠密向量通常只有 4096 维(Llama 2-7B 级别),每个维度都是一个浮点数。这不仅压缩了空间,还让每个维度都能承载信息(比如某一维可能暗含“词性”,另一维暗含“情感”)。

C. 梯度回传与可学习性

  • 整数(ID)是不可导的。神经网络通过反向传播来学习,需要所有参数都是连续、可微的。

  • 映射后的 hidden_states 是浮点数矩阵。在训练过程中,模型可以根据损失函数微调这些向量的数值。随着训练进行,原本随机的向量会逐渐在空间中自动归类,形成逻辑严密的知识图谱。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐