本文参考以下英文教程撰写:https://pub.towardsai.net/build-your-own-llama-3-architecture-from-scratch-using-pytorch-2ce1ecaa901c

第一次看到有人把 Llama 3 从零实现一遍,我就知道这件事值得认真做一次。因为只有真正写出来,才能体会到每一个设计选择背后的逻辑——为什么 Norm 放在前面而不是后面,为什么 KV Cache 只缓存 K 和 V 而不包括 Q,为什么 RoPE 要转成复数域再做乘法……这些问题,看论文能得到答案,但只有自己写代码才能真正把它们变成直觉。

这篇笔记的目标是:把 Llama 3 的每一个模块,从设计动机到公式推导到代码实现,都讲清楚。所有技术细节来自 Meta 官方论文,所有代码都可以独立运行。我们用 Tiny Shakespeare 数据集来做演示训练,因为它足够小,能快速看到结果,同时又足够有趣,让你感受到语言模型在学什么。


先看全局

Llama 3 本质上是一个标准的 decoder-only transformer,但 Meta 在几个关键位置做了精准的替换。架构本身不复杂,复杂的是每个替换背后的权衡。

整个前向传播可以分成三段:输入块 → 解码器块(×N层)→ 输出块。

一、Llama 3 整体架构:三个关键杠杆

在深入每个模块之前,先把 Llama 3 的全貌说清楚。

Llama 3 是一个标准的稠密 Transformer 解码器(dense Transformer decoder)架构,论文明确说没有采用混合专家模型(MoE),主要原因是为了最大化训练稳定性、降低复杂度。

论文指出了驱动 Llama 3 性能的三个核心杠杆:

数据(Data):预训练语料从 Llama 2 的 1.8T tokens 提升到 15T tokens,增幅超过 8 倍。数据质量也大幅提升,引入了多轮清洗和领域分类策略。最终的数据配比是:约 50% 通用知识、25% 数学与推理、17% 代码、8% 多语言内容。

规模(Scale):最大模型 405B 参数,用 3.8×10²⁵ FLOPs 训练,是 Llama 2 最大版本的约 50 倍算力。论文按照 Chinchilla 缩放定律推算出计算最优点在 402B 参数和 16.55T tokens,最终选择了 405B。

复杂度管理(Managing Complexity):Post-training 阶段采用 SFT + 拒绝采样(RS)+ DPO 的组合,而非更复杂的 RL 算法,理由是复杂算法更难稳定扩展。

整个模型在架构层面继承了 Llama 2 的基本骨架,但有四处关键修改:GQA 的 KV head 数量固定为 8、词表从 32K 扩充到 128K、RoPE base frequency 从 10000 提升到 500000、以及引入跨文档注意力 mask。下面逐一讲。


二、输入块(Input Block)

输入块包含三个组件:文本/提示词、分词器、嵌入层。

1.1 分词器:从字符到 token 的映射

Llama 3 在生产环境中使用 TikToken 作为分词器,这是一个子词(subword)级分词器,词表大小 128,000,由 100,000 个来自 tiktoken 的基础 token 加上 28,000 个额外的非英语语言 token 组成。

相比 Llama 2 的 SentencePiece 分词器(词表 32K),这次扩充带来的直接收益是压缩率提升——同样一段英文文本,Llama 3 平均每个 token 能表示 3.94 个字符,而 Llama 2 只有 3.17 个字符。这意味着在相同计算量下,Llama 3 能"读到"更多文本。

为什么词表变大能提升压缩率?本质是因为更大的词表可以把更多常见词组和词缀直接存成一个 token,而不是拆成多个子词。比如 "generating" 如果本身在词表里,就是 1 个 token;如果不在,可能被拆成 "generat" 和 "ing" 两个 token。词表越大,整体需要的 token 数越少,序列越短,注意力计算成本越低。

特殊 token 方面,Llama 3 定义了 <|begin_of_text|><|end_of_text|><|eot_id|>(turn 结束)、<|start_header_id|><|end_header_id|> 等,这些在对话场景下有重要的结构化作用。

