环境声明

  • Python版本:Python 3.12+
  • 深度学习框架:PyTorch 2.0+
  • 图神经网络库:PyTorch Geometric 2.4+
  • 开发工具:PyCharm 或 VS Code
  • 硬件要求:GPU(推荐)或 CPU

学习目标

完成本章学习后,你将能够:

  • 理解图数据的数学表示与非欧几里得特性
  • 掌握GCN、GAT、GraphSAGE等经典GNN模型的原理
  • 实现图卷积网络的数学推导与代码实现
  • 应用GNN解决社交网络分析、分子分类等实际问题
  • 了解Graph Transformer、GraphGPT等前沿研究方向

1. 图数据与图表示学习

1.1 什么是图数据

图(Graph)是一种非欧几里得数据结构,由节点(Node/Vertex)和边(Edge)组成。与图像、文本等规则数据不同,图数据没有固定的网格结构。

一句话总结:图就像一张社交网络——每个人是一个节点,朋友关系是边,没有固定的排列顺序。

1.2 图的基本定义

一个图 G 可以形式化定义为:

G=(V,E) G = (V, E) G=(V,E)

其中:

  • V 是节点集合,|V| = N 表示节点数量
  • E 是边集合,E ⊆ V × V
  • A 是邻接矩阵,A ∈ {0, 1}^{N×N}

1.3 邻接矩阵与度矩阵

邻接矩阵 A:描述节点之间的连接关系

    A  B  C  D
A [ 0  1  1  0 ]
B [ 1  0  1  1 ]
C [ 1  1  0  0 ]
D [ 0  1  0  0 ]

度矩阵 D:对角矩阵,D_{ii} = Σ_j A_{ij}

D=diag(d1,d2,...,dN) D = \text{diag}(d_1, d_2, ..., d_N) D=diag(d1,d2,...,dN)

归一化邻接矩阵

A~=D−1/2AD−1/2 \tilde{A} = D^{-1/2} A D^{-1/2} A~=D1/2AD1/2

1.4 图拉普拉斯矩阵

图拉普拉斯矩阵是图信号处理的核心工具:

未归一化拉普拉斯矩阵

L=D−A L = D - A L=DA

对称归一化拉普拉斯矩阵

Lsym=D−1/2LD−1/2=I−D−1/2AD−1/2 L_{sym} = D^{-1/2} L D^{-1/2} = I - D^{-1/2} A D^{-1/2} Lsym=D1/2LD1/2=ID1/2AD1/2

随机游走归一化拉普拉斯矩阵

Lrw=D−1L=I−D−1A L_{rw} = D^{-1} L = I - D^{-1} A Lrw=D1L=ID1A

1.5 图信号处理基础

图信号是定义在图节点上的函数 f: V → R^d。图傅里叶变换基于拉普拉斯矩阵的特征分解:

L=UΛUT L = U \Lambda U^T L=UΛUT

其中 U 是特征向量矩阵,Λ = diag(λ_1, …, λ_N) 是特征值对角矩阵。

图傅里叶变换

f^=UTf \hat{f} = U^T f f^=UTf

逆图傅里叶变换

f=Uf^ f = U \hat{f} f=Uf^

1.6 图数据的类型

图类型 描述 应用场景
同质图 单一类型节点和边 社交网络、引文网络
异质图 多种类型节点和边 知识图谱、推荐系统
动态图 随时间变化的图 交通网络、交易网络
超图 一条边可连接多个节点 团队关系、化合物
属性图 节点和边带有特征 分子图、知识图谱

2. GCN图卷积网络

2.1 从CNN到GCN

传统卷积神经网络(CNN)在图像上滑动卷积核,利用平移不变性提取局部特征。但图数据没有规则的网格结构,无法直接应用标准卷积。

核心问题:如何在不规则的图结构上定义卷积操作?

2.2 谱域图卷积

谱域方法基于图傅里叶变换,在频域定义卷积:

f∗g=U((UTf)⊙(UTg)) f * g = U((U^T f) \odot (U^T g)) fg=U((UTf)(UTg))

其中 ⊙ 表示逐元素乘法。

谱卷积网络(Spectral CNN)

y=U⋅gθ(Λ)⋅UTx y = U \cdot g_\theta(\Lambda) \cdot U^T x y=Ugθ(Λ)UTx

其中 g_θ(Λ) 是对角矩阵,表示可学习的滤波器。

2.3 切比雪夫多项式近似

