【机器学习精通】第14章 | 图神经网络:处理非欧几里得数据
环境声明
- Python版本:Python 3.12+
- 深度学习框架:PyTorch 2.0+
- 图神经网络库:PyTorch Geometric 2.4+
- 开发工具:PyCharm 或 VS Code
- 硬件要求:GPU(推荐)或 CPU
学习目标
完成本章学习后,你将能够:
- 理解图数据的数学表示与非欧几里得特性
- 掌握GCN、GAT、GraphSAGE等经典GNN模型的原理
- 实现图卷积网络的数学推导与代码实现
- 应用GNN解决社交网络分析、分子分类等实际问题
- 了解Graph Transformer、GraphGPT等前沿研究方向
1. 图数据与图表示学习
1.1 什么是图数据
图(Graph)是一种非欧几里得数据结构,由节点(Node/Vertex)和边(Edge)组成。与图像、文本等规则数据不同,图数据没有固定的网格结构。
一句话总结:图就像一张社交网络——每个人是一个节点,朋友关系是边,没有固定的排列顺序。
1.2 图的基本定义
一个图 G 可以形式化定义为:
G=(V,E) G = (V, E) G=(V,E)
其中:
- V 是节点集合,|V| = N 表示节点数量
- E 是边集合,E ⊆ V × V
- A 是邻接矩阵,A ∈ {0, 1}^{N×N}
1.3 邻接矩阵与度矩阵
邻接矩阵 A:描述节点之间的连接关系
A B C D
A [ 0 1 1 0 ]
B [ 1 0 1 1 ]
C [ 1 1 0 0 ]
D [ 0 1 0 0 ]
度矩阵 D:对角矩阵,D_{ii} = Σ_j A_{ij}
D=diag(d1,d2,...,dN) D = \text{diag}(d_1, d_2, ..., d_N) D=diag(d1,d2,...,dN)
归一化邻接矩阵:
A~=D−1/2AD−1/2 \tilde{A} = D^{-1/2} A D^{-1/2} A~=D−1/2AD−1/2
1.4 图拉普拉斯矩阵
图拉普拉斯矩阵是图信号处理的核心工具:
未归一化拉普拉斯矩阵:
L=D−A L = D - A L=D−A
对称归一化拉普拉斯矩阵:
Lsym=D−1/2LD−1/2=I−D−1/2AD−1/2 L_{sym} = D^{-1/2} L D^{-1/2} = I - D^{-1/2} A D^{-1/2} Lsym=D−1/2LD−1/2=I−D−1/2AD−1/2
随机游走归一化拉普拉斯矩阵:
Lrw=D−1L=I−D−1A L_{rw} = D^{-1} L = I - D^{-1} A Lrw=D−1L=I−D−1A
1.5 图信号处理基础
图信号是定义在图节点上的函数 f: V → R^d。图傅里叶变换基于拉普拉斯矩阵的特征分解:
L=UΛUT L = U \Lambda U^T L=UΛUT
其中 U 是特征向量矩阵,Λ = diag(λ_1, …, λ_N) 是特征值对角矩阵。
图傅里叶变换:
f^=UTf \hat{f} = U^T f f^=UTf
逆图傅里叶变换:
f=Uf^ f = U \hat{f} f=Uf^
1.6 图数据的类型
| 图类型 | 描述 | 应用场景 |
|---|---|---|
| 同质图 | 单一类型节点和边 | 社交网络、引文网络 |
| 异质图 | 多种类型节点和边 | 知识图谱、推荐系统 |
| 动态图 | 随时间变化的图 | 交通网络、交易网络 |
| 超图 | 一条边可连接多个节点 | 团队关系、化合物 |
| 属性图 | 节点和边带有特征 | 分子图、知识图谱 |
2. GCN图卷积网络
2.1 从CNN到GCN
传统卷积神经网络(CNN)在图像上滑动卷积核,利用平移不变性提取局部特征。但图数据没有规则的网格结构,无法直接应用标准卷积。
核心问题:如何在不规则的图结构上定义卷积操作?
2.2 谱域图卷积
谱域方法基于图傅里叶变换,在频域定义卷积:
f∗g=U((UTf)⊙(UTg)) f * g = U((U^T f) \odot (U^T g)) f∗g=U((UTf)⊙(UTg))
其中 ⊙ 表示逐元素乘法。
谱卷积网络(Spectral CNN):
y=U⋅gθ(Λ)⋅UTx y = U \cdot g_\theta(\Lambda) \cdot U^T x y=U⋅gθ(Λ)⋅UTx
其中 g_θ(Λ) 是对角矩阵,表示可学习的滤波器。
2.3 切比雪夫多项式近似
直接计算谱卷积需要 O(N^3) 的特征分解,计算代价高昂。Defferrard等人提出使用切比雪夫多项式近似:
gθ(Λ)=∑k=0KθkTk(Λ~) g_\theta(\Lambda) = \sum_{k=0}^{K} \theta_k T_k(\tilde{\Lambda}) gθ(Λ)=k=0∑KθkTk(Λ~)
其中:
- T_k 是k阶切比雪夫多项式
- \tilde{\Lambda} = 2Λ/λ_{max} - I 是归一化特征值
- K 是多项式阶数,控制局部感受野大小
递推公式:
T0(x)=1 T_0(x) = 1 T0(x)=1
T1(x)=x T_1(x) = x T1(x)=x
Tk(x)=2xTk−1(x)−Tk−2(x) T_k(x) = 2x T_{k-1}(x) - T_{k-2}(x) Tk(x)=2xTk−1(x)−Tk−2(x)
2.4 GCN的数学推导
Kipf和Welling在2016年提出简化的图卷积网络(GCN),取 K=1 并做进一步近似:
第一层传播规则:
H(l+1)=σ(D~−1/2A~D~−1/2H(l)W(l)) H^{(l+1)} = \sigma(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)}) H(l+1)=σ(D~−1/2A~D~−1/2H(l)W(l))
其中:
- \tilde{A} = A + I_N(添加自环)
- \tilde{D}{ii} = Σ_j \tilde{A}{ij}
- H^{(l)} 是第l层的节点特征矩阵
- W^{(l)} 是可学习的权重矩阵
- σ 是非线性激活函数
完整推导过程:
- 从切比雪夫展开出发,设 K=1
- 假设 λ_{max} ≈ 2,简化特征值缩放
- 添加自环连接 A → A + I
- 对权重矩阵进行参数共享
2.5 GCN的空间域解释
GCN可以解释为邻居特征的加权聚合:
hi(l+1)=σ(∑j∈N(i)∪{i}1d~id~jhj(l)W(l)) h_i^{(l+1)} = \sigma\left(\sum_{j \in \mathcal{N}(i) \cup \{i\}} \frac{1}{\sqrt{\tilde{d}_i \tilde{d}_j}} h_j^{(l)} W^{(l)}\right) hi(l+1)=σ j∈N(i)∪{i}∑d~id~j1hj(l)W(l)
其中归一化系数 1/√(\tilde{d}_i \tilde{d}_j) 起到平衡不同度数节点的作用。
2.6 GCN的局限性
| 局限性 | 说明 | 解决方案 |
|---|---|---|
| 过平滑 | 层数增加导致节点表示趋同 | 残差连接、DropEdge |
| 归纳能力 | 无法处理未见过的图 | GraphSAGE |
| 边权重 | 无法学习边的权重 | GAT |
| 大规模图 | 内存消耗大 | 采样方法 |
3. GAT图注意力网络
3.1 注意力机制回顾
注意力机制允许模型动态地关注输入的不同部分。在序列模型中,注意力计算为:
Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
3.2 图注意力层
Velickovic等人(2018)将注意力机制引入图神经网络,提出GAT(Graph Attention Network)。
注意力系数计算:
eij=LeakyReLU(aT[Whi∥Whj]) e_{ij} = \text{LeakyReLU}(a^T [W h_i \| W h_j]) eij=LeakyReLU(aT[Whi∥Whj])
其中:
- W 是共享的线性变换矩阵
- a 是注意力权重向量
- || 表示向量拼接
- 只计算邻居节点 j ∈ N(i)
归一化注意力系数:
αij=exp(eij)∑k∈N(i)exp(eik) \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})} αij=∑k∈N(i)exp(eik)exp(eij)
输出特征计算:
hi′=σ(∑j∈N(i)αijWhj) h_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W h_j\right) hi′=σ j∈N(i)∑αijWhj
3.3 多头注意力机制
为增加模型表达能力,GAT使用多头注意力:
hi′=∥k=1Kσ(∑j∈N(i)αij(k)W(k)hj) h_i' = \|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^{(k)} W^{(k)} h_j\right) hi′=∥k=1Kσ j∈N(i)∑αij(k)W(k)hj
在最后一层,使用平均代替拼接:
hi′=σ(1K∑k=1K∑j∈N(i)αij(k)W(k)hj) h_i' = \sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^{(k)} W^{(k)} h_j\right) hi′=σ K1k=1∑Kj∈N(i)∑αij(k)W(k)hj
3.4 GAT的优势
| 特性 | GCN | GAT |
|---|---|---|
| 权重学习 | 固定(基于度) | 自适应(注意力) |
| 计算复杂度 | O( | E |
| 归纳学习 | 否 | 是 |
| 可解释性 | 低 | 高(注意力权重) |
4. GraphSAGE与采样
4.1 归纳式学习的需求
GCN和GAT都是直推式(Transductive)方法,需要所有节点在训练时已知。GraphSAGE(SAmple and agGreGatE)提出归纳式(Inductive)学习框架。
一句话总结:GraphSAGE像一位聪明的学生,学会了"如何聚合邻居信息"的方法,而不是死记硬背每个节点的表示。
4.2 GraphSAGE算法
采样阶段:对每个节点,采样固定数量的邻居
聚合阶段:使用聚合函数合并邻居信息
算法流程:
输入:图 G=(V,E),节点特征 {x_v, ∀v∈V},深度 K,邻居采样数 S_k
输出:所有节点的向量表示 z_v
for k = 1 to K do
for v ∈ V do
// 采样邻居
N_k(v) ← Sample(N(v), S_k)
// 聚合邻居表示
h_{N(v)}^{(k)} ← Aggregate({h_u^{(k-1)}, ∀u ∈ N_k(v)})
// 更新节点表示
h_v^{(k)} ← σ(W^{(k)} · CONCAT(h_v^{(k-1)}, h_{N(v)}^{(k)}))
end
// 归一化
h_v^{(k)} ← h_v^{(k)} / ||h_v^{(k)}||_2
end
z_v ← h_v^{(K)}
4.3 聚合函数
| 聚合函数 | 公式 | 特点 |
|---|---|---|
| Mean | h_{N(v)} = mean({h_u, ∀u∈N(v)}) | 简单高效 |
| LSTM | h_{N(v)} = LSTM(permute({h_u})) | 表达力强,非对称 |
| Pooling | h_{N(v)} = max({σ(W_{pool}h_u+b)}) | 适合离散特征 |
4.4 邻居采样策略
固定数量采样:每层采样固定数量的邻居,控制计算复杂度
重要性采样:根据节点重要性加权采样
层间采样复杂度:
O(∏k=1KSk) O\left(\prod_{k=1}^{K} S_k\right) O(k=1∏KSk)
当 S_k 为常数时,复杂度与图大小无关。
5. 图嵌入与链路预测
5.1 节点嵌入
节点嵌入将节点映射到低维向量空间,保留图的结构信息。
目标函数:
L=∑(u,v)∈Elogσ(zuTzv)+∑(u,v)∉Elogσ(−zuTzv) \mathcal{L} = \sum_{(u,v) \in E} \log \sigma(z_u^T z_v) + \sum_{(u,v) \notin E} \log \sigma(-z_u^T z_v) L=(u,v)∈E∑logσ(zuTzv)+(u,v)∈/E∑logσ(−zuTzv)
5.2 图级表示
对于图分类任务,需要将节点表示聚合为图级表示:
全局平均池化:
hG=1N∑i=1Nhi h_G = \frac{1}{N} \sum_{i=1}^{N} h_i hG=N1i=1∑Nhi
全局最大池化:
hG=maxi=1Nhi h_G = \max_{i=1}^{N} h_i hG=i=1maxNhi
层次化池化(DiffPool):
S(l)=softmax(GNNl,pool(A(l),X(l))) S^{(l)} = \text{softmax}(GNN_{l,pool}(A^{(l)}, X^{(l)})) S(l)=softmax(GNNl,pool(A(l),X(l)))
X(l+1)=S(l)TZ(l) X^{(l+1)} = S^{(l)T} Z^{(l)} X(l+1)=S(l)TZ(l)
A(l+1)=S(l)TA(l)S(l) A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)} A(l+1)=S(l)TA(l)S(l)
5.3 链路预测
链路预测任务预测两个节点之间是否存在边。
解码器设计:
| 方法 | 公式 | 适用场景 |
|---|---|---|
| 内积 | score(u,v) = z_u^T z_v | 无向图 |
| 拼接MLP | score(u,v) = MLP([z_u | z_v]) | 复杂关系 |
| 双线性 | score(u,v) = z_u^T W z_v | 关系预测 |
负采样策略:
对于大规模图,对所有非边计算损失代价高昂。采用负采样:
L=−logσ(zuTzv)−∑i=1kEvi∼Pn(v)[logσ(−zuTzvi)] \mathcal{L} = -\log \sigma(z_u^T z_v) - \sum_{i=1}^{k} \mathbb{E}_{v_i \sim P_n(v)}[\log \sigma(-z_u^T z_{v_i})] L=−logσ(zuTzv)−i=1∑kEvi∼Pn(v)[logσ(−zuTzvi)]
6. GNN前沿进展
6.1 Graph Transformer
2024-2025年,Graph Transformer成为GNN研究的热点方向。传统Transformer通过自注意力机制捕获全局依赖,与GNN结合后兼具局部结构和全局语义建模能力。
核心思想:将图结构信息融入Transformer架构
代表性工作:
| 模型 | 发表 | 核心创新 |
|---|---|---|
| Graphormer | NeurIPS 2021 | 空间编码、边编码 |
| SAN | ICLR 2022 | 拉普拉斯特征位置编码 |
| GraphGPS | ICLR 2022 | 通用、可扩展、等变 |
| POLYNORMER | ICLR 2024 | 多项式表达能力 |
| Ring-Enhanced GT | AAAI 2025 | 环增强结构编码 |
6.2 GraphGPT
大语言模型(LLM)与图神经网络的融合是2024年的重要趋势。
GraphGPT框架:
- 图-文本对齐:将图结构编码与文本语义空间对齐
- 指令微调:在图任务上微调LLM
- 多模态推理:结合图结构和自然语言进行推理
应用场景:
- 知识图谱问答
- 分子性质预测
- 推荐系统解释
6.3 图基础模型
2024年,图基础模型(Graph Foundation Models)成为研究前沿。
核心挑战:
- 跨域迁移学习
- 零样本图理解
- 大规模预训练
技术路线:
| 方向 | 方法 | 代表工作 |
|---|---|---|
| 预训练GNN | 自监督预训练 + 微调 | GCC, GraphCL |
| GNN+LLM | 图编码器 + 大语言模型 | GraphGPT, InstructGLM |
| 统一架构 | Transformer统一建模 | Graphormer, TokenGT |
6.4 2024-2025研究趋势
- 多模态图学习:融合文本、图像、图结构
- 动态图神经网络:处理时序变化的图
- 可解释GNN:注意力可视化、因果推理
- 高效GNN:模型压缩、量化、蒸馏
- 科学发现:药物发现、材料设计、蛋白质结构预测
7. 实战案例:社交网络节点分类
7.1 案例背景
使用Cora引文网络数据集进行节点分类。Cora包含2708篇机器学习论文,分为7个类别,引文关系构成图结构。
7.2 环境配置
# 安装依赖
# pip install torch torchvision torchaudio
# pip install torch-geometric
# pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子
torch.manual_seed(42)
7.3 数据加载与探索
# 加载Cora数据集
dataset = Planetoid(root='data/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
print(f"数据集名称: {dataset.name}")
print(f"节点数量: {data.num_nodes}")
print(f"边数量: {data.num_edges}")
print(f"特征维度: {dataset.num_features}")
print(f"类别数量: {dataset.num_classes}")
print(f"训练集大小: {data.train_mask.sum().item()}")
print(f"验证集大小: {data.val_mask.sum().item()}")
print(f"测试集大小: {data.test_mask.sum().item()}")
# 数据对象结构
print(f"\n数据对象属性:")
print(f"x (节点特征): {data.x.shape}")
print(f"edge_index (边索引): {data.edge_index.shape}")
print(f"y (节点标签): {data.y.shape}")
7.4 GCN模型实现
class GCN(nn.Module):
"""两层GCN模型"""
def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
self.dropout = dropout
def forward(self, x, edge_index):
# 第一层GCN + ReLU + Dropout
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# 第二层GCN
x = self.conv2(x, edge_index)
return x
# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_gcn = GCN(
in_channels=dataset.num_features,
hidden_channels=16,
out_channels=dataset.num_classes,
dropout=0.5
).to(device)
data = data.to(device)
print(f"\nGCN模型结构:\n{model_gcn}")
7.5 GAT模型实现
class GAT(nn.Module):
"""两层GAT模型,使用多头注意力"""
def __init__(self, in_channels, hidden_channels, out_channels,
heads=8, dropout=0.6):
super(GAT, self).__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads,
dropout=dropout)
# 第二层使用单头注意力,拼接改为平均
self.conv2 = GATConv(hidden_channels * heads, out_channels,
heads=1, concat=False, dropout=dropout)
self.dropout = dropout
def forward(self, x, edge_index):
# 第一层GAT + ELU + Dropout
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# 第二层GAT
x = self.conv2(x, edge_index)
return x
# 初始化GAT模型
model_gat = GAT(
in_channels=dataset.num_features,
hidden_channels=8,
out_channels=dataset.num_classes,
heads=8,
dropout=0.6
).to(device)
print(f"GAT模型结构:\n{model_gat}")
7.6 GraphSAGE模型实现
class GraphSAGE(nn.Module):
"""GraphSAGE模型,使用均值聚合"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers=2, dropout=0.5):
super(GraphSAGE, self).__init__()
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
self.dropout = dropout
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index)
return x
# 初始化GraphSAGE模型
model_sage = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=16,
out_channels=dataset.num_classes,
num_layers=2,
dropout=0.5
).to(device)
print(f"GraphSAGE模型结构:\n{model_sage}")
7.7 训练与评估函数
def train(model, data, optimizer, criterion):
"""训练一个epoch"""
model.train()
optimizer.zero_grad()
# 前向传播
out = model(data.x, data.edge_index)
# 计算训练集损失
loss = criterion(out[data.train_mask], data.y[data.train_mask])
# 反向传播
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def evaluate(model, data):
"""评估模型性能"""
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
correct = pred[mask].eq(data.y[mask]).sum().item()
acc = correct / mask.sum().item()
accs.append(acc)
return accs
def train_model(model, data, epochs=200, lr=0.01, weight_decay=5e-4):
"""完整训练流程"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
best_val_acc = 0
test_acc_at_best = 0
history = {'train': [], 'val': [], 'test': []}
for epoch in range(1, epochs + 1):
loss = train(model, data, optimizer, criterion)
train_acc, val_acc, test_acc = evaluate(model, data)
history['train'].append(train_acc)
history['val'].append(val_acc)
history['test'].append(test_acc)
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc_at_best = test_acc
if epoch % 20 == 0:
print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, "
f"Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}")
return history, best_val_acc, test_acc_at_best
7.8 模型对比实验
print("=" * 60)
print("开始训练GCN模型")
print("=" * 60)
history_gcn, val_acc_gcn, test_acc_gcn = train_model(model_gcn, data)
print("\n" + "=" * 60)
print("开始训练GAT模型")
print("=" * 60)
history_gat, val_acc_gat, test_acc_gat = train_model(model_gat, data, lr=0.005)
print("\n" + "=" * 60)
print("开始训练GraphSAGE模型")
print("=" * 60)
history_sage, val_acc_sage, test_acc_sage = train_model(model_sage, data)
print("\n" + "=" * 60)
print("实验结果对比")
print("=" * 60)
print(f"GCN - 最佳验证准确率: {val_acc_gcn:.4f}, 测试准确率: {test_acc_gcn:.4f}")
print(f"GAT - 最佳验证准确率: {val_acc_gat:.4f}, 测试准确率: {test_acc_gat:.4f}")
print(f"GraphSAGE - 最佳验证准确率: {val_acc_sage:.4f}, 测试准确率: {test_acc_sage:.4f}")
7.9 可视化训练过程
def plot_training_history(histories, labels):
"""绘制训练曲线"""
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, split in enumerate(['train', 'val', 'test']):
for history, label in zip(histories, labels):
axes[i].plot(history[split], label=label)
axes[i].set_xlabel('Epoch')
axes[i].set_ylabel('Accuracy')
axes[i].set_title(f'{split.capitalize()} Accuracy')
axes[i].legend()
axes[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('gnn_training_comparison.png', dpi=150)
plt.show()
# 绘制对比图
plot_training_history(
[history_gcn, history_gat, history_sage],
['GCN', 'GAT', 'GraphSAGE']
)
7.10 节点嵌入可视化
from sklearn.manifold import TSNE
def visualize_embeddings(model, data, title):
"""使用t-SNE可视化节点嵌入"""
model.eval()
with torch.no_grad():
# 获取最后一层之前的特征
if isinstance(model, GCN):
x = model.conv1(data.x, data.edge_index)
x = F.relu(x)
elif isinstance(model, GAT):
x = model.conv1(data.x, data.edge_index)
x = F.elu(x)
else:
x = model.convs[0](data.x, data.edge_index)
x = F.relu(x)
# t-SNE降维
embeddings = x.cpu().numpy()
tsne = TSNE(n_components=2, random_state=42)
embeddings_2d = tsne.fit_transform(embeddings)
# 绘制散点图
plt.figure(figsize=(10, 8))
colors = plt.cm.tab10(np.linspace(0, 1, dataset.num_classes))
for i in range(dataset.num_classes):
mask = data.y.cpu() == i
plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1],
c=[colors[i]], label=f'Class {i}', alpha=0.6, s=20)
plt.title(f'{title} - Node Embeddings Visualization')
plt.legend()
plt.savefig(f'{title.lower()}_embeddings.png', dpi=150)
plt.show()
# 可视化各模型的嵌入
visualize_embeddings(model_gcn, data, 'GCN')
visualize_embeddings(model_gat, data, 'GAT')
7.11 注意力权重可视化(GAT)
def visualize_attention(model, data, num_nodes=100):
"""可视化GAT的注意力权重"""
model.eval()
# 获取注意力权重
with torch.no_grad():
# 第一层注意力
_, attn_weights = model.conv1(data.x, data.edge_index, return_attention_weights=True)
edge_index, attn = attn_weights
# 选择部分节点进行可视化
plt.figure(figsize=(12, 8))
# 绘制注意力分布
plt.subplot(1, 2, 1)
plt.hist(attn.cpu().numpy().flatten(), bins=50, alpha=0.7)
plt.xlabel('Attention Weight')
plt.ylabel('Frequency')
plt.title('Distribution of Attention Weights')
# 绘制注意力热力图(前num_nodes个节点)
plt.subplot(1, 2, 2)
adj_attn = torch.zeros(num_nodes, num_nodes)
mask = (edge_index[0] < num_nodes) & (edge_index[1] < num_nodes)
for i in range(mask.sum()):
src = edge_index[0][mask][i].item()
dst = edge_index[1][mask][i].item()
adj_attn[src, dst] = attn[mask][i].mean().item()
plt.imshow(adj_attn.numpy(), cmap='hot', interpolation='nearest')
plt.colorbar(label='Attention Weight')
plt.title(f'Attention Heatmap (First {num_nodes} Nodes)')
plt.xlabel('Target Node')
plt.ylabel('Source Node')
plt.tight_layout()
plt.savefig('gat_attention_visualization.png', dpi=150)
plt.show()
# 可视化GAT注意力
visualize_attention(model_gat, data)
8. 避坑小贴士
8.1 数据预处理陷阱
| 问题 | 现象 | 解决方案 |
|---|---|---|
| 未添加自环 | 节点无法保留自身信息 | 使用 add_self_loops 或 GCNConv 的 improved 参数 |
| 特征未归一化 | 梯度爆炸/消失 | 使用 NormalizeFeatures 或手动标准化 |
| 忽略图的方向性 | 有向图被当作无向图处理 | 检查 edge_index 是否需要反转添加 |
8.2 模型设计陷阱
| 问题 | 现象 | 解决方案 |
|---|---|---|
| 层数过深 | 过平滑,性能下降 | 限制层数在2-3层,使用残差连接 |
| 隐藏维度不当 | 欠拟合或过拟合 | 从16/32开始,根据数据量调整 |
| 忽略Dropout | 训练集过拟合 | 设置 dropout=0.5-0.6 |
8.3 训练过程陷阱
| 问题 | 现象 | 解决方案 |
|---|---|---|
| 学习率过大 | 损失震荡不收敛 | 从0.01或0.005开始,使用学习率衰减 |
| 权重衰减过小 | 模型复杂度过高 | 设置 weight_decay=5e-4 |
| 未使用验证集 | 无法判断最佳模型 | 早停策略,保存验证集最优模型 |
8.4 大规模图处理
对于大规模图(百万级节点),直接加载全图会OOM:
# 使用NeighborLoader进行采样训练
from torch_geometric.loader import NeighborLoader
loader = NeighborLoader(
data,
num_neighbors=[10, 10], # 每层采样10个邻居
batch_size=64,
input_nodes=data.train_mask,
)
for batch in loader:
# batch 包含采样的子图
out = model(batch.x, batch.edge_index)
loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
9. 本章小结
本章系统介绍了图神经网络的核心概念与前沿进展:
核心知识点回顾:
- 图数据表示:邻接矩阵、度矩阵、拉普拉斯矩阵的数学定义
- GCN:谱域卷积到空间域卷积的简化,归一化邻接矩阵的作用
- GAT:注意力机制在图上的应用,多头注意力的实现
- GraphSAGE:归纳式学习框架,邻居采样与聚合策略
- 图嵌入:节点嵌入、图级表示、链路预测任务
前沿进展:
- Graph Transformer:融合全局注意力与图结构
- GraphGPT:大语言模型与图的结合
- 图基础模型:跨域迁移与零样本学习
实战技能:
- 使用PyTorch Geometric构建GNN模型
- 实现GCN、GAT、GraphSAGE三种经典架构
- 社交网络节点分类的完整流程
进一步学习方向:
- 异质图神经网络(HAN, RGCN)
- 动态图神经网络(EvolveGCN, TGAT)
- 图生成模型(VGAE, GraphVAE)
- 图强化学习
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)