在我们的从零实现中,用字符级分词器来代替 TikToken,目的是让整个 encode/decode 流程完全透明可控:

with open('tiny_shakespeare.txt', 'r') as f:
    data = f.read()

vocab = sorted(list(set(data)))
vocab.extend(['<|begin_of_text|>', '<|end_of_text|>', '<|pad_id|>'])
vocab_size = len(vocab)

itos = {i: ch for i, ch in enumerate(vocab)}
stoi = {ch: i for i, ch in enumerate(vocab)}

def encode(s):
    return [stoi[ch] for ch in s]

def decode(l):
    return ''.join(itos[i] for i in l)

token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int, device=device)
token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int, device=device)
token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int, device=device)

1.2 模型超参数

在进入解码器之前,先把本次实现用到的所有参数集中定义好。注意这里为了让训练快速出结果,把 dim 调小到 512,实际 Llama 3 8B 的 dim 是 4096:

@dataclass
class ModelArgs:
    dim: int = 512
    n_layers: int = 8
    n_heads: int = 8
    n_kv_heads: int = 4          # 对应论文里 GQA 的 KV heads = 8(这里按比例缩小)
    vocab_size: int = len(vocab)
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000.0  # 论文把这个值从 10000 提升到了 500000

    max_batch_size: int = 10
    max_seq_len: int = 256

    epochs: int = 2500
    log_interval: int = 10
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

这里有一个细节值得单独说:rope_theta = 500000.0。这是 Llama 3 相比 Llama 2 的四大架构改动之一。原论文指出,将 RoPE 的 base frequency 从默认的 10000 提升到 500000,能让模型有效处理更长的上下文——论文引用的研究表明这个值对 32768 长度的上下文有效,而在后续的长上下文继续预训练阶段,上下文长度进一步扩展到了 128K tokens。


三、解码器块(Decoder Block)

解码器块是 Llama 3 的核心。每个解码器块包含六个子组件:RMSNorm、RoPE、KV Cache、GQA、FeedForward Network,以及把它们组装在一起的 TransformerBlock。我们逐一深入。

3.1 RMSNorm:比 LayerNorm 更高效的归一化

为什么需要归一化? embedding 向量在各个维度上的数值范围差异很大,直接送入后续计算会导致梯度爆炸或消失,训练不稳定。归一化把这些值拉到一个合适的范围,让梯度的量级更一致,训练更稳。

为什么用 RMSNorm 而不是 LayerNorm? 这是 Llama 系列从第一代就继承下来的选择。LayerNorm 需要计算均值(mean)和方差(variance):

\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \varepsilon}} \times \gamma_i + \beta_i

而 RMSNorm 完全省掉了均值的计算,只保留 RMS(Root Mean Square,均方根)这一步:

\hat{x}_i = \frac{x_i}{\sqrt{\mathrm{mean}(x^2) + \varepsilon}} \times \gamma_i

这意味着:第一,少了均值计算,计算开销降低;第二,没有偏移参数 β,参数更少;第三,论文作者的实验表明性能相当甚至更好。直觉上,归一化的关键作用是控制量级(scale),而均值中心化对这个目标的贡献有限,省掉它代价不大。

同时要注意,Llama 3 采用 Pre-Norm 结构,即在注意力和前馈网络之前做归一化,而不是之后。Pre-Norm 相比 Post-Norm 训练更稳定,这一点已经被大量工作证明。

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim).to(device))

    def _norm(self, x):
        # x.pow(2).mean(dim=-1, keepdim=True) 是对最后一个维度(embedding dim)求均方
        # rsqrt = 1 / sqrt,整体就是 x / RMS(x)
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(device)

    def forward(self, x):
        # Shape: x[bs, seq, dim] -> output[bs, seq, dim]
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

3.2 旋转位置编码(RoPE):用旋转矩阵编码绝对位置和相对位置

问题:Transformer 的自注意力机制本质上是置换不变的(permutation invariant)——把输入序列的顺序打乱,注意力得分的模式不会变。但语言显然是有顺序的,"我爱你"和"你爱我"意思截然不同。所以必须想办法把位置信息注入 embedding。