直接计算谱卷积需要 O(N^3) 的特征分解,计算代价高昂。Defferrard等人提出使用切比雪夫多项式近似:

gθ(Λ)=∑k=0KθkTk(Λ~) g_\theta(\Lambda) = \sum_{k=0}^{K} \theta_k T_k(\tilde{\Lambda}) gθ(Λ)=k=0KθkTk(Λ~)

其中:

  • T_k 是k阶切比雪夫多项式
  • \tilde{\Lambda} = 2Λ/λ_{max} - I 是归一化特征值
  • K 是多项式阶数,控制局部感受野大小

递推公式

T0(x)=1 T_0(x) = 1 T0(x)=1
T1(x)=x T_1(x) = x T1(x)=x
Tk(x)=2xTk−1(x)−Tk−2(x) T_k(x) = 2x T_{k-1}(x) - T_{k-2}(x) Tk(x)=2xTk1(x)Tk2(x)

2.4 GCN的数学推导

Kipf和Welling在2016年提出简化的图卷积网络(GCN),取 K=1 并做进一步近似:

第一层传播规则

H(l+1)=σ(D~−1/2A~D~−1/2H(l)W(l)) H^{(l+1)} = \sigma(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)}) H(l+1)=σ(D~1/2A~D~1/2H(l)W(l))

其中:

  • \tilde{A} = A + I_N(添加自环)
  • \tilde{D}{ii} = Σ_j \tilde{A}{ij}
  • H^{(l)} 是第l层的节点特征矩阵
  • W^{(l)} 是可学习的权重矩阵
  • σ 是非线性激活函数

完整推导过程

  1. 从切比雪夫展开出发,设 K=1
  2. 假设 λ_{max} ≈ 2,简化特征值缩放
  3. 添加自环连接 A → A + I
  4. 对权重矩阵进行参数共享

2.5 GCN的空间域解释

GCN可以解释为邻居特征的加权聚合:

hi(l+1)=σ(∑j∈N(i)∪{i}1d~id~jhj(l)W(l)) h_i^{(l+1)} = \sigma\left(\sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{\sqrt{\tilde{d}_i \tilde{d}_j}} h_j^{(l)} W^{(l)}\right) hi(l+1)=σ jN(i){i}d~id~j 1hj(l)W(l)

其中归一化系数 1/√(\tilde{d}_i \tilde{d}_j) 起到平衡不同度数节点的作用。

2.6 GCN的局限性

局限性 说明 解决方案
过平滑 层数增加导致节点表示趋同 残差连接、DropEdge
归纳能力 无法处理未见过的图 GraphSAGE
边权重 无法学习边的权重 GAT
大规模图 内存消耗大 采样方法

3. GAT图注意力网络

3.1 注意力机制回顾

注意力机制允许模型动态地关注输入的不同部分。在序列模型中,注意力计算为:

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

3.2 图注意力层

Velickovic等人(2018)将注意力机制引入图神经网络,提出GAT(Graph Attention Network)。

注意力系数计算

eij=LeakyReLU(aT[Whi∥Whj]) e_{ij} = \text{LeakyReLU}(a^T [W h_i \| W h_j]) eij=LeakyReLU(aT[WhiWhj])

其中:

  • W 是共享的线性变换矩阵
  • a 是注意力权重向量
  • || 表示向量拼接
  • 只计算邻居节点 j ∈ N(i)

归一化注意力系数

αij=exp⁡(eij)∑k∈N(i)exp⁡(eik) \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})} αij=kN(i)exp(eik)exp(eij)

输出特征计算

hi′=σ(∑j∈N(i)αijWhj) h_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W h_j\right) hi=σ jN(i)αijWhj

3.3 多头注意力机制

为增加模型表达能力,GAT使用多头注意力:

hi′=∥k=1Kσ(∑j∈N(i)αij(k)W(k)hj) h_i' = \|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^{(k)} W^{(k)} h_j\right) hi=k=1Kσ jN(i)αij(k)W(k)hj

在最后一层,使用平均代替拼接:

hi′=σ(1K∑k=1K∑j∈N(i)αij(k)W(k)hj) h_i' = \sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^{(k)} W^{(k)} h_j\right) hi=σ K1k=1KjN(i)αij(k)W(k)hj

3.4 GAT的优势

