环境声明

  • Python版本:Python 3.10+
  • PyTorch版本:PyTorch 2.0+
  • 推荐开发工具:PyCharm / VS Code / Jupyter Notebook
  • 操作系统:Windows / macOS / Linux(通用)

学习目标和摘要

摘要:本章将深入探讨注意力机制的发展历程,从2014年Bahdanau等人开创性的工作开始,到2017年Transformer中Self-Attention的横空出世,再到近年来各种高效注意力变体的涌现。你将理解注意力机制的数学本质,掌握Query-Key-Value的计算范式,并学会实现多头注意力机制。

学习目标

  1. 理解注意力机制的生物学启发和直观意义
  2. 掌握Seq2Seq架构中注意力机制的工作原理
  3. 深入理解Self-Attention的数学形式和计算过程
  4. 学会实现多头注意力机制和位置编码
  5. 了解高效注意力变体(Linear Attention、Sparse Attention等)
  6. 掌握注意力可视化的方法

1. 注意力机制的生物学启发

1.1 人类视觉注意力系统

想象你正在一个拥挤的火车站寻找你的朋友。你的眼睛不会同时清晰地看到所有人和物体,而是会快速扫视,将注意力集中在可能与朋友相关的特征上——比如相似的身高、穿着的颜色、发型等。这种选择性关注的能力就是注意力的本质。

核心比喻:注意力机制就像是给神经网络装上了一个"聚光灯",让它能够在处理大量信息时,动态地选择关注最重要的部分。

人类大脑处理视觉信息时,存在两种注意力机制:

  • 自下而上的注意力:由外界刺激驱动,例如突然的响声或明亮的闪光会自动吸引注意
  • 自上而下的注意力:由目标和任务驱动,例如主动寻找特定物体

深度学习中的注意力机制主要模拟的是自上而下的注意力——根据当前任务目标,动态调整对不同输入部分的关注程度。

1.2 从RNN到注意力:为什么需要变革

在注意力机制出现之前,序列到序列(Seq2Seq)模型使用固定长度的上下文向量来编码整个输入序列。这就像试图用一张小纸条记录整本书的内容——信息损失不可避免。

关键问题:当输入序列很长时,编码器必须将所有信息压缩到一个固定维度的向量中,导致信息瓶颈(Information Bottleneck)。

注意力机制通过允许解码器在生成每个输出时"回看"输入序列的不同部分,彻底解决了这个问题。


2. Seq2Seq与编码器-解码器架构

2.1 基础Seq2Seq架构

Seq2Seq(Sequence to Sequence)模型是深度学习中处理序列转换任务的基石,典型应用包括机器翻译、文本摘要、语音识别等。

架构组成

import torch
import torch.nn as nn