Llama 1/2/3 都用 RoPE(Rotary Positional Encoding),而不是原始 Transformer 里的正弦绝对位置编码,也不是可学习的绝对位置编码。RoPE 的核心思路是:用旋转矩阵对 Q 和 K 的 embedding 进行旋转,使得旋转角度正比于 token 的绝对位置

这样做的妙处是:注意力分数 Q·K^T 在经过旋转之后,只跟两个 token 的相对位置有关,与它们的绝对位置无关——因为旋转是线性变换,两个旋转矩阵相乘的结果只取决于它们的旋转角度之差,即 m-n(m 和 n 分别是两个 token 的绝对位置)。这就同时实现了绝对位置编码和相对位置感知。

RoPE 的数学实现:

对于位置 m 的 token,每对相邻维度 (2i, 2i+1) 按角度 m × θᵢ 旋转,其中:

\theta_i = \frac{1}{\mathrm{base}^{(2i / d)}}

这里 base 就是 rope_theta。Llama 3 把 base 从 10000 改成了 500000,使得高频分量(小 θ)变化更缓慢,能更好地处理长序列中遥远位置之间的相对关系。

旋转操作在实数域是矩阵乘法,但在复数域只是点乘,所以实现上会先把 embedding 转成复数,乘以旋转因子,再转回实数:

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 500000.0):
    device = ModelArgs.device
    # 计算每对维度的 theta 值:θᵢ = 1 / (theta^(2i/dim))
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device)[:(dim // 2)].float() / dim))
    
    # 计算序列中每个位置 m 的值
    t = torch.arange(seq_len, dtype=torch.float32, device=device)
    
    # outer product 得到每个位置每个维度对的旋转角度:m × θᵢ
    freqs = torch.outer(t, freqs).to(device)
    
    # 转成极坐标形式(模=1,角度=freqs),即 e^(i·m·θ) 的复数表示
    freqs_cis = torch.polar(torch.ones_like(freqs).to(device), freqs).to(device)
    return freqs_cis

def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    device = ModelArgs.device
    # 把最后一维两两配对,视作复数:[bsz, seq_len, n_heads, head_dim/2]
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).to(device)
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).to(device)

    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)

    # 复数乘法 = 旋转,然后转回实数并展平最后两维
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).to(device)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).to(device)
    return xq_out.type_as(xq), xk_out.type_as(xk)

特别注意:RoPE 只施加在 Q 和 K 上,不施加在 V 上。这是因为 RoPE 的目的是让注意力分数(Q·K^T)感知相对位置关系,而 V 是被注意力加权聚合的值,不需要位置旋转。

3.3 KV Cache:推理时空间换时间的关键优化

KV Cache 只在推理阶段启用,训练时不需要。理解它的必要性需要先理解自回归生成的过程。

在推理时,模型每次只生成一个 token,但每次生成都要做完整的注意力计算。假设当前已经生成了 t 个 token,要生成第 t+1 个:

没有 KV Cache 的情况: 对 t+1 个 token 做完整的 QKV 计算,每次都要对所有历史 token 重新计算 K 和 V。这些历史 token 的 K 和 V 在上一步已经算过了,完全是重复计算。矩阵乘法的规模是 (t+1) × (t+1),随着序列变长计算量呈平方增长。

有 KV Cache 的情况: 把每一步计算出来的 K 和 V 存下来,下一步直接复用。当前步只需要用最新的一个 Q token,与缓存里所有历史的 K、V 做注意力计算,矩阵乘法变成 1 × (t+1),大幅降低计算量。

Q 不需要缓存的原因:每一步我们只用当前位置的 Q 来查询,它不会被未来的步骤复用。

# 在 Attention.__init__ 里初始化 KV Cache
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self

3.4 分组查询注意力(GQA):在精度与效率之间找到最优平衡点

Llama 3 把 GQA 从 Llama 2 仅在 70B 模型上使用,扩展到了所有规模(8B、70B、405B)都使用,且不管模型大小,KV heads 数量统一固定为 8