特性 GCN GAT
权重学习 固定(基于度) 自适应(注意力)
计算复杂度 O( E
归纳学习
可解释性 高(注意力权重)

4. GraphSAGE与采样

4.1 归纳式学习的需求

GCN和GAT都是直推式(Transductive)方法,需要所有节点在训练时已知。GraphSAGE(SAmple and agGreGatE)提出归纳式(Inductive)学习框架。

一句话总结:GraphSAGE像一位聪明的学生,学会了"如何聚合邻居信息"的方法,而不是死记硬背每个节点的表示。

4.2 GraphSAGE算法

采样阶段:对每个节点,采样固定数量的邻居

聚合阶段:使用聚合函数合并邻居信息

算法流程

输入:图 G=(V,E),节点特征 {x_v, ∀v∈V},深度 K,邻居采样数 S_k
输出:所有节点的向量表示 z_v

for k = 1 to K do
    for v ∈ V do
        // 采样邻居
        N_k(v) ← Sample(N(v), S_k)
        // 聚合邻居表示
        h_{N(v)}^{(k)} ← Aggregate({h_u^{(k-1)}, ∀u ∈ N_k(v)})
        // 更新节点表示
        h_v^{(k)} ← σ(W^{(k)} · CONCAT(h_v^{(k-1)}, h_{N(v)}^{(k)}))
    end
    // 归一化
    h_v^{(k)} ← h_v^{(k)} / ||h_v^{(k)}||_2
end

z_v ← h_v^{(K)}

4.3 聚合函数

聚合函数 公式 特点
Mean h_{N(v)} = mean({h_u, ∀u∈N(v)}) 简单高效
LSTM h_{N(v)} = LSTM(permute({h_u})) 表达力强,非对称
Pooling h_{N(v)} = max({σ(W_{pool}h_u+b)}) 适合离散特征

4.4 邻居采样策略

固定数量采样:每层采样固定数量的邻居,控制计算复杂度

重要性采样:根据节点重要性加权采样

层间采样复杂度

O(∏k=1KSk) O\left(\prod_{k=1}^{K} S_k\right) O(k=1KSk)

当 S_k 为常数时,复杂度与图大小无关。


5. 图嵌入与链路预测

5.1 节点嵌入

节点嵌入将节点映射到低维向量空间,保留图的结构信息。

目标函数

L=∑(u,v)∈Elog⁡σ(zuTzv)+∑(u,v)∉Elog⁡σ(−zuTzv) \mathcal{L} = \sum_{(u,v) \in E} \log \sigma(z_u^T z_v) + \sum_{(u,v) \notin E} \log \sigma(-z_u^T z_v) L=(u,v)Elogσ(zuTzv)+(u,v)/Elogσ(zuTzv)

5.2 图级表示

对于图分类任务,需要将节点表示聚合为图级表示:

全局平均池化

hG=1N∑i=1Nhi h_G = \frac{1}{N} \sum_{i=1}^{N} h_i hG=N1i=1Nhi

全局最大池化

hG=max⁡i=1Nhi h_G = \max_{i=1}^{N} h_i hG=i=1maxNhi

层次化池化(DiffPool)

S(l)=softmax(GNNl,pool(A(l),X(l))) S^{(l)} = \text{softmax}(GNN_{l,pool}(A^{(l)}, X^{(l)})) S(l)=softmax(GNNl,pool(A(l),X(l)))
X(l+1)=S(l)TZ(l) X^{(l+1)} = S^{(l)T} Z^{(l)} X(l+1)=S(l)TZ(l)
A(l+1)=S(l)TA(l)S(l) A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)} A(l+1)=S(l)TA(l)S(l)

5.3 链路预测

链路预测任务预测两个节点之间是否存在边。

解码器设计

方法 公式 适用场景
内积 score(u,v) = z_u^T z_v 无向图
拼接MLP score(u,v) = MLP([z_u | z_v]) 复杂关系
双线性 score(u,v) = z_u^T W z_v 关系预测

负采样策略

对于大规模图,对所有非边计算损失代价高昂。采用负采样:

L=−log⁡σ(zuTzv)−∑i=1kEvi∼Pn(v)[log⁡σ(−zuTzvi)] \mathcal{L} = -\log \sigma(z_u^T z_v) - \sum_{i=1}^{k} \mathbb{E}_{v_i \sim P_n(v)}[\log \sigma(-z_u^T z_{v_i})] L=logσ(zuTzv)i=1kEviPn(v)[logσ(zuTzvi)]


6. GNN前沿进展

6.1 Graph Transformer

2024-2025年,Graph Transformer成为GNN研究的热点方向。传统Transformer通过自注意力机制捕获全局依赖,与GNN结合后兼具局部结构和全局语义建模能力。

