1 要点

标题: Multimodal brain age estimation using interpretable adaptive population-graph learning
中文: 自适应人口图学习用于多模态脑龄估计
作者: Kyriaki-Margarita Bintsi, Vasileios Baltatzis, Rolandos Alexandros Potamias, Alexander Hammers, Daniel Rueckert
机构: Imperial College London / King’s College London / Technical University of Munich
会议: MICCAI 2023
代码:https://github.com/bintsi/adaptive-graph-learning

研究背景
人口图(Population Graph)将每个被试作为节点、被试间的相似关系作为边,配合GCN在多模态医学影像任务中展现了优势。然而,人口图通常依赖非影像信息(如年龄、性别、认知分数),通过固定的相似度度量手动、静态地构建。GCN的性能高度依赖图结构,如果图的同配性(homophily,即相似节点倾向于相互连接的程度)很低,GCN可能被简单MLP反超。现有自适应图学习方法存在三类典型缺陷:仅用影像特征构建边(丢弃非影像信息)、边已预先修剪仅能微调权重、或构建过程不可解释。

研究目标:构建一个端到端的自适应人口图学习框架,其利用注意力机制自动学习多模态(影像 + 非影像)表型的权重、基于加权表型动态生成图结构,同时通过可视化注意力权重提供完整的可解释性。

关键技术

  1. 注意力权重提取:MLP以每个被试的非影像表型和影像表型为输入,输出每个表型的注意力权重,随后跨被试平均并归一化为全局权重,并对所有节点统一适用。
  2. 边提取:先用注意力权重对表型加权,再计算边概率,最后通过Gumbel-Top-k技巧为每个节点随机采样 k k k条边,维持稀疏 k k k度图。
  3. 图损失函数(回归):图是稀疏的、边具有离散性质,无法直接通过反向传播训练。因此需要奖励导致正确预测的边,惩罚导致错误预测的边

数据集UK Biobank(UKBB,英国生物银行),约6,500例(47–81 岁),含68个神经影像表型(MRI+ DTI)和20个非影像表型(来源:Cole 2020年确定的90个脑龄相关特征),按75 / 5 / 20划分训练/验证/测试集。

2 引言

2.1 脑龄估计与人口图

健康脑衰老遵循特定的演变模式,但AD(阿尔茨海默病)、帕金森病、精神分裂症等疾病则伴随异常加速的衰老。脑龄差(Brain Age Gap,BAG=估计脑龄−实际年龄)反映个体偏离健康衰老轨迹的程度,已成为神经退行性疾病的重要生物标志物。

人口图天然适合整合多模态信息:将神经影像作为节点特征,通过相似度度量在不同被试之间建立边。但在医学应用中,图的构建并非易事,因为定义两个被试相似的方式可以多种多样。GCN的性能高度依赖图结构,这与其同配性直接相关:同配性刻画的是节点倾向于与特征相似的节点连接的程度。同配性若极低,一个完全抛弃图结构的MLP就能胜过GCN。

2.2 现有自适应图学习的局限

自适应图学习通过在训练过程中同时学习图结构来解决上述问题。但在医学领域,现有方法各自存在缺陷:

  • [13, 14, 7]:边仅基于影像特征连接,完全放弃非影像信息
  • [11]:使用非影像信息决定连接,但边已预先修剪,仅能调整权重,无法改变拓扑
  • [26]:图是动态的,但并非从数据中学出来的
  • 更根本的共性问题:上述方法的边连接均不可解释,无从得知哪些特征驱动了节点之间的连接

2.3 贡献

  1. 将影像与非影像信息统一纳入基于注意力的框架,自适应学习优化的图结构
  2. 提出新型图损失函数,首次实现对脑龄回归(而非常用的分类)的端到端训练
  3. 注意力机制天然提供可解释性,可对所有表型按任务重要性排序

3 方法

3.1 总体框架

给定 N N N个被试、 M M M个影像特征 X = [ x 1 , … , x N ] ∈ R N × M X = [x_1, \ldots, x_N] \in \mathbb{R}^{N \times M} X=[x1,,xN]RN×M和标签 y ∈ R N y \in \mathbb{R}^N yRN。人口图 G = { V , E } G = \{V, E\} G={V,E}中每个被试对应一个节点,目标是为脑龄估计任务学习一组最优的边集 E E E