理解 GQA 需要先理解三种注意力机制的区别:

Multi-Head Attention(MHA):Q、K、V 各有相同数量的 head(比如 32 个)。每个 Q head 都有自己对应的独立 K head 和 V head。KV Cache 大小正比于 head 数量。

Multi-Query Attention(MQA):Q 有多个 head,但所有 Q head 共享同一组 K 和 V(只有 1 个 KV head 对)。KV Cache 大幅缩减,但不同 Q head 之间的表示多样性受限,可能损失模型质量。

Grouped Query Attention(GQA):介于两者之间——Q head 分成若干组,同一组内的 Q head 共享一对 K/V head。Llama 3 8B 有 32 个 Q heads、8 个 KV heads,分组数 = 32/8 = 4,每 4 个 Q head 共享一对 KV head。

这样做的好处是:KV Cache 从 MHA 的 32× 降到 8×,显存占用大幅减少;但保留了 32 个独立的 Q heads,注意力表达能力不受影响。实验数据表明这个设计在质量和效率之间取得了很好的平衡。

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.head_dim = args.dim // args.n_heads
        # 每个 KV head 需要被几个 Q head 共享
        self.n_rep = args.n_heads // args.n_kv_heads

        # Q 的输出维度:n_heads × head_dim
        # K/V 的输出维度:n_kv_heads × head_dim(更小)
        self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False, device=device)
        self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
        self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False, device=device)
        self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False, device=device)

        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)

    def forward(self, x: torch.Tensor, start_pos, inference):
        bsz, seq_len, _ = x.shape
        mask = None

        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)

        # reshape 到 [bsz, seq_len, n_heads, head_dim]
        xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)

        if inference:
            # 推理模式:启用 KV Cache,rope_theta 用论文值 500000
            freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len * 2,
                                              theta=self.args.rope_theta)
            freqs_cis = freqs_cis[start_pos: start_pos + seq_len]
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
            self.cache_k[:bsz, start_pos:start_pos + seq_len] = xk
            self.cache_v[:bsz, start_pos:start_pos + seq_len] = xv

            keys = self.cache_k[:bsz, :start_pos + seq_len]
            values = self.cache_v[:bsz, :start_pos + seq_len]

            # 把 KV heads 扩展到和 Q heads 一样多,以便做矩阵乘法
            keys = repeat_kv(keys, self.n_rep)
            values = repeat_kv(values, self.n_rep)
        else:
            # 训练模式:不用 KV Cache,直接对整个序列做注意力
            freqs_cis = precompute_freqs_cis(dim=self.head_dim, seq_len=self.args.max_seq_len,
                                              theta=self.args.rope_theta)
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis)

            keys = repeat_kv(xk, self.n_rep)
            values = repeat_kv(xv, self.n_rep)

            # 因果 mask:上三角填 -inf,防止当前 token 看到未来 token
            mask = torch.full((seq_len, seq_len), float("-inf"), device=self.args.device)
            mask = torch.triu(mask, diagonal=1).to(self.args.device)

        # Transpose to [bsz, n_heads, seq_len, head_dim]
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # 注意力分数 = Q·K^T / √d_k,然后 softmax,然后加权 V
        scores = torch.matmul(xq, keys.transpose(2, 3)).to(self.args.device) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values).to(self.args.device)

        # 把所有 head 的输出拼回来:[bsz, n_heads, seq_len, head_dim] -> [bsz, seq_len, dim]
        output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
        return self.wo(output)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """把 KV head 的数量扩展到和 Q head 一样多"""
    bsz, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bsz, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(bsz, seq_len, n_kv_heads * n_rep, head_dim)
    )

这里有一个很容易忽视的细节:因果 mask 在训练时是必须的,因为训练时整个序列一次性并行处理,模型必须被阻止看到未来的 token。而在推理时,因为 KV Cache 的存在,每次只处理一个新 token,天然没有"未来信息泄露"的问题,所以 mask 不需要了。