核心思想:将图结构信息融入Transformer架构

代表性工作

模型 发表 核心创新
Graphormer NeurIPS 2021 空间编码、边编码
SAN ICLR 2022 拉普拉斯特征位置编码
GraphGPS ICLR 2022 通用、可扩展、等变
POLYNORMER ICLR 2024 多项式表达能力
Ring-Enhanced GT AAAI 2025 环增强结构编码

6.2 GraphGPT

大语言模型(LLM)与图神经网络的融合是2024年的重要趋势。

GraphGPT框架

  1. 图-文本对齐:将图结构编码与文本语义空间对齐
  2. 指令微调:在图任务上微调LLM
  3. 多模态推理:结合图结构和自然语言进行推理

应用场景

  • 知识图谱问答
  • 分子性质预测
  • 推荐系统解释

6.3 图基础模型

2024年,图基础模型(Graph Foundation Models)成为研究前沿。

核心挑战

  • 跨域迁移学习
  • 零样本图理解
  • 大规模预训练

技术路线

方向 方法 代表工作
预训练GNN 自监督预训练 + 微调 GCC, GraphCL
GNN+LLM 图编码器 + 大语言模型 GraphGPT, InstructGLM
统一架构 Transformer统一建模 Graphormer, TokenGT

6.4 2024-2025研究趋势

  1. 多模态图学习:融合文本、图像、图结构
  2. 动态图神经网络:处理时序变化的图
  3. 可解释GNN:注意力可视化、因果推理
  4. 高效GNN:模型压缩、量化、蒸馏
  5. 科学发现:药物发现、材料设计、蛋白质结构预测

7. 实战案例:社交网络节点分类

7.1 案例背景

使用Cora引文网络数据集进行节点分类。Cora包含2708篇机器学习论文,分为7个类别,引文关系构成图结构。

7.2 环境配置

# 安装依赖
# pip install torch torchvision torchaudio
# pip install torch-geometric
# pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cpu.html

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import matplotlib.pyplot as plt
import numpy as np

# 设置随机种子
torch.manual_seed(42)

7.3 数据加载与探索

# 加载Cora数据集
dataset = Planetoid(root='data/Cora', name='Cora', transform=NormalizeFeatures())

data = dataset[0]

print(f"数据集名称: {dataset.name}")
print(f"节点数量: {data.num_nodes}")
print(f"边数量: {data.num_edges}")
print(f"特征维度: {dataset.num_features}")
print(f"类别数量: {dataset.num_classes}")
print(f"训练集大小: {data.train_mask.sum().item()}")
print(f"验证集大小: {data.val_mask.sum().item()}")
print(f"测试集大小: {data.test_mask.sum().item()}")

# 数据对象结构
print(f"\n数据对象属性:")
print(f"x (节点特征): {data.x.shape}")
print(f"edge_index (边索引): {data.edge_index.shape}")
print(f"y (节点标签): {data.y.shape}")

7.4 GCN模型实现

class GCN(nn.Module):
    """两层GCN模型"""
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        # 第一层GCN + ReLU + Dropout
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # 第二层GCN
        x = self.conv2(x, edge_index)
        
        return x

# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_gcn = GCN(
    in_channels=dataset.num_features,
    hidden_channels=16,
    out_channels=dataset.num_classes,
    dropout=0.5
).to(device)

data = data.to(device)
print(f"\nGCN模型结构:\n{model_gcn}")

7.5 GAT模型实现

class GAT(nn.Module):
    """两层GAT模型,使用多头注意力"""
    def __init__(self, in_channels, hidden_channels, out_channels, 
                 heads=8, dropout=0.6):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, 
                             dropout=dropout)
        # 第二层使用单头注意力,拼接改为平均
        self.conv2 = GATConv(hidden_channels * heads, out_channels, 
                             heads=1, concat=False, dropout=dropout)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        # 第一层GAT + ELU + Dropout
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # 第二层GAT
        x = self.conv2(x, edge_index)
        
        return x

# 初始化GAT模型
model_gat = GAT(
    in_channels=dataset.num_features,
    hidden_channels=8,
    out_channels=dataset.num_classes,
    heads=8,
    dropout=0.6
).to(device)

print(f"GAT模型结构:\n{model_gat}")

7.6 GraphSAGE模型实现

