2017年Google发那篇"Attention is All You Need"的时候,我还在调LSTM的hidden size,觉得RNN就是序列处理的终极答案。

结果Transformer横空出世,直接把RNN那套打法干翻了。

说实话,当时我心里是不服的。RNN好歹有几十年的积累,你说替代就替代?凭啥?

几年后,我自己撸了一遍自注意力的代码,又去读了几篇后续的优化论文。态度从不服变成了真香。

今天聊聊自注意力机制为什么这么强,以及它到底比RNN好在哪。全程有代码,自己跑跑就懂了。

RNN的硬伤

先说说RNN为什么不行。不是感情问题,是结构问题。

串行计算 (Transformer No.1)

RNN处理序列是一步一步来的

x1 → h1 → x2 → h2 → x3 → h3 → ...

第t步必须等第t-1步算完。这导致一个严重的问题——没法并行

你输入一句话,RNN必须从左到右一个字一个字读。GPU上面几千个核心在跑,但RNN只能用一个核心跑当前步。

Transformer的自注意力机制就不一样了。它一次看全序列,所有位置同时计算。GPU的并行能力被充分利用。

实验结果对比:

模型 训练一个epoch(2.6B参数) 所需GPU
LSTM 约3天 32张V100
Transformer 约5小时 8张A100

时间差了15倍,GPU数量还少了4倍。这就是并行的威力。

长距离依赖 (Transformer No.2)

RNN的第二个硬伤是记忆衰退

假设有这么一句话:

“2015年我去了北京,在五道口租了一间房子,那地方离清华不远,每天早上骑共享单车去上班。”

"我"和"去了北京"之间相隔十几个词。RNN的处理方式是:每过一个词,就用一个hidden state压缩所有信息。传到最后,"北京"的信息已经被稀释了很多次。

传统RNN能记住的有效距离大概在50-100个token。超过这个数,前面的信息基本就丢了。

LSTM和GRU改善了这个问题,但本质上还是把信息塞进一个固定大小的向量里,不可能无限压缩。

Transformer的做法是:每个位置直接和所有位置通信。 “我"想知道上下文,直接看整句话的所有词,没有距离衰减。远在50个词之前的"北京”,也能产生直接的注意力权重。

梯度问题 (Transformer No.3)

老炼丹人都知道,RNN的训练有多难受。

  • 梯度爆炸:loss突然飞到NaN
  • 梯度消失:训练几个小时loss纹丝不动
  • gradient clipping:家家必备

为什么?因为RNN的梯度要在时间步上反向传播。时间步越长,梯度路径越长,越容易出现指数级的增长或衰减。

Transformer的梯度路径是直接的——输出直接连接到输入。没有时间维度上的反复传播,训练稳定很多。

从我实际训练的感受来看:调Transformer几乎不用操心梯度问题。而调LSTM,我至少得给loss加三五个监控报警。

自注意力到底怎么算的

说完了RNN的硬伤,聊聊Transformer的核心——缩放点积注意力。

公式就一行:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V

拆开来看就三个步骤:

第一步:生成Q、K、V

每个token的embedding经过三个不同的线性变换,得到Query、Key、Value三个向量。

用代码来理解:

import torch
import torch.nn.functional as F

# 假设输入:batch=1, seq_len=4, d_model=8
x = torch.randn(1, 4, 8)

# 三个线性变换
W_Q = torch.randn(8, 8, requires_grad=True)
W_K = torch.randn(8, 8, requires_grad=True)
W_V = torch.randn(8, 8, requires_grad=True)

Q = x @ W_Q  # (1, 4, 8)
K = x @ W_K  # (1, 4, 8)
V = x @ W_V  # (1, 4, 8)

第二步:计算注意力分数

用Query去查每个Key的相似度:

# Q @ K^T → (1, 4, 4) 每个位置和其他位置的相似度
scores = Q @ K.transpose(-2, -1)  # (1, 4, 4)

# 缩放,防止内积太大导致softmax梯度消失
scores = scores / 8**0.5

# softmax归一化
attn_weights = F.softmax(scores, dim=-1)

这个注意力矩阵就是"谁看谁"的权重。第i行第j列表示,在计算第i个位置时,要看多少第j个位置的信息。

第三步:加权求和

最后用注意力权重去加权Value:

output = attn_weights @ V  # (1, 4, 8)

这一步相当于:每个位置从其他位置"收集"信息,收集多少就看权重有多大。

这跟RNN的一个关键区别: RNN中,当前位置只能从之前的hidden state获取信息,而且信息已经被压缩过。自注意力中,当前位置可以直接从任意位置获取原始信息,不受距离限制。

完整的多头注意力

实际用的是多头注意力(Multi-Head Attention),就是把Q、K、V切分成h个头,每个头独立计算注意力,再拼起来。

def multi_head_attention(Q, K, V, num_heads=8):
    batch_size, seq_len, d_model = Q.shape
    d_head = d_model // num_heads
    
    # 拆成num_heads个头
    Q = Q.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    K = K.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    V = V.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    
    # 每个头独立算注意力(可以并行!)
    scores = (Q @ K.transpose(-2, -1)) / d_head**0.5
    attn = F.softmax(scores, dim=-1)
    output = attn @ V
    
    # 合并多头
    output = output.transpose(1, 2).contiguous()
    output = output.view(batch_size, seq_len, d_model)
    return output

多头的好处: 不同的头可以关注不同的模式。比如一个头关注语法关系,另一个头关注语义相似性,各看各的,互不干扰。

为什么RNN还活着

说了这么多Transformer的优点,那你可能想问:RNN现在是不是完全没用了?

也不完全是。

在某些场景下,RNN仍然有优势:

  1. 极低延迟场景:RNN是增量计算的,来一个token处理一个。Transformer需要看到完整序列才能计算。如果你在做实时语音识别,用户说一个字,你就得推理一次——这种情况下Transformer的延迟会更高。

  2. 资源极度受限:在小模型、嵌入式设备上,RNN的简单结构反而有优势。一个LSTM单元可能只需要几百KB,而同等能力的Transformer需要几MB甚至更多。

  3. 序列长度极长的场景:标准Transformer是O(n²)的复杂度。如果序列长度达到百万级别,纯Transformer几乎跑不动。(但FlashAttention和线性注意力正在解决这个问题)

不过说到底,Transformer在这场竞争中赢了。你今天看到的GPT、Claude、Gemini,全是基于Transformer架构。RNN在主流NLP领域基本退出了历史舞台。

写在最后

写这篇文章的时候,我特意去翻了一下2017年的代码仓库。当时用LSTM做机器翻译,一个3000万句对的模型要训练两周。

现在同样的任务,Transformer只需要一天半,效果还好得多。

自注意力机制的厉害之处不是它算力强,而是它让信息不再需要经过"压缩-解压"的过程。每个位置直接跟所有位置对话,没有中间商赚差价。这才是替代RNN的根本原因。

如果你也想深入理解Transformer,我建议不要只读论文,动手写一遍代码。从单头注意力开始,慢慢加多头、加残差连接、加层归一化。写完之后,你会发现原来Transformer也没那么神秘。

下篇准备写KV Cache是如何让推理提速10倍的,欢迎关注。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