3.5 前馈网络(FeedForward with SwiGLU)

前馈网络在每个注意力块之后,负责对每个 token 的表示做非线性变换,让模型能学到更复杂的特征。

Llama 3 使用 SwiGLU 激活函数,而非原始 Transformer 里的 ReLU 或 GeLU。SwiGLU 全称是 Swish-Gated Linear Unit,它的前馈计算公式是:

\text{FFN}(x) = W_2 \cdot \left( \text{SiLU}(W_1 x) \odot W_3 x \right)

其中 SiLU(x) = x · σ(x)(σ 是 sigmoid),⊙ 是逐元素乘法。和标准的两层 FFN 不同,SwiGLU 用了三个线性变换矩阵(W₁、W₂、W₃),其中 W₁ 的输出经过 SiLU 激活后,作为"门"对 W₃ 的输出进行过滤。

为什么 SwiGLU 比 ReLU 好? 关键在于 ReLU 在负数区域输出全为 0(hard gate),而 SwiGLU 在负数区域有平滑的非零输出,梯度更连续,也保留了一定的负数信息。这种"软门控"机制让模型的表达能力更强,同时训练更稳定。

由于多了 W₃ 这个矩阵,FFN 的参数量相比 ReLU 版本多了约 50%。为了保持总参数量不变,Llama 3 对隐层维度做了调整,论文里用的隐层维度计算公式是 int(2 * hidden_dim / 3) 并取 256 的倍数:

class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float]):
        super().__init__()
        self.dim = dim
        # 按 Meta 的隐层维度计算公式:先缩到 2/3,再取 256 的整数倍
        hidden_dim = int(2 * hidden_dim / 3)
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)  # gate 路径
        self.w2 = nn.Linear(hidden_dim, self.dim, bias=False, device=device)  # 输出投影
        self.w3 = nn.Linear(self.dim, hidden_dim, bias=False, device=device)  # value 路径

    def forward(self, x):
        # SwiGLU: SiLU(W₁x) ⊙ W₃x,然后过 W₂
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

从论文给出的具体数字来看,Llama 3 8B 的 FFN 隐层维度是 14,336,70B 是 28,672,405B 是 53,248。

3.6 解码器块(TransformerBlock):把所有子模块组装起来

一个完整的 TransformerBlock 按照以下顺序执行:

  1. 输入 x 先过 RMSNorm(Pre-Norm),再进入注意力层
  2. 注意力输出和原始 x 做残差连接(Residual Connection)
  3. 残差连接结果再过 RMSNorm,进入 FFN
  4. FFN 输出再次做残差连接

用公式写就是:

h = x + \text{Attention}(\text{RMSNorm}(x))

\text{out} = h + \text{FFN}(\text{RMSNorm}(h))

残差连接是 Transformer 稳定深度训练的关键。没有残差连接,梯度在穿过几十层之后会彻底消失,完全无法训练。残差连接提供了一条高速路,让梯度可以直接从输出层流回输入层。

class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.attention_norm = RMSNorm(dim=args.dim, eps=args.norm_eps)
        self.attention = Attention(args)
        self.ff_norm = RMSNorm(dim=args.dim, eps=args.norm_eps)
        self.feedforward = FeedForward(args.dim, 4 * args.dim, args.multiple_of, args.ffn_dim_multiplier)

    def forward(self, x, start_pos, inference):
        # Pre-Norm + Attention + Residual
        h = x + self.attention(self.attention_norm(x), start_pos, inference)
        # Pre-Norm + FFN + Residual
        out = h + self.feedforward(self.ff_norm(h))
        return out

Llama 3 的三个规模对应不同的解码器层数:8B 模型有 32 层,70B 模型有 80 层,405B 模型有 126 层。每一层的结构完全相同,只是宽度(dim)不同。


四、输出块(Output Block)与完整模型

所有解码器块处理完之后,最后的 hidden states 流入输出块:先过一次 RMSNorm,再过一个线性层,把 embedding 维度映射到词表大小,输出 logits。