class GraphSAGE(nn.Module):
    """GraphSAGE模型,使用均值聚合"""
    def __init__(self, in_channels, hidden_channels, out_channels, 
                 num_layers=2, dropout=0.5):
        super(GraphSAGE, self).__init__()
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

# 初始化GraphSAGE模型
model_sage = GraphSAGE(
    in_channels=dataset.num_features,
    hidden_channels=16,
    out_channels=dataset.num_classes,
    num_layers=2,
    dropout=0.5
).to(device)

print(f"GraphSAGE模型结构:\n{model_sage}")

7.7 训练与评估函数

def train(model, data, optimizer, criterion):
    """训练一个epoch"""
    model.train()
    optimizer.zero_grad()
    
    # 前向传播
    out = model(data.x, data.edge_index)
    
    # 计算训练集损失
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    
    # 反向传播
    loss.backward()
    optimizer.step()
    
    return loss.item()

@torch.no_grad()
def evaluate(model, data):
    """评估模型性能"""
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = pred[mask].eq(data.y[mask]).sum().item()
        acc = correct / mask.sum().item()
        accs.append(acc)
    
    return accs

def train_model(model, data, epochs=200, lr=0.01, weight_decay=5e-4):
    """完整训练流程"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0
    test_acc_at_best = 0
    history = {'train': [], 'val': [], 'test': []}
    
    for epoch in range(1, epochs + 1):
        loss = train(model, data, optimizer, criterion)
        train_acc, val_acc, test_acc = evaluate(model, data)
        
        history['train'].append(train_acc)
        history['val'].append(val_acc)
        history['test'].append(test_acc)
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc_at_best = test_acc
        
        if epoch % 20 == 0:
            print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, "
                  f"Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}")
    
    return history, best_val_acc, test_acc_at_best

7.8 模型对比实验

print("=" * 60)
print("开始训练GCN模型")
print("=" * 60)
history_gcn, val_acc_gcn, test_acc_gcn = train_model(model_gcn, data)

print("\n" + "=" * 60)
print("开始训练GAT模型")
print("=" * 60)
history_gat, val_acc_gat, test_acc_gat = train_model(model_gat, data, lr=0.005)

print("\n" + "=" * 60)
print("开始训练GraphSAGE模型")
print("=" * 60)
history_sage, val_acc_sage, test_acc_sage = train_model(model_sage, data)

print("\n" + "=" * 60)
print("实验结果对比")
print("=" * 60)
print(f"GCN       - 最佳验证准确率: {val_acc_gcn:.4f}, 测试准确率: {test_acc_gcn:.4f}")
print(f"GAT       - 最佳验证准确率: {val_acc_gat:.4f}, 测试准确率: {test_acc_gat:.4f}")
print(f"GraphSAGE - 最佳验证准确率: {val_acc_sage:.4f}, 测试准确率: {test_acc_sage:.4f}")

7.9 可视化训练过程

def plot_training_history(histories, labels):
    """绘制训练曲线"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    for i, split in enumerate(['train', 'val', 'test']):
        for history, label in zip(histories, labels):
            axes[i].plot(history[split], label=label)
        axes[i].set_xlabel('Epoch')
        axes[i].set_ylabel('Accuracy')
        axes[i].set_title(f'{split.capitalize()} Accuracy')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('gnn_training_comparison.png', dpi=150)
    plt.show()

# 绘制对比图
plot_training_history(
    [history_gcn, history_gat, history_sage],
    ['GCN', 'GAT', 'GraphSAGE']
)

7.10 节点嵌入可视化

from sklearn.manifold import TSNE

def visualize_embeddings(model, data, title):
    """使用t-SNE可视化节点嵌入"""
    model.eval()
    with torch.no_grad():
        # 获取最后一层之前的特征
        if isinstance(model, GCN):
            x = model.conv1(data.x, data.edge_index)
            x = F.relu(x)
        elif isinstance(model, GAT):
            x = model.conv1(data.x, data.edge_index)
            x = F.elu(x)
        else:
            x = model.convs[0](data.x, data.edge_index)
            x = F.relu(x)
    
    # t-SNE降维
    embeddings = x.cpu().numpy()
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    # 绘制散点图
    plt.figure(figsize=(10, 8))
    colors = plt.cm.tab10(np.linspace(0, 1, dataset.num_classes))
    
    for i in range(dataset.num_classes):
        mask = data.y.cpu() == i
        plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], 
                   c=[colors[i]], label=f'Class {i}', alpha=0.6, s=20)
    
    plt.title(f'{title} - Node Embeddings Visualization')
    plt.legend()
    plt.savefig(f'{title.lower()}_embeddings.png', dpi=150)
    plt.show()