class Encoder(nn.Module):
    """
    基础编码器:将输入序列编码为上下文向量
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        
    def forward(self, x):
        # x: (batch_size, seq_len)
        embedded = self.embedding(x)  # (batch_size, seq_len, embed_dim)
        outputs, (hidden, cell) = self.lstm(embedded)
        # outputs: (batch_size, seq_len, hidden_dim)
        # hidden: (1, batch_size, hidden_dim)
        return outputs, hidden, cell

class Decoder(nn.Module):
    """
    基础解码器:从上下文向量生成输出序列
    """
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden, cell):
        # x: (batch_size, 1) - 单个词
        embedded = self.embedding(x)  # (batch_size, 1, embed_dim)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell

2.2 信息瓶颈问题

在基础Seq2Seq中,编码器将变长输入序列压缩为固定长度的上下文向量。对于长序列,这种压缩会导致严重的信息损失。

一句话总结:注意力机制让模型学会"该看哪里",而不是被迫记住所有信息。


3. 注意力机制的数学形式

3.1 Query、Key、Value范式

注意力机制的核心思想可以抽象为三个概念:

  • Query(查询):当前需要关注什么,代表"我要找什么"
  • Key(键):输入序列中各位置的标识,代表"我是什么"
  • Value(值):输入序列中各位置的实际内容,代表"我有什么信息"

计算过程类比:想象你在图书馆找书(Query),每本书都有书名标签(Key)和实际内容(Value)。你通过比较Query和Key的相似度,决定从哪些Value中获取信息。

3.2 注意力计算的一般形式

注意力函数可以描述为将一个Query和一组Key-Value对映射到输出,其中输出是Value的加权和,权重由Query与对应Key的相似度计算得到。

Attention(Q, K, V) = softmax(similarity(Q, K)) * V

其中similarity函数可以有多种形式:

注意力类型 相似度计算方式 特点
加性注意力 v^T * tanh(W_qQ + W_kK) 灵活性强,可学习参数多
点积注意力 Q * K^T 计算简单,速度快
缩放点积注意力 (Q * K^T) / sqrt(d_k) 防止softmax梯度消失

4. 加性注意力与点积注意力

4.1 加性注意力(Additive Attention)

Bahdanau等人在2014年提出的加性注意力使用一个前馈网络来计算相似度:

class AdditiveAttention(nn.Module):
    """
    加性注意力机制(Bahdanau Attention)
    适用于Query和Key维度不同的情况
    """
    def __init__(self, query_dim, key_dim, hidden_dim):
        super(AdditiveAttention, self).__init__()
        self.W_query = nn.Linear(query_dim, hidden_dim, bias=False)
        self.W_key = nn.Linear(key_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        
    def forward(self, query, keys, values, mask=None):
        """
        Args:
            query: (batch_size, query_dim) 或 (batch_size, num_queries, query_dim)
            keys: (batch_size, seq_len, key_dim)
            values: (batch_size, seq_len, value_dim)
            mask: (batch_size, seq_len) 可选的掩码
        Returns:
            context: (batch_size, value_dim) 加权后的上下文向量
            attention_weights: (batch_size, seq_len) 注意力权重
        """
        # 扩展query维度以便广播
        if query.dim() == 2:
            query = query.unsqueeze(1)  # (batch_size, 1, query_dim)
        
        # 计算加性分数: v^T * tanh(W_q*Q + W_k*K)
        # query_transformed: (batch_size, 1, hidden_dim)
        query_transformed = self.W_query(query)
        # keys_transformed: (batch_size, seq_len, hidden_dim)
        keys_transformed = self.W_key(keys)
        
        # 广播相加: (batch_size, seq_len, hidden_dim)
        combined = torch.tanh(query_transformed + keys_transformed)
        
        # 计算分数: (batch_size, seq_len, 1) -> (batch_size, seq_len)
        scores = self.v(combined).squeeze(-1)
        
        # 应用掩码(如果提供)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax归一化
        attention_weights = torch.softmax(scores, dim=-1)  # (batch_size, seq_len)
        
        # 加权求和: (batch_size, 1, seq_len) @ (batch_size, seq_len, value_dim)
        context = torch.bmm(attention_weights.unsqueeze(1), values).squeeze(1)
        
        return context, attention_weights

4.2 点积注意力与缩放点积注意力

点积注意力计算更加直接高效:

class ScaledDotProductAttention(nn.Module):
    """
    缩放点积注意力机制(Scaled Dot-Product Attention)
    Transformer中使用的标准注意力机制
    """
    def __init__(self, dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch_size, num_heads, seq_len_q, d_k)
            key: (batch_size, num_heads, seq_len_k, d_k)
            value: (batch_size, num_heads, seq_len_v, d_v)
            mask: (batch_size, 1, seq_len_q, seq_len_k) 可选
        Returns:
            output: (batch_size, num_heads, seq_len_q, d_v)
            attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
        """
        d_k = query.size(-1)
        
        # 计算点积: Q @ K^T / sqrt(d_k)
        # (batch, heads, seq_q, d_k) @ (batch, heads, d_k, seq_k) 
        # = (batch, heads, seq_q, seq_k)
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        
        # 应用掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax归一化
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # 与Value相乘: (batch, heads, seq_q, seq_k) @ (batch, heads, seq_v, d_v)
        # = (batch, heads, seq_q, d_v)
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights

缩放因子的重要性:除以sqrt(d_k)可以防止点积结果过大导致softmax梯度消失。当d_k较大时,点积的方差会增大,导致softmax进入饱和区。


5. Self-Attention:注意力的新范式

5.1 从Cross-Attention到Self-Attention

传统的注意力机制(Cross-Attention)中,Query来自解码器,Key和Value来自编码器。而Self-Attention则让Query、Key、Value都来自同一个序列——序列中的每个位置都能"看到"其他所有位置。

核心思想:Self-Attention允许模型直接建模序列中任意两个位置之间的关系,无论它们相距多远。这克服了RNN中远距离依赖难以捕捉的问题。

5.2 Self-Attention的直观理解

考虑句子:“The animal didn’t cross the street because it was too tired.”

这里的"it"指代什么?人类读者很容易理解"it"指的是"animal"而不是"street"。Self-Attention通过计算"it"与句子中所有词的关联度,让模型学会这种指代消解。

5.3 Self-Attention的完整实现

class SelfAttention(nn.Module):
    """
    自注意力机制的实现
    """
    def __init__(self, d_model, dropout=0.1):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        self.W_query = nn.Linear(d_model, d_model)
        self.W_key = nn.Linear(d_model, d_model)
        self.W_value = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, d_model)
            mask: (batch_size, seq_len, seq_len) 可选的注意力掩码
        Returns:
            output: (batch_size, seq_len, d_model)
            attention_weights: (batch_size, seq_len, seq_len)
        """
        batch_size, seq_len, _ = x.size()
        
        # 生成Q、K、V
        Q = self.W_query(x)  # (batch, seq, d_model)
        K = self.W_key(x)
        V = self.W_value(x)
        
        # 计算注意力分数
        scores = torch.bmm(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        # scores: (batch, seq, seq)
        
        # 应用掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax和加权
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        output = torch.bmm(attention_weights, V)  # (batch, seq, d_model)
        
        return output, attention_weights

6. 多头注意力机制(Multi-Head Attention)

6.1 为什么需要多头

单一的注意力机制只能捕捉一种类型的关系。但在自然语言中,词与词之间可能存在多种关系:语法关系、语义关系、指代关系等。

核心思想:多头注意力使用多组独立的Q、K、V投影,让模型在不同的"表示子空间"中并行学习不同类型的依赖关系。

6.2 多头注意力的数学表达

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O
where head_i = Attention(Q*W_i^Q, K*W_i^K, V*W_i^V)

6.3 多头注意力完整实现

class MultiHeadAttention(nn.Module):
    """
    多头注意力机制的完整实现
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        
        # 线性投影层
        self.W_query = nn.Linear(d_model, d_model)
        self.W_key = nn.Linear(d_model, d_model)
        self.W_value = nn.Linear(d_model, d_model)
        self.W_output = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(dropout)
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x, batch_size):
        """
        将输入分割成多个头
        x: (batch_size, seq_len, d_model)
        return: (batch_size, num_heads, seq_len, d_k)
        """
        seq_len = x.size(1)
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # (batch, heads, seq, d_k)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query, key, value: (batch_size, seq_len, d_model)
            mask: (batch_size, 1, seq_len, seq_len) 或兼容形状
        Returns:
            output: (batch_size, seq_len, d_model)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        batch_size = query.size(0)
        
        # 线性投影并分割多头
        Q = self.split_heads(self.W_query(query), batch_size)   # (batch, heads, seq_q, d_k)
        K = self.split_heads(self.W_key(key), batch_size)       # (batch, heads, seq_k, d_k)
        V = self.split_heads(self.W_value(value), batch_size)   # (batch, heads, seq_v, d_v)
        
        # 调整mask形状以适配多头
        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)  # (batch, 1, seq_q, seq_k)
        
        # 计算注意力
        attn_output, attention_weights = self.attention(Q, K, V, mask)
        # attn_output: (batch, heads, seq_q, d_k)
        
        # 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()  # (batch, seq_q, heads, d_k)
        attn_output = attn_output.view(batch_size, -1, self.d_model)  # (batch, seq_q, d_model)
        
        # 最终线性投影
        output = self.W_output(attn_output)
        output = self.dropout(output)
        
        return output, attention_weights