边构建使用两组表型特征

  • 非影像表型 q i ∈ R Q q_i \in \mathbb{R}^Q qiRQ Q = 20 Q = 20 Q=20:覆盖人口学、认知、生活方式、生物医学指标)
  • 影像表型 s i ⊆ x i ∈ R S s_i \subseteq x_i \in \mathbb{R}^S sixiRS S = 68 S= 68 S=68:来自结构MRI和DTI的衍生测量,从Cole 2020[5]确定的90个脑龄相关特征中选取)

设计选择:完整的影像特征 X X X用作 G C N GCN GCN的节点特征,但仅取其子集 S S S(影像表型)参与边构建。非影像特征 q i q_i qi仅用于边构建,不进入 G C N GCN GCN节点特征,这确保了两类信息各司其职。

3.2 注意力权重提取

一个基本假设:并非所有表型对图构建同等重要。为此,使用MLP g θ g_\theta gθ,以每个被试 i i i的非影像特征 q i q_i qi和影像特征 s i s_i si为输入,输出注意力权重向量:
a ∈ R Q + S a \in \mathbb{R}^{Q+S} aRQ+S每个元素对应一个特定表型的重要性。与脑龄估计密切相关的特征获得接近1的权重,不相关的趋近于0。

全局化:跨所有被试平均注意力权重并归一化至 [ 0 , 1 ] [0, 1] [0,1],权重是全局的,对所有节点统一适用。

3.3 边提取

Step 1 — 加权表型
f i = a ⊙ ( q i ∥ s i ) , f i ∈ R Q + S f_i = a \odot (q_i \| s_i), \quad f_i \in \mathbb{R}^{Q+S} fi=a(qisi),fiRQ+S其中 ∥ \| 为拼接, ⊙ \odot 为Hadamard积(逐元素乘法)。

Step 2 — 计算边概率
p i j ( f i , f j ; θ , t ) = e − t ⋅ d ( f i , f j ) 2 (1) p_{ij}(f_i, f_j; \theta, t) = e^{-t \cdot d(f_i, f_j)^2} \tag{1} pij(fi,fj;θ,t)=etd(fi,fj)2(1)其中 t t t为可学习温度参数,控制距离对概率的敏感度; d ( ⋅ , ⋅ ) d(\cdot,\cdot) d(,)为距离函数(默认使用欧氏距离,消融实验确认其为最优选择)

概率解释:两个被试的加权表型越相似,则距离越小、 连接概率越高、图自然趋向同配性(年龄相似的被试被连接在一起)。

Step 3 — Gumbel-Top-k 采样
使用Gumbel-Top-k 技巧[17](kNN规则的随机松弛版本)按概率矩阵 P P P为每个节点随机采样 k k k条边,构建稀疏 k k k度图。推理时执行多次采样取平均预测,以充分利用随机性获得更稳健的结果。

3.4 端到端训练与优化