# 可视化各模型的嵌入
visualize_embeddings(model_gcn, data, 'GCN')
visualize_embeddings(model_gat, data, 'GAT')

7.11 注意力权重可视化(GAT)

def visualize_attention(model, data, num_nodes=100):
    """可视化GAT的注意力权重"""
    model.eval()
    
    # 获取注意力权重
    with torch.no_grad():
        # 第一层注意力
        _, attn_weights = model.conv1(data.x, data.edge_index, return_attention_weights=True)
    
    edge_index, attn = attn_weights
    
    # 选择部分节点进行可视化
    plt.figure(figsize=(12, 8))
    
    # 绘制注意力分布
    plt.subplot(1, 2, 1)
    plt.hist(attn.cpu().numpy().flatten(), bins=50, alpha=0.7)
    plt.xlabel('Attention Weight')
    plt.ylabel('Frequency')
    plt.title('Distribution of Attention Weights')
    
    # 绘制注意力热力图(前num_nodes个节点)
    plt.subplot(1, 2, 2)
    adj_attn = torch.zeros(num_nodes, num_nodes)
    mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
    
    for i in range(mask.sum()):
        src = edge_index[0][mask][i].item()
        dst = edge_index[1][mask][i].item()
        adj_attn[src, dst] = attn[mask][i].mean().item()
    
    plt.imshow(adj_attn.numpy(), cmap='hot', interpolation='nearest')
    plt.colorbar(label='Attention Weight')
    plt.title(f'Attention Heatmap (First {num_nodes} Nodes)')
    plt.xlabel('Target Node')
    plt.ylabel('Source Node')
    
    plt.tight_layout()
    plt.savefig('gat_attention_visualization.png', dpi=150)
    plt.show()

# 可视化GAT注意力
visualize_attention(model_gat, data)

8. 避坑小贴士

8.1 数据预处理陷阱

问题 现象 解决方案
未添加自环 节点无法保留自身信息 使用 add_self_loops 或 GCNConv 的 improved 参数
特征未归一化 梯度爆炸/消失 使用 NormalizeFeatures 或手动标准化
忽略图的方向性 有向图被当作无向图处理 检查 edge_index 是否需要反转添加

8.2 模型设计陷阱

问题 现象 解决方案
层数过深 过平滑,性能下降 限制层数在2-3层,使用残差连接
隐藏维度不当 欠拟合或过拟合 从16/32开始,根据数据量调整
忽略Dropout 训练集过拟合 设置 dropout=0.5-0.6

8.3 训练过程陷阱

问题 现象 解决方案
学习率过大 损失震荡不收敛 从0.01或0.005开始,使用学习率衰减
权重衰减过小 模型复杂度过高 设置 weight_decay=5e-4
未使用验证集 无法判断最佳模型 早停策略,保存验证集最优模型

8.4 大规模图处理

对于大规模图(百万级节点),直接加载全图会OOM:

# 使用NeighborLoader进行采样训练
from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
    data,
    num_neighbors=[10, 10],  # 每层采样10个邻居
    batch_size=64,
    input_nodes=data.train_mask,
)

for batch in loader:
    # batch 包含采样的子图
    out = model(batch.x, batch.edge_index)
    loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])

9. 本章小结

本章系统介绍了图神经网络的核心概念与前沿进展:

核心知识点回顾

  1. 图数据表示:邻接矩阵、度矩阵、拉普拉斯矩阵的数学定义
  2. GCN:谱域卷积到空间域卷积的简化,归一化邻接矩阵的作用
  3. GAT:注意力机制在图上的应用,多头注意力的实现
  4. GraphSAGE:归纳式学习框架,邻居采样与聚合策略
  5. 图嵌入:节点嵌入、图级表示、链路预测任务

前沿进展

  • Graph Transformer:融合全局注意力与图结构
  • GraphGPT:大语言模型与图的结合
  • 图基础模型:跨域迁移与零样本学习

实战技能

  • 使用PyTorch Geometric构建GNN模型
  • 实现GCN、GAT、GraphSAGE三种经典架构
  • 社交网络节点分类的完整流程

进一步学习方向

  • 异质图神经网络(HAN, RGCN)
  • 动态图神经网络(EvolveGCN, TGAT)
  • 图生成模型(VGAE, GraphVAE)
  • 图强化学习

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