7. 注意力可视化与可解释性

7.1 注意力权重的意义

注意力权重矩阵直观地展示了模型在处理序列时"关注"了哪些位置。通过可视化这些权重,我们可以:

  • 理解模型的决策过程
  • 发现模型学到的语言规律
  • 诊断模型的问题

7.2 注意力热力图可视化代码

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def visualize_attention(attention_weights, tokens=None, title="Attention Heatmap"):
    """
    可视化注意力权重热力图
    
    Args:
        attention_weights: numpy数组,形状为 (seq_len, seq_len) 或 (num_heads, seq_len, seq_len)
        tokens: 可选,token列表用于坐标轴标注
        title: 图表标题
    """
    if isinstance(attention_weights, torch.Tensor):
        attention_weights = attention_weights.detach().cpu().numpy()
    
    # 如果是多头注意力,取平均
    if attention_weights.ndim == 3:
        attention_weights = attention_weights.mean(axis=0)
    
    seq_len = attention_weights.shape[0]
    
    # 如果没有提供tokens,使用索引
    if tokens is None:
        tokens = [f"{i}" for i in range(seq_len)]
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights,
        xticklabels=tokens,
        yticklabels=tokens,
        cmap="YlOrRd",
        cbar_kws={'label': 'Attention Weight'},
        square=True,
        linewidths=0.5
    )
    plt.title(title, fontsize=14, fontweight='bold')
    plt.xlabel('Key Position', fontsize=12)
    plt.ylabel('Query Position', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()


def visualize_multihead_attention(attention_weights, tokens=None, num_heads_to_show=4):
    """
    可视化多头注意力中不同头的注意力模式
    
    Args:
        attention_weights: (num_heads, seq_len, seq_len)
        tokens: token列表
        num_heads_to_show: 要展示的头数
    """
    if isinstance(attention_weights, torch.Tensor):
        attention_weights = attention_weights.detach().cpu().numpy()
    
    num_heads = attention_weights.shape[0]
    seq_len = attention_weights.shape[1]
    
    if tokens is None:
        tokens = [f"{i}" for i in range(seq_len)]
    
    # 选择要展示的头
    heads_to_show = min(num_heads_to_show, num_heads)
    
    fig, axes = plt.subplots(1, heads_to_show, figsize=(4*heads_to_show, 4))
    if heads_to_show == 1:
        axes = [axes]
    
    for idx, ax in enumerate(axes):
        sns.heatmap(
            attention_weights[idx],
            xticklabels=tokens if idx == heads_to_show-1 else [],
            yticklabels=tokens if idx == 0 else [],
            cmap="YlOrRd",
            ax=ax,
            square=True,
            cbar=False
        )
        ax.set_title(f'Head {idx+1}', fontsize=10)
    
    plt.suptitle('Multi-Head Attention Visualization', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


# 示例:创建演示数据并可视化
def demo_attention_visualization():
    """
    演示注意力可视化功能
    """
    # 模拟一个简单句子的注意力权重
    sentence = ["The", "cat", "sat", "on", "the", "mat", "."]
    seq_len = len(sentence)
    
    # 创建模拟的注意力权重(模拟对角线和局部关注)
    np.random.seed(42)
    attention = np.random.rand(seq_len, seq_len) * 0.1
    
    # 增强对角线(自注意力通常关注自身)
    for i in range(seq_len):
        attention[i, i] = 0.5
        # 增强相邻词的关注
        if i > 0:
            attention[i, i-1] = 0.2
        if i < seq_len - 1:
            attention[i, i+1] = 0.2
    
    # 归一化
    attention = attention / attention.sum(axis=1, keepdims=True)
    
    # 可视化
    visualize_attention(attention, sentence, "Self-Attention Pattern Example")
    
    # 多头注意力演示
    num_heads = 8
    multihead_attention = np.random.rand(num_heads, seq_len, seq_len)
    
    # 为不同头设置不同模式
    for h in range(num_heads):
        if h % 2 == 0:
            # 偶数头:关注局部
            for i in range(seq_len):
                for j in range(max(0, i-2), min(seq_len, i+3)):
                    multihead_attention[h, i, j] += 0.3
        else:
            # 奇数头:关注全局(模拟长距离依赖)
            multihead_attention[h, :, :] += 0.1
    
    # 归一化
    for h in range(num_heads):
        multihead_attention[h] = multihead_attention[h] / multihead_attention[h].sum(axis=1, keepdims=True)
    
    visualize_multihead_attention(multihead_attention, sentence, num_heads_to_show=4)


# 运行演示
if __name__ == "__main__":
    demo_attention_visualization()

7.3 注意力模式分析

通过观察注意力热力图,我们可以发现一些有趣的模式:

注意力模式 描述 示例
对角线模式 主要关注当前位置或相邻位置 局部特征提取
垂直/水平条纹 某些位置被广泛关注(如标点、特殊token) [CLS]、[SEP] token
块状结构 关注特定短语或句子片段 名词短语识别
稀疏分散 长距离依赖关系 指代消解

8. 高效注意力变体

8.1 标准注意力的计算复杂度问题

标准Self-Attention的计算复杂度为O(n^2),其中n是序列长度。对于长序列(如长文档、高分辨率图像),这成为严重的性能瓶颈。

8.2 Linear Attention

Linear Attention通过核技巧将复杂度从O(n^2)降低到O(n):

class LinearAttention(nn.Module):
    """
    Linear Attention实现
    通过核技巧将O(n^2)复杂度降低到O(n)
    参考: "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(LinearAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_query = nn.Linear(d_model, d_model)
        self.W_key = nn.Linear(d_model, d_model)
        self.W_value = nn.Linear(d_model, d_model)
        self.W_output = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # 投影
        Q = self.W_query(query).view(batch_size, seq_len, self.num_heads, self.d_k)
        K = self.W_key(key).view(batch_size, seq_len, self.num_heads, self.d_k)
        V = self.W_value(value).view(batch_size, seq_len, self.num_heads, self.d_k)
        
        # 应用核函数(elu+1)
        Q = torch.nn.functional.elu(Q) + 1
        K = torch.nn.functional.elu(K) + 1
        
        # 转置为 (batch, heads, seq, d_k)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # Linear Attention核心: (Q @ K^T) @ V = Q @ (K^T @ V)
        # 先计算 K^T @ V: (batch, heads, d_k, d_k)
        KV = torch.matmul(K.transpose(-2, -1), V)
        
        # 再计算 Q @ KV: (batch, heads, seq, d_k)
        Z = 1 / (torch.matmul(Q, K.sum(dim=2).unsqueeze(-1)) + 1e-6)
        output = torch.matmul(Q, KV) * Z
        
        # 合并头
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_output(output)
        
        return output, None  # Linear Attention不直接产生可解释的权重

8.3 Sparse Attention

Sparse Attention通过限制每个位置只能关注部分位置来降低复杂度:

class SparseAttention(nn.Module):
    """
    稀疏注意力实现 - Strided Pattern
    每个位置只关注固定间隔的位置
    """
    def __init__(self, d_model, num_heads, stride=4, dropout=0.1):
        super(SparseAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.stride = stride
        self.d_k = d_model // num_heads
        
        self.W_query = nn.Linear(d_model, d_model)
        self.W_key = nn.Linear(d_model, d_model)
        self.W_value = nn.Linear(d_model, d_model)
        self.W_output = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def create_sparse_mask(self, seq_len, device):
        """创建稀疏注意力掩码"""
        mask = torch.zeros(seq_len, seq_len, device=device)
        
        # 每个位置关注:自身、局部窗口、固定间隔位置
        for i in range(seq_len):
            # 局部窗口(前后各2个)
            for j in range(max(0, i-2), min(seq_len, i+3)):
                mask[i, j] = 1
            
            # 固定间隔位置
            for j in range(0, seq_len, self.stride):
                mask[i, j] = 1
        
        return mask
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # 投影
        Q = self.W_query(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_key(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_value(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # 创建稀疏掩码
        sparse_mask = self.create_sparse_mask(seq_len, query.device)
        sparse_mask = sparse_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq, seq)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        
        # 应用稀疏掩码
        scores = scores.masked_fill(sparse_mask == 0, float('-inf'))
        
        # Softmax和dropout
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = attn_weights.masked_fill(torch.isnan(attn_weights), 0)
        attn_weights = self.dropout(attn_weights)
        
        # 加权求和
        output = torch.matmul(attn_weights, V)
        
        # 合并头
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_output(output)
        
        return output, attn_weights

8.4 高效注意力变体对比

变体名称 时间复杂度 空间复杂度 核心思想 适用场景
标准Attention O(n^2) O(n^2) 全连接注意力 短序列
Linear Attention O(n) O(n) 核技巧重排序 长序列生成
Sparse Attention O(n*sqrt(n)) O(n*sqrt(n)) 稀疏连接模式 长文档处理
Linformer O(n) O(n) 低秩近似 长序列分类
Performer O(n) O(n) 正交随机特征 超长序列
Flash Attention O(n^2) O(1) IO感知的分块计算 硬件优化

9. 完整Transformer编码器层实现

class TransformerEncoderLayer(nn.Module):
    """
    完整的Transformer编码器层
    包含:多头注意力 + 前馈网络 + 残差连接 + LayerNorm
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # 多头自注意力子层
        attn_output, attn_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))  # 残差连接 + LayerNorm
        
        # 前馈网络子层
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)  # 残差连接 + LayerNorm
        
        return x, attn_weights


