【图神经网络】Graph Neural Network详解:处理非欧几里得数据


一、引言

Graph Neural Network (GNN) 是专门用于处理图结构数据的神经网络。社交网络、分子结构、知识图谱等都是典型的图数据。GNN的出现让我们能够对这类复杂关系进行深度学习建模。

本文将详细介绍GNN的核心原理、消息传递机制以及主流变体。

nn_blogs_imgs%2Fgnn_architecture.png


二、GNN核心原理

2.1 消息传递机制

GNN的核心是**消息传递(Message Passing)**机制:

h v ( l + 1 ) = U P D A T E ( h v ( l ) , A G G ( { h u ( l ) : u ∈ N ( v ) } ) ) h_v^{(l+1)} = UPDATE\left(h_v^{(l)}, AGG\left(\{h_u^{(l)} : u \in \mathcal{N}(v)\}\right)\right) hv(l+1)=UPDATE(hv(l),AGG({hu(l):uN(v)}))

其中:

  • h v ( l ) h_v^{(l)} hv(l):节点 v v v 在第 l l l 层的嵌入
  • N ( v ) \mathcal{N}(v) N(v):节点 v v v 的邻居节点集合
  • U P D A T E UPDATE UPDATE:更新函数
  • A G G AGG AGG:聚合函数(通常使用SUM/MEAN/MAX)

2.2 图卷积操作

Graph Convolutional Network (GCN) 的核心公式:

H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)}\right) H(l+1)=σ(D~21A~D~21H(l)W(l))

其中:

  • A ~ = A + I \tilde{A} = A + I A~=A+I:带自环的邻接矩阵
  • D ~ \tilde{D} D~:度矩阵
  • W ( l ) W^{(l)} W(l):可学习的权重矩阵

2.3 聚合操作对比

聚合方式 公式 特点
SUM ∑ u ∈ N ( v ) h u \sum_{u \in \mathcal{N}(v)} h_u uN(v)hu 保留全部信息
MEAN $\frac{1}{ \mathcal{N}(v)
MAX max ⁡ u ∈ N ( v ) h u \max_{u \in \mathcal{N}(v)} h_u maxuN(v)hu 保留最显著特征
Attention ∑ u ∈ N ( v ) α v u h u \sum_{u \in \mathcal{N}(v)} \alpha_{vu} h_u uN(v)αvuhu 加权聚合

三、实验结果

我们在多个图数据集上进行了节点分类实验:

nn_blogs_imgs%2Fgnn_accuracy.png

数据集 Cora Citeseer Pubmed ogbn-arxiv
GCN 81.5% 70.3% 79.0% 71.7%
GAT 83.0% 72.5% 79.0% 72.1%
GraphSAGE 80.2% 71.8% 78.8% 71.8%
MoNet 81.7% 71.4% 78.8% -

nn_blogs_imgs%2Fgnn_training.png


四、代码实现

4.1 消息传递GNN层

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MessagePassing(nn.Module):
    """Message Passing Neural Network Layer"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.lin = nn.Linear(in_ch, out_ch)
    
    def propagate(self, edge_index, size=None, **kwargs):
        """Message passing step"""
        pass
    
    def message(self, x_j):
        """Construct messages from neighbor nodes"""
        return x_j

class GCNConv(MessagePassing):
    """Graph Convolutional Network Layer"""
    def __init__(self, in_ch, out_ch, bias=True):
        super().__init__(in_ch, out_ch)
        self.lin = nn.Linear(in_ch, out_ch, bias=False)
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_ch))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)
    
    def forward(self, x, edge_index):
        # Compute degree matrix
        row, col = edge_index
        deg = torch.zeros(row.size(0), device=edge_index.device)
        deg.scatter_add_(0, row, torch.ones_like(row))
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        # Normalize
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
        # Message passing
        x = self.lin(x)
        out = self.propagate(edge_index, x=x, norm=norm)
        
        if self.bias is not None:
            out = out + self.bias
        return out
    
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    
    def propagate(self, edge_index, x, norm):
        out = torch.zeros(x.size(0), x.size(1), device=x.device)
        row, col = edge_index
        out.index_add_(0, row, norm.view(-1, 1) * x[col])
        return out

class GATConv(MessagePassing):
    """Graph Attention Network Layer"""
    def __init__(self, in_ch, out_ch, heads=8, concat=True, negative_slope=0.2):
        super().__init__(in_ch, out_ch)
        self.heads = heads
        self.concat = concat
        self.out_ch = out_ch // heads if concat else out_ch
        
        self.lin = nn.Linear(in_ch, self.heads * self.out_ch, bias=False)
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * self.out_ch))
        self.bias = nn.Parameter(torch.Tensor(self.heads * self.out_ch if concat else self.out_ch))
        
        self.negative_slope = negative_slope
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin.weight)
        nn.init.xavier_uniform_(self.att)
        nn.init.zeros_(self.bias)
    
    def forward(self, x, edge_index):
        x = self.lin(x).view(-1, self.heads, self.out_ch)
        
        # Compute attention coefficients
        row, col = edge_index
        x_i = x[row]
        x_j = x[col]
        
        cat = torch.cat([x_i, x_j], dim=-1)
        att = (cat * self.att).sum(dim=-1)
        att = F.leaky_relu(att, self.negative_slope)
        
        # Mask attention coefficients
        mask = torch.full_like(row, -9e15, dtype=torch.float)
        mask.scatter_(0, row, att)
        
        # Softmax
        att = F.softmax(mask, dim=0)
        
        # Message passing
        out = att.view(-1, self.heads, 1) * x_j
        out = out.view(-1, self.heads * self.out_ch)
        
        if self.concat:
            out = out + self.bias
        else:
            out = out.mean(dim=1) + self.bias
        
        return out

4.2 完整GNN模型

class GraphNetwork(nn.Module):
    """Complete Graph Neural Network"""
    def __init__(self, in_ch, hidden_ch, out_ch, num_layers=3, dropout=0.5):
        super().__init__()
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        # Input layer
        self.convs.append(GCNConv(in_ch, hidden_ch))
        self.norms.append(nn.LayerNorm(hidden_ch))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_ch, hidden_ch))
            self.norms.append(nn.LayerNorm(hidden_ch))
        
        # Output layer
        self.convs.append(GCNConv(hidden_ch, out_ch))
        self.norms.append(nn.LayerNorm(out_ch))
        
        self.dropout = dropout
        self.reset_parameters()
    
    def reset_parameters(self):
        for conv in self.convs:
            if hasattr(conv, 'reset_parameters'):
                conv.reset_parameters()
    
    def forward(self, x, edge_index):
        for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
            x_prev = x
            x = conv(x, edge_index)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            
            # Residual connection for hidden layers
            if i > 0 and i < len(self.convs) - 1:
                x = x + x_prev
        
        return F.log_softmax(x, dim=1)

五、GNN变体总结

5.1 空间域方法

模型 聚合方式 特点
GraphSAGE 采样+聚合 可归纳学习,支持minibatch
GAT Attention 自适应邻居权重
PINN 偏执不变 旋转不变性
GEN 边特征 支持边信息

5.2 频谱域方法

模型 滤波器 特点
GCN 切比雪夫多项式 一阶近似
ChebNet 高阶多项式 K-局部化
AGCN 自适应图 学习距离度量

六、总结与展望

GNN的优势

✅ 统一处理各种图结构数据
✅ 可扩展性强,支持mini-batch训练
✅ 表达能力强大,超越传统图算法

挑战与未来方向

  • 深层次GNN:如何避免过平滑
  • 动态图:处理时序变化的图
  • 大规模图:如何高效计算
  • 异构图:统一处理多类型节点和边

参考论文


💡 您的点赞和关注是我创作的动力!

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