logits 的每个维度对应词表里的一个 token,softmax 之后就是模型预测下一个 token 的概率分布。

训练时:把 logits 和真实 target labels 传入交叉熵损失函数,反向传播更新所有参数。

推理时:从 logits 对应的概率分布中采样,得到下一个生成的 token。

class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        # 输入层:token id -> embedding vector
        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)

        # 解码器堆叠:n_layers 个 TransformerBlock
        self.layers = nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(args=params))

        # 输出层
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)

    def forward(self, x, start_pos=0, targets=None):
        # x: [bsz, seq_len] -> h: [bsz, seq_len, dim]
        h = self.tok_embeddings(x)

        inference = targets is None

        for layer in self.layers:
            h = layer(h, start_pos, inference)

        h = self.norm(h)
        # h: [bsz, seq_len, dim] -> logits: [bsz, seq_len, vocab_size]
        logits = self.output(h).float()

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.params.vocab_size), targets.view(-1))

        return logits, loss

五、训练

在训练流程上,我们用 80% 的数据做训练,10% 做验证,10% 做测试。每次随机采样一个 batch,输入是从 <|begin_of_text|> 开始的序列,目标是把这个序列向右移动一位(即每个位置预测下一个 token)。

dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device)

def get_dataset_batch(data, split, args: ModelArgs):
    seq_len = args.max_seq_len
    batch_size = args.max_batch_size
    device = args.device

    train = data[:int(0.8 * len(data))]
    val = data[int(0.8 * len(data)): int(0.9 * len(data))]
    test = data[int(0.9 * len(data)):]

    batch_data = train
    if split == "val":
        batch_data = val
    if split == "test":
        batch_data = test

    ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device)
    # x:<|begin_of_text|> + 正文前 seq_len-1 个字符
    x = torch.stack([torch.cat([token_bos, batch_data[i:i + seq_len - 1]]) for i in ix]).long().to(device)
    # y:正文后 seq_len-1 个字符 + <|end_of_text|>(即 x 右移一位)
    y = torch.stack([torch.cat([batch_data[i + 1:i + seq_len], token_eos]) for i in ix]).long().to(device)

    return x, y


@torch.no_grad()
def evaluate_loss(model, args: ModelArgs):
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = []
        for _ in range(10):
            xb, yb = get_dataset_batch(dataset, split, args)
            _, loss = model(x=xb, targets=yb)
            losses.append(loss.item())
        out[split] = np.mean(losses)
    model.train()
    return out


def train(model, optimizer, args: ModelArgs):
    epochs = args.epochs
    log_interval = args.log_interval
    device = args.device
    losses = []
    start_time = time.time()

    for epoch in range(epochs):
        optimizer.zero_grad()

        xs, ys = get_dataset_batch(dataset, 'train', args)
        xs = xs.to(device)
        ys = ys.to(device)
        logits, loss = model(x=xs, targets=ys)
        loss.backward()
        optimizer.step()

        if epoch % log_interval == 0:
            batch_time = time.time() - start_time
            x = evaluate_loss(model, args)
            losses += [x]
            print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}")
            start_time = time.time()

    print("validation loss: ", losses[-1]['val'])
    return pd.DataFrame(losses).plot()


model = Transformer(ModelArgs).to(ModelArgs.device)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer, ModelArgs)

在 Google Colab 的免费 GPU 上,2500 个 epoch 大约需要 10 分钟,最终 validation loss 在 2.19 左右。这个数字并不算低,原因是我们只用了 Tiny Shakespeare 这个小数据集,且没有做任何超参数调优。真实的 Llama 3 在 15T tokens 上训练,差距是数量级的。


六、推理

推理的核心是自回归生成(autoregressive generation):每次预测一个 token,把这个 token 追加到输入序列,再继续预测下一个,直到生成了最大长度或遇到结束符。

采样策略使用 Top-p(Nucleus)Sampling:按概率降序排列所有 token,找到累积概率刚好超过 p 的最小集合,只从这个集合里随机采样。这比直接取最大概率(贪心解码)生成的文本更有多样性,比完全随机采样又更有质量保证。