class PositionalEncoding(nn.Module):
    """
    位置编码实现
    """
    def __init__(self, d_model, max_seq_length=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-torch.log(torch.tensor(10000.0)) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

10. 避坑小贴士

10.1 常见错误与解决方案

问题1:注意力权重全为NaN

原因:输入数值过大导致softmax溢出,或mask使用不当。

解决方案:

# 确保使用缩放因子
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

# 检查mask是否正确应用
if mask is not None:
    scores = scores.masked_fill(mask == 0, float('-inf'))

问题2:多头维度不匹配

原因:d_model不能被num_heads整除。

解决方案:

assert d_model % num_heads == 0, f"d_model ({d_model}) 必须能被 num_heads ({num_heads}) 整除"
self.d_k = d_model // num_heads

问题3:注意力可视化时权重归一化错误

原因:对已经softmax过的权重再次归一化。

解决方案:注意力权重已经是概率分布(每行和为1),直接可视化即可,不需要额外归一化。

10.2 性能优化建议

  1. 使用Flash Attention:对于长序列,使用Flash Attention可以显著减少内存占用并加速计算
  2. 梯度检查点:对于深层Transformer,使用gradient checkpointing节省显存
  3. 混合精度训练:使用torch.cuda.amp进行FP16训练,加速并节省显存

11. 本章小结和知识点回顾

核心概念回顾

  1. 注意力机制的本质:动态加权机制,让模型学会"该看哪里"

  2. Q-K-V范式

    • Query:查询向量,代表当前要寻找的信息
    • Key:键向量,代表输入各位置的标识
    • Value:值向量,代表输入各位置的实际信息
  3. 主要注意力类型

    • 加性注意力:灵活,适合Q/K维度不同
    • 点积注意力:计算高效,是Transformer的标准选择
    • 缩放点积注意力:通过除以sqrt(d_k)防止梯度消失
  4. 多头注意力:在多个子空间并行学习不同类型的依赖关系

  5. 高效注意力变体:Linear Attention、Sparse Attention等解决长序列问题

关键公式总结

缩放点积注意力: Attention(Q,K,V) = softmax(QK^T / sqrt(d_k))V
多头注意力: MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O

一句话总结

注意力机制让神经网络拥有了"选择性关注"的能力,而Self-Attention和Multi-Head Attention的出现,让模型能够直接建模序列中任意位置之间的关系,彻底改变了深度学习处理序列数据的方式。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