提取的图与影像特征 X X X一并输入GCN g ψ g_\psi gψ。总损失为: L = L G C N + L g r a p h (2) L = L_{GCN} + L_{graph} \tag{2} L=LGCN+Lgraph(2)其中 L G C N L_{GCN} LGCN在回归任务使用Huber损失,其对异常值比MSE更稳健),分类任务则使用交叉熵; L g r a p h L_{graph} Lgraph回归图损失 L g r a p h = ∑ ( i , j ) ∈ E ρ ( y i , g ψ ( x i ) ) ⋅ log ⁡ ( p i j ) (3) L_{graph} = \sum_{(i,j) \in E} \rho(y_i, g_\psi(x_i)) \cdot \log(p_{ij}) \tag{3} Lgraph=(i,j)Eρ(yi,gψ(xi))log(pij)(3)其中奖励函数: ρ ( y i , g ψ ( x i ) ) = ∣ y i − g ψ ( x i ) ∣ − ε (4) \rho(y_i, g_\psi(x_i)) = |y_i - g_\psi(x_i)| - \varepsilon \tag{4} ρ(yi,gψ(xi))=yigψ(xi)ε(4) ε = 6 \varepsilon=6 ε=6为空模型的预测误差(即训练集的平均脑龄), ∣ y i − g ψ ( x i ) ∣ |y_i - g_\psi(x_i)| yigψ(xi)为GCN的预测误差。

  • 当预测误差 < 空模型误差 → ρ < 0 \rho < 0 ρ<0(负奖励,即好)→ 最小化 L g r a p h L_{graph} Lgraph要求最大化 log ⁡ ( p i j ) \log(p_{ij}) log(pij)奖励有用边,增大其概率
  • 当预测误差 > 空模型误差 → ρ > 0 \rho > 0 ρ>0(正惩罚,即差)→ 最小化 ( L g r a p h (L_{graph} (Lgraph要求最小化 log ⁡ ( p i j ) \log(p_{ij}) log(pij)惩罚有害边,减小其概率

3.5 GCN架构

  • 1层图卷积层(512单元)→ ReLU
  • 1层全连接层(128单元)→ ReLU
  • 回归 / 分类输出层
  • 层数和维度通过超参数搜索(基于验证集性能)确定

训练AdamW,初始学习率0.005,300个epoch,early stopping,NVIDIA Titan RTX GPU,PyTorch 实现。报告 10 次不同随机初始化的平均结果。

4 实验

4.1 数据集与实验设置

UK Biobank:约6,500例被试(47–81 岁),仅纳入具有全部必要表型的个体。训练/验证/测试 = 75 / 5 / 20。

表型来源:从Cole[5]确定的90个脑龄相关特征中选取。含68个神经影像表型(结构MRI + DTI衍生测量)和20个非影像表型(覆盖人口学、认知、生活方式、生物医学指标)。所有表型归一化至 [0, 1]。

多模态利用方式

  • 68个影像特征 X X X完整用作 GCN 节点特征
  • 68个影像表型 s i s_i si + 20个非影像表型 q i q_i qi:仅输入注意力MLP用于边构建

主要任务:脑龄回归(MAE 和 Pearson r)。次要任务:4类年龄分类(Accuracy、AUC、Macro F1)。

Baseline

  • Linear / Logistic Regression:若图结构毫无信息量,GCN可能比简单线性模型更差
  • Static (Node features):基于影像节点特征的余弦相似度kNN静态图 + GCN
  • Static (Phenotypes):基于表型的余弦相似度kNN静态图 + GCN
  • DGM [13]:医学图学习的SOTA方法,原文仅支持分类,本文将其损失函数扩展至回归以便公平比较

所有静态基线使用 k = 5 k = 5 k=5的kNN。欧氏距离作为相似度度量在静态图上表现更差,因此基线仅报告余弦相似度的结果。

4.2 回归结果(Table1左)

对测试集中每个被试,用训练好的模型估计脑龄,计算MAE(|估计年龄−真实年龄|的均值,单位年)和Pearson r(估计值与真实值的线性相关强度),具体如表1。结果表明,GCN基于静态图的表现不如线性回归。DGM优于静态基线但逊于本文方法。

4.3 分类结果(Table1右)

将连续年龄按分位数离散化为 4 4 4个平衡的年龄组(每组人数大致相等),训练模型做 4 4 4分类预测,如表1右侧所示。结果表明,本文方法达虽然优于其他方法但并非高到可以临床使用。这恰恰是本文将主推方向放在回归任务上的动因:将连续年龄离散化会丢失信息。

4.4 消融实验(Table 2)

消融实验分两组。第一组为固定距离度量(欧氏距离),改变参与边构建的表型数量(20 个非影像 / 35 个混合 / 50 个混合 / 68 个影像 only),观察MAE变化。第二组为固定35个表型,改变距离函数(随机边 / 欧氏距离 / 余弦距离 / 双曲距离),观察哪个度量最优,如表2所示。

4.5 可解释性:注意力权重(Fig. 2)

训练完成后,取出注意力MLP输出的全局权重向量 a a a,每个值对应一个表型在边构建中的重要性。按权重从高到低排序,粉色标注非影像表型 Q = 20 Q=20 Q=20,蓝色标注影像表型 S = 68 S=68 S=68。图2为完整排序的柱状图。

Top非影像表型:数字连线任务(执行功能)、字母-数字连线任务(认知灵活性)、收缩压(心血管健康)、卒中史。Top影像表型:丘脑前辐射束 FA(白质微结构完整性)、WMH 体积(小血管疾病标志)、灰质体积、FA 骨架测量。这些发现与Cole (2020)[5]高度一致:连线任务反映执行功能衰退,收缩压反映心血管-脑衰老关联,WMH 反映小血管疾病负荷,DTI 指标反映白质微结构完整性。

一个重要的概念区分:注意力权重揭示的不是哪些表型预测脑龄能力最强,而是哪些表型最适合用来判断两个被试在脑衰老模式上是否相似,即表型对边构建的价值。例如某个表型虽然本身预测脑龄不强,但如果它在被试间方差大且与年龄结构相关,它就非常适合用来决定连边。

4.6 人口图可视化(Fig. 3)

为直观对比自适应图与静态图的质量,取同一批被试,分别用两种方式构建图:1)静态图:所有表型等权重,计算其余弦相似度kNN图;2)加权图:将注意力权重乘以表型后再算余弦相似度kNN。将被试按其实际年龄着色,用t-SNE 将图嵌入到二维空间进行可视化。图3展示两种图的节点分布:

  • 静态图:节点按年龄着色后无明显聚类模式,不同年龄的被试混杂分布——说明在等权表型空间中,表型相似并不等价于脑衰老模式相似",图缺乏同配性。
  • 加权图:相同年龄的节点形成清晰聚类,注意力权重成功压制了不相关表型的噪声,使表型距离与年龄距离对齐,图同配性显著增强。

这直接解释了本文方法为何优于静态基线:在高同配性图上,GCN 每次邻域聚合都在汇聚年龄相似节点的特征,信号逐层增强;在低同配性静态图上,邻域聚合反而混合了不同年龄的特征,噪声覆盖了信号,GCN 退化为不如线性模型的水平。

5 讨论

5.1 方法价值

本文在医学自适应图学习领域做出了两个独特贡献。

  • 将此前仅用于分类的图损失函数推广到回归任务
  • 通过注意力权重的可视化使图构建过程变得可解释,这在需要建立信任的临床应用中不可或缺。

5.2 问题与局限

5.2.1 全局注意力权重的一刀切假设

注意力权重是跨所有被试平均的全局权重。这隐含了一个前提:所有被试在图构建中使用同一组表型重要性排序。但表型与脑龄的关联很可能因年龄、性别和健康状况而异,年轻被试中DTI指标的重要性未必与老年被试相同。若采用样本自适应的注意力权重(即每个被试独立计算权重,或按年龄/性别分层计算),有可能进一步提升精度。

5.2.2 GCN的浅层架构

仅1层图卷积层:理论上加深图卷积可扩大感受野、聚合更远邻居的特征。而论文未对GCN深度做消融,1层是否最优尚无实证支持。

5.2.3 边概率函数的指数衰减假设

p i j = e − t ⋅ d ( f i , f j ) 2 p_{ij} = e^{-t \cdot d(f_i, f_j)^2} pij=etd(fi,fj)2的指数衰减形式隐含假设:表型相似度与脑衰老模式相似度之间为单调递减关系,即表型越相似,衰老模式越可能相同。但在医学数据中,两个表型差异极大的人可能因互补的健康模式(如一人心血管差但认知好,另一人恰好相反)而在衰老轨迹上相似,欧氏距离无法捕捉这种非线性互补关系。

5.2.4 空模型的选择

ε \varepsilon ε 被设为训练集平均脑龄(约 6 年误差)。若训练集与测试集的年龄分布不同(UKBB存在已知的健康志愿者偏倚),空模型误差在不同子群上的校准将不一致,导致图损失在年轻/年长被试上的奖励与惩罚强度不对称。

5.2.5 仅用Cole 2020预选表型,削弱了框架的"自适应"潜力

68 + 20个表型均从Cole (2020) 的90个特征中预选。若框架真的自适应,理应能从全部可用的UKBB表型中自动筛选最重要的特征,而非在人工预选的子集上分配权重。当前设计相当于施加了强先验过滤,在一定程度上削弱了自适应学习这一主张的力度。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