Temperature 参数控制分布的"尖锐程度":temperature < 1 会让分布更集中(更保守),temperature > 1 会让分布更平坦(更随机)。

def generate(model, prompts: str, params: ModelArgs, max_gen_len: int = 500,
             temperature: float = 0.6, top_p: float = 0.9):
    bsz = 1
    prompt_tokens = token_bos.tolist() + encode(prompts)
    assert len(prompt_tokens) <= params.max_seq_len

    total_len = min(len(prompt_tokens) + max_gen_len, params.max_seq_len)
    tokens = torch.full((bsz, total_len), fill_value=token_pad.item(), dtype=torch.long, device=params.device)
    tokens[:, :len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long, device=params.device)

    input_text_mask = tokens != token_pad.item()

    prev_pos = 0
    for cur_pos in range(1, total_len):
        with torch.no_grad():
            logits, _ = model(x=tokens[:, prev_pos:cur_pos], start_pos=prev_pos)

        if temperature > 0:
            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits[:, -1], dim=-1)

        next_token = next_token.reshape(-1)
        # 如果当前位置是 prompt 的一部分,不要覆盖它
        next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
        tokens[:, cur_pos] = next_token

        prev_pos = cur_pos
        if tokens[:, cur_pos] == token_pad.item() and next_token == token_eos.item():
            break

    output_tokens, output_texts = [], []
    for i, toks in enumerate(tokens.tolist()):
        if token_eos.item() in toks:
            eos_idx = toks.index(token_eos.item())
            toks = toks[:eos_idx]
        output_tokens.append(toks)
        output_texts.append(decode(toks))
    return output_tokens, output_texts


def sample_top_p(probs, p):
    # 按概率降序排列
    probs_sort, prob_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    # 找到累积概率超过 p 的截断点,把截断点之后的概率置 0
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    # 重归一化后采样
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(prob_idx, -1, next_token)
    return next_token


prompts = "Consider you what services he has done"
output_tokens, output_texts = generate(model, prompts, ModelArgs)
output_texts = output_texts[0].replace("<|begin_of_text|>", "")
print(output_texts)

附:论文里的那些工程细节

上面的代码实现覆盖了 Llama 3 的核心架构,但论文里还有几个工程细节值得单独讲一讲,它们是真实生产训练里性能的关键保障,也是理解为什么 Llama 3 能在这么大的规模上稳定训练的原因。

1. 文档级别的 attention mask:论文提到在标准预训练阶段效果有限,但在长上下文继续预训练时至关重要。它的作用是:当多个文档被拼接成一个长序列时,阻止不同文档之间的 token 互相 attend。否则,文档 A 的最后一个 token 会"看到"文档 B 的第一个 token,这会注入本不应该存在的上下文关系,污染长距离依赖的学习。

2. 长上下文预训练:Llama 3 的标准预训练上下文长度是 8K tokens,但 Llama 3.1 的 405B 模型最终支持 128K tokens 的上下文。这不是一步到位的,而是分六个阶段逐步扩展,从 8K 到 128K,用了约 8000 亿 token 的继续预训练来让模型适应。

3. 退火(Annealing):在预训练的最后阶段,把学习率线性退火到 0,同时把高质量的数学和代码数据的权重调高。论文发现退火让 8B 模型在 GSM8k 上的准确率提升了 24%,在 MATH 上提升了 6.4%。这提示我们:数据质量在预训练的最后阶段比数量更重要。

4. 缩放定律实验:Meta 在确定 405B 这个参数规模之前,跑了大量从 40M 到 16B 参数的小模型实验,建立了 IsoFLOPs 曲线,用幂律关系 N*(C) = A × C^α 来预测最优参数规模。拟合出 α=0.53,A=0.29,据此推算出 3.8×10²⁵ FLOPs 对应的计算最优点约在 402B 参数,最终选择了 405B。这是把科学方法引入工程决策的典型例子。


论文链接:https://arxiv.org/abs/2407.21783
Meta 官方代码:https://github.com/meta-llama/llama3

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