FedCache论文阅读笔记
- 论文下载链接: https://ieeexplore.ieee.org/document/10420495
- FedCache: A Knowledge Cache-driven Federated Learning Architecture for Personalized Edge Intelligence
- Fedcache:一种面向个性化边缘智能的知识缓存驱动的联邦学习架构
- 期刊:IEEE Transactions on Mobile Computing (TMC)(CCF A类)
摘要
提出一种名为FedCache的知识缓存驱动的PFL架构:该架构在服务器上保留了一个知识缓存,用于从与每个给定的设备上样本具有相似哈希的样本中获取个性化知识。在训练阶段,集成蒸馏被应用于设备上模型,以利用从服务器端知识缓存传输的个性化知识进行建设性优化。
对四个数据集的实证实验表明,FedCache的性能与最先进的PFL方法相当,通信效率提高了两个多数量级。
1 引言
-
边缘智能(EI):允许人工智能(AI)应用程序在边缘运行,可以在接近数据源的地方实时执行数据分析和决策。但传统的集中式EI范式需要上传原始数据用于训练普适人工智能模型,这增加了对敏感数据泄露的隐私担忧。
-
联邦学习(FL):用于在不损害数据隐私的情况下跨多个设备协作训练共享AI模型,但主流的FL方法要求所有参与的设备共享一个统一的模型,无法保证模型在异构客户端上的泛化和自适应。
-
个性化联邦学习(PFL):可以在客户端本地个性化训练需求和全局模型泛化目标之间取得平衡,然而现有的PFL方法大多基于以FedAvg为代表的基于参数交互的架构(PIA),需要进行进行大规模参数传输。
-
基于Logits交互的架构(LIA):将知识蒸馏技术应用于PFL,通过Logits传输来更新模型参数,与PIA相比具有通信轻量级和允许异构设备模型的优势。
【logits(knowledge):最终的全连接层的输出(在未经过softmax之前的、未归一化的概率)】
-
基于知识缓存的联邦学习架构(FedCache):一种新颖的客户端-服务器交互范式,它在服务器上维护一个知识缓存,用于存储与每个私有样本相关的最新知识,并应用一种定制的、基于知识缓存的个性化蒸馏技术来训练设备端模型。
2 相关工作
2.1 面向个性化边缘智能的联邦学习
-
核心问题:联邦学习训练出的"通用模型"无法满足不同设备的需求,因为各设备的数据分布、计算能力、任务目标都存在差异。
-
现有解决方案:
- 多任务学习:允许每个设备训练自己的个性化神经网络,以适应本地特有的数据分布。
- 历史知识蒸馏:利用设备上保存的历史模型,将旧知识迁移到新模型中,实现个性化。
- 域自适应:将在大规模源域数据上训练好的模型,通过特定方法适配到设备上的目标数据,同时考虑存储开销。
- 动态退出:允许训练效果好的设备提前结束训练,从而节约资源。
-
存在的问题:大多没有专门针对通信效率和异步训练进行架构设计。
2.2 联邦边缘学习中的知识蒸馏
- 核心问题:如何利用知识蒸馏技术解决联邦学习中的通信和异构难题。
- 现有解决方案:
- 作为通信协议:将知识蒸馏用作一种新的通信方式,在设备和服务器之间交换模型输出(logits),而不是模型参数,从而实现高效通信和异构模型训练。
- 知识迁移:在设备算力差异大的场景下,将复杂模型的知识迁移到简单模型上。
- 个性化优化:利用蒸馏技术对设备上的模型进行个性化优化。
- 存在的问题:
- 部分方法需要传输中间层的特征,通信量依然很大。
- 部分方法需要依赖公共数据集,这在现实世界中很难获取。
3 预备知识与研究动机
3.1 背景与符号定义
分布式设备(称为客户端)在边缘服务器的协调下,协同训练用于C部分类(C-class)的分类模型,同时将私有数据保留在设备上。
- 客户端数量:假设有 K K K 个客户端参与 PFL, k ∈ { 1 , 2 , . . . , K } k \in \{1,2,...,K\} k∈{1,2,...,K} 。
- 客户端的数据集:每个客户端 k k k 拥有一个私有数据集 D k : = ⋃ i = 1 N k { ( X i k , y i k ) } \mathcal{D}^k := \bigcup_{i=1}^{N^k} \{ (X_i^k, y_i^k) \} Dk:=⋃i=1Nk{(Xik,yik)},其中 N k N^k Nk 是数据集 D k \mathcal{D}^k Dk 中的样本数量, X i k X_i^k Xik 和 y i k y_i^k yik 分别是 D k \mathcal{D}^k Dk 中的第 i i i 个数据和标签。
- 个性化模型:每个设备 k k k 拥有一个个性化模型 M k : = ( W k , f k ) M^k := (W^k, f^k) Mk:=(Wk,fk),其模型参数或架构可能不同,其中 W k W^k Wk 是模型 M k M^k Mk 的参数(保留在本地), f k ( ⋅ ) f^k(\cdot) fk(⋅) 是由 M k M^k Mk 决定的非线性映射。 f k ( X i k ) f^k(X_i^k) fk(Xik)的意思就是把输入 X i k X_i^k Xik喂给模型,得到输出。
- 用户模型准确性(UA):每个设备的目标是尽可能提高其个性化模型在私有数据上的UA。
- 系统优化目标:PFL系统的目标是最大化所有客户端的平均UA(MAUA)。
3.2 边缘智能的实际局限性
PFL 架构在 EI 中部署时需要克服的主要实际限制:
- 设备异构性。考虑到终端设备的不同硬件配置,如中央处理器、内存资源和能量状态,设备之间需要采用个性化模型以适应其特定特性。
- 通信效率。由于边缘服务器和终端设备之间的无线网络带宽有限,它们无法进行大规模通信 。
- 数据隐私。出于隐私问题或数据保护法规,设备不愿意与边缘服务器共享其本地数据。因此,很难获取有关用户本地数据的信息。
- 异步优化。不同设备的不同计算任务、能力和通信延迟导致的高同步开销阻碍了模型更新。
3.3 PFL 架构概述
3.3.1 基于参数交互的PFL架构
在采用 PIA 的 PFL 中,客户端倾向于只上传其部分模型参数,以保留本地自适应能力。因此,服务器端执行经过筛选的、按本地样本数量加权的参数聚合,即:
W ∗ = N k ∑ i = 1 K N i ⋅ filter ( W k ) , (1) W^* = \frac{N^k}{\sum_{i=1}^{K} N^i} \cdot \text{filter}(W^k), \tag{1} W∗=∑i=1KNiNk⋅filter(Wk),(1)
其中 filter ( ⋅ ) \text{filter}(\cdot) filter(⋅) 筛选出将要上传到服务器的部分设备端模型参数,而 W ∗ W^* W∗ 表示服务器上聚合后的模型参数。
尽管 PIA 能够通过筛选模型参数来保留设备端模型的个性化能力,但对于通信资源有限的设备来说,传输大规模参数进行聚合仍然成本过高。此外,PIA 要求在聚合过程中设备端模型架构具有高度同质性,而在边缘智能场景中,普遍存在具有各种硬件相关约束的异构设备,这一要求很难实现。
3.3.2 基于Logits交互的PFL架构
基于 Logits 交互的架构(Logits Interaction-based Architecture, LIA):通过交换 logits(通常称为知识)而非交互模型参数节省数个数量级通信开销以及支持异构模型训练的好处。
-
基于类粒度 Logits 交互的架构(CLIA):在训练过程中,客户端之间交换按类别聚合后的 logits。每个客户端只能学习到其他客户端对某个类别的“平均看法”,难以精确匹配本地数据的细微分布。CLIA 适用于对通信要求极高、但个性化要求不十分严苛的场景。FedCache 则通过更细粒度的样本级知识交互,在保持通信高效的同时大幅提升性能。尽管 CLIA 支持模型异构且通信轻量,但它只允许每个客户端学习 C C C 类 logits。由于客户端相比独立训练仅能获得极少的服务器端信息,这种 PFL 设计容易达到性能上限。
-
基于样本粒度 Logits 交互的架构(SLIA):在训练过程中,客户端之间交换的是每个具体样本的 logits。但现有的基于 SLIA 的方法要么依赖于在不切实际的公共数据集上进行额外的客户端训练,要么除了 logits 之外还需要传输不可忽略大小的嵌入特征。
1) 基于类粒度 Logits 交互的架构(CLIA)
对于 CLIA,每个客户端 k k k 的样本 X i k X_i^k Xik 的输出需要逼近由所有其他客户端中具有相同标签 y i k y_i^k yik 的样本计算得到的全局平均 logits,即:
arg min W k ∑ ( X i k , y i k ) ∈ D k [ L C E ( σ 0 ( f k ( X i k ) ) , y i k ) + γ ⋅ L C E ( σ 0 ( f k ( X i k ) ) , σ 0 ( ∑ l = 1 K F l , k − F k , k K − 1 ) ) ] , (2) \begin{aligned} \underset{W^k}{\arg\min} \sum_{(X_i^k,y_i^k)\in \mathcal{D}^k} \Big[ & L_{CE}\big(\sigma_0(f^k(X_i^k)), y_i^k\big) \\ & + \gamma \cdot L_{CE}\Big(\sigma_0(f^k(X_i^k)), \sigma_0\big(\frac{\sum_{l=1}^K F^{l,k} - F^{k,k}}{K-1}\big)\Big) \Big], \end{aligned} \tag{2} Wkargmin(Xik,yik)∈Dk∑[LCE(σ0(fk(Xik)),yik)+γ⋅LCE(σ0(fk(Xik)),σ0(K−1∑l=1KFl,k−Fk,k))],(2)
其中 σ 0 ( ⋅ ) \sigma_0(\cdot) σ0(⋅) 是 softmax 映射, γ \gamma γ 是蒸馏加权因子, L C E ( ⋅ ) L_{CE}(\cdot) LCE(⋅) 表示交叉熵损失。
- 第一项:标准分类损失
L C E ( σ 0 ( f k ( X i k ) ) , y i k ) L_{CE}\big(\sigma_0(f^k(X_i^k)), y_i^k\big) LCE(σ0(fk(Xik)),yik) 是交叉熵损失,迫使本地模型对样本 X i k X_i^k Xik 的预测概率 σ 0 ( f k ( X i k ) ) \sigma_0(f^k(X_i^k)) σ0(fk(Xik)) 尽可能接近真实标签 y i k y_i^k yik。这是常规的监督学习任务。 - 第二项:蒸馏损失
γ ⋅ L C E ( σ 0 ( f k ( X i k ) ) , σ 0 ( ∑ l = 1 K F l , k − F k , k K − 1 ) ) \gamma \cdot L_{CE}\Big(\sigma_0(f^k(X_i^k)), \sigma_0\big(\frac{\sum_{l=1}^K F^{l,k} - F^{k,k}}{K-1}\big)\Big) γ⋅LCE(σ0(fk(Xik)),σ0(K−1∑l=1KFl,k−Fk,k)) 是知识蒸馏项。它让本地模型的预测概率逼近一个目标概率,这个目标概率是所有其他客户端在相同类别上的平均 logits 经过 softmax 后的结果。具体来说:- F l , k F^{l,k} Fl,k 表示客户端 l l l 上所有标签为 y i k y_i^k yik 的样本的平均 logits(见公式 3)。
- ∑ l = 1 K F l , k − F k , k K − 1 \frac{\sum_{l=1}^K F^{l,k} - F^{k,k}}{K-1} K−1∑l=1KFl,k−Fk,k 是去掉客户端 k k k 自身后,其他 K − 1 K-1 K−1 个客户端在类别 y i k y_i^k yik 上的平均 logits。(避免本地模型过拟合自己的噪声或偏见,只从别人那里学习通用知识。如果包含自己,可能会让模型固守本地观点,失去泛化能力。)
- 对这个平均 logits 应用 softmax σ 0 ( ⋅ ) \sigma_0(\cdot) σ0(⋅) 得到目标概率分布。
- 然后计算本地预测与这个目标分布的交叉熵,作为蒸馏损失。
F l , k F^{l,k} Fl,k 是由客户端 l l l 中具有相同标签 y i l y_i^l yil 的样本计算的平均 logits,即:
F l , k = E ( X i l , y i l ) ∈ D l y i l = y i k f l ( X i l ) . (3) F^{l,k} = \mathop{\mathbb{E}}_{\substack{(X_i^l,y_i^l)\in \mathcal{D}^l \\ y_i^l = y_i^k}} f^l(X_i^l). \tag{3} Fl,k=E(Xil,yil)∈Dlyil=yikfl(Xil).(3)
-
f l ( X i l ) f^l(X_i^l) fl(Xil) 是客户端 l l l 的模型对样本 X i l X_i^l Xil 输出的 logits(未经过 softmax 的原始分数)。
-
条件 y i l = y i k y_i^l = y_i^k yil=yik 表示只考虑那些标签与当前样本 X i k X_i^k Xik 的标签相同的样本。
-
然后对这些样本的 logits 求平均,得到 F l , k F^{l,k} Fl,k。
2) 基于样本粒度 Logits 交互的架构(SLIA)
对于 SLIA,设备端模型学习的 logits 数量与样本数量相关。此类架构通常需要引入公共数据集或增加通信开销作为不可避免的折衷,可分为两种形式。
①带特征交换的 SLIA(SLIA-FE)
在 SLIA-FE 中,客户端 k k k 的模型参数分为特征提取器部分 W e k W_e^k Wek 和预测器部分 W p k W_p^k Wpk,其中特征提取器的预测映射记为 f e k ( ⋅ ) f_e^k(\cdot) fek(⋅)。
- 特征提取器 f e k f_e^k fek:将输入样本映射为中间特征(embedding)。
- 预测器 f p k f_p^k fpk:基于特征输出最终的 logits。
服务器仅保留一个大规模分类器 W S W^S WS,其对应的预测映射为 f S ( ⋅ ) f^S(\cdot) fS(⋅)。通常,服务器上的模型通过结合交叉熵损失 L C E ( ⋅ ) L_{CE}(\cdot) LCE(⋅) 和 Kullback-Leibler 散度损失 K L ( ⋅ ) KL(\cdot) KL(⋅) 进行更新,该更新依赖于客户端上传的特征 f e k ( X i k ) f_e^k(X_i^k) fek(Xik)和 logits f p k ( X i k ) f_p^k(X_i^k) fpk(Xik),可表示如下:
arg min W S ∑ ( X i k , y i k ) ∈ D k [ L C E ( σ 0 ( f S ( f e k ( X i k ) ⏟ uploaded features ) ) , y i k ) + λ ⋅ K L ( σ 0 ( f S ( f e k ( X i k ) ⏟ uploaded features ) ) ∥ σ 1 ( f p k ( X i k ) ⏟ uploaded logits ) ) ] , (4) \begin{aligned} \underset{W^S}{\arg\min} \sum_{(X_i^k,y_i^k)\in \mathcal{D}^k} \Big[ & L_{CE}\big(\sigma_0(f^S(\underbrace{f_e^k(X_i^k)}_{\text{uploaded features}})), y_i^k\big) \\ & + \lambda \cdot KL\big(\sigma_0(f^S(\underbrace{f_e^k(X_i^k)}_{\text{uploaded features}})) \;\|\; \sigma_1(\underbrace{f_p^k(X_i^k)}_{\text{uploaded logits}})\big) \Big], \end{aligned} \tag{4} WSargmin(Xik,yik)∈Dk∑[LCE(σ0(fS(uploaded features
fek(Xik))),yik)+λ⋅KL(σ0(fS(uploaded features
fek(Xik)))∥σ1(uploaded logits
fpk(Xik)))],(4)
其中 σ 1 ( ⋅ ) \sigma_1(\cdot) σ1(⋅) 是本地 logits 的变换映射, λ \lambda λ 是蒸馏加权因子。
- 第一项:服务器分类器 f S f^S fS 对客户端上传的特征 f e k ( X i k ) f_e^k(X_i^k) fek(Xik) 进行预测,计算与真实标签 y i k y_i^k yik 的交叉熵损失。这相当于让服务器分类器学会正确分类客户端的特征。
- 第二项:KL 散度损失,迫使服务器分类器的预测概率 σ 0 ( f S ( feature ) ) \sigma_0(f^S(\text{feature})) σ0(fS(feature)) 尽量接近客户端本地预测器输出的 logits(经过变换 σ 1 \sigma_1 σ1 后的概率分布)。这相当于让服务器从客户端的本地知识中蒸馏,吸收客户端的个性化信息。
- 目的:服务器通过聚合所有客户端的特征和 logits,训练出一个能兼容不同客户端特征的全局分类器。
相对地,客户端 k k k 使用从服务器下载的全局 logits 更新本地模型参数,优化以下损失函数:
arg min W k ∑ ( X i k , y i k ) ∈ D k [ L C E ( σ 0 ( f k ( X i k ) ) , y i k ) + μ ⋅ K L ( σ 0 ( f k ( X i k ) ) ∥ σ 2 ( f S ( f e k ( X i k ) ⏟ uploaded features ) ) ) ] , (5) \begin{aligned} \underset{W^k}{\arg\min} \sum_{(X_i^k,y_i^k)\in \mathcal{D}^k} \Big[ & L_{CE}\big(\sigma_0(f^k(X_i^k)), y_i^k\big) \\ & + \mu \cdot KL\big(\sigma_0(f^k(X_i^k)) \;\|\; \sigma_2(f^S(\underbrace{f_e^k(X_i^k)}_{\text{uploaded features}}))\big) \Big], \end{aligned} \tag{5} Wkargmin(Xik,yik)∈Dk∑[LCE(σ0(fk(Xik)),yik)+μ⋅KL(σ0(fk(Xik))∥σ2(fS(uploaded features
fek(Xik))))],(5)
其中 σ 2 ( ⋅ ) \sigma_2(\cdot) σ2(⋅) 是全局 logits 的变换映射, μ \mu μ 是蒸馏加权因子。
- 第一项:本地模型 f k f^k fk 的常规分类损失。
- 第二项:KL 散度损失,让本地模型的预测概率 σ 0 ( f k ( X i k ) ) \sigma_0(f^k(X_i^k)) σ0(fk(Xik)) 尽量接近服务器全局分类器对本地特征的输出(经过变换 σ 2 \sigma_2 σ2 后的概率分布)。这样,客户端就能从服务器获取其他客户端的知识。
- 目的:客户端在保持自己数据拟合的同时,吸收来自其他设备的“全局观点”。
尽管 SLIA-FE 允许异构设备端模型且无需参数传输,但参与者需要就特征维度达成一致。此外,由于高分辨率图像和长序列数据的特征维度通常较高,特征传输的开销对设备而言仍然显著。
②基于公共数据集的 SLIA(SLIA-PD)
SLIA-PD 引入一个公共数据集 D O \mathcal{D}^O DO,该数据集分布应尽量接近所有客户端的私有数据。客户端 k k k 的目标是在公共数据集 D O \mathcal{D}^O DO 的给定样本 ( X i O , y i O ) (X_i^O, y_i^O) (XiO,yiO) 上逼近所有客户端的平均 logits ,即:
arg min W k ∑ ( X i O , y i O ) ∈ D O L C E ( σ 0 ( f k ( X i O ) ) , σ 0 ( 1 K ∑ l = 1 K f l ( X i O ) U l ) ) , (6) \underset{W^k}{\arg\min} \sum_{(X_i^O,y_i^O)\in \mathcal{D}^O} L_{CE}\Big(\sigma_0(f^k(X_i^O)),\; \sigma_0\big(\frac{1}{K}\sum_{l=1}^K \frac{f^l(X_i^O)}{U^l}\big)\Big), \tag{6} Wkargmin(XiO,yiO)∈DO∑LCE(σ0(fk(XiO)),σ0(K1l=1∑KUlfl(XiO))),(6)
其中 U U U 是一个超参数,用于控制集成 logits 的分布。
- 左边:客户端 k k k 的模型对公共样本 X i O X_i^O XiO 的预测概率。
- 右边:所有客户端对该公共样本的 logits 的加权平均( U l U^l Ul 是控制分布的参数),再经过 softmax 得到的目标概率。
- 损失函数:交叉熵,让本地预测接近所有客户端的“共识”。
- 训练过程:每个客户端在本地私有数据上正常训练,同时定期对公共数据集进行前向传播,上传 logits;服务器聚合后下发平均 logits,客户端通过蒸馏吸收全局知识。
SLIA-PD 不仅进一步放松了客户端间模型架构的约束,而且训练过程中仅交换尺寸极小的 logits,相比前述架构显著降低了通信开销。然而,SLIA-PD 依赖于一个公共数据集,且该数据集的分布应接近客户端的私有数据。由于在未知客户端数据分布的情况下收集满意的公共数据几乎不可能,这种架构在实际中并不可行。
1. 交叉熵损失 L C E L_{CE} LCE
作用:衡量模型预测的概率分布与真实标签之间的差距,是分类任务的标准损失函数。
数学形式:
L C E ( p , y ) = − ∑ c = 1 C 1 ( y = c ) log p c L_{CE}(p, y) = -\sum_{c=1}^{C} \mathbb{1}(y=c) \log p_c LCE(p,y)=−c=1∑C1(y=c)logpc
其中 p = σ 0 ( f ( X ) ) p = \sigma_0(f(X)) p=σ0(f(X)) 是模型输出的概率分布(经过 softmax), y y y 是真实类别(one-hot 形式), C C C 是类别总数。
如果模型对正确类别的预测概率接近 1,损失就小;否则损失大。
2. KL 散度损失 K L KL KL
作用:衡量两个概率分布之间的差异,常用于知识蒸馏中,让学生模型的输出分布逼近教师模型的输出分布。
数学形式:
K L ( p ∥ q ) = ∑ c = 1 C p c log p c q c KL(p \| q) = \sum_{c=1}^{C} p_c \log \frac{p_c}{q_c} KL(p∥q)=c=1∑Cpclogqcpc
其中 p p p 和 q q q 是两个概率分布。注意 KL 散度不对称,即 K L ( p ∥ q ) ≠ K L ( q ∥ p ) KL(p\|q) \neq KL(q\|p) KL(p∥q)=KL(q∥p)。
如果两个分布完全相同,KL 散度为 0;差异越大,值越大。在蒸馏中,通常让学生分布 q q q 去拟合教师分布 p p p。
3. 变换函数 σ 0 , σ 1 , σ 2 \sigma_0, \sigma_1, \sigma_2 σ0,σ1,σ2
这些函数用于将模型的原始输出(logits)转换为概率分布,通常采用 softmax 或带温度的 softmax。
- σ 0 \sigma_0 σ0:标准 softmax,它用于常规分类任务的输出。
σ 0 ( z ) c = exp ( z c ) ∑ j = 1 C exp ( z j ) \sigma_0(z)_c = \frac{\exp(z_c)}{\sum_{j=1}^{C} \exp(z_j)} σ0(z)c=∑j=1Cexp(zj)exp(zc)
- σ 1 \sigma_1 σ1 和 σ 2 \sigma_2 σ2:带温度的 softmax,在知识蒸馏中,常使用带温度的 softmax 来软化概率分布,温度参数 T T T 控制分布的平滑程度。
- T > 1 T > 1 T>1:分布更均匀(软化),放大负标签的信息,有助于学生模型学习教师的知识。
- T < 1 T < 1 T<1:分布更尖锐,突出主要类别。
3.4 动机
现有方法的不足:现有的个性化联邦学习架构无法在系统性能、资源效率和不依赖公共数据集之间实现令人满意的平衡。
FedCache核心思想:在服务器端维护一个知识缓存,存储所有客户端上传的最新logits,作为无需公共数据集的样本粒度知识来源,用于个性化蒸馏。
- 哈希编码:客户端将样本转换为哈希值(隐私保护指纹),上传建立索引。
- 邻居检索:服务器为每个样本找到最相似的R个邻居(基于哈希相似度)。
- 知识集成与蒸馏:客户端下载邻居的集成知识,用于本地模型优化。
优势:样本级细粒度交互(优于CLIA)、无需传输特征(优于SLIA-FE)、无需公共数据集(优于SLIA-PD)、支持完全异构模型、支持异步训练(设备可独立进行知识交互)
| 架构 | 模型异构支持性 | 通信效率 | 不依赖公共数据 | 异步优化 | 通信协议 |
|---|---|---|---|---|---|
| PIA | 部分异构(要求部分参数结构一致) | 低 | 是 | 否 | 模型参数 |
| CLIA | 完全异构 | 高 | 是 | 否 | 类粒度 Logits |
| SLIA-FE | 完全异构(要求所有客户端的特征维度一致) | 中 | 是 | 是 | 样本粒度特征和 Logits |
| SLIA-PD | 完全异构 | 高 | 否 | 否 | 样本粒度 Logits |
| FedCache | 完全异构 | 高 | 是 | 是 | 样本粒度 Logits |
4 知识缓存驱动的个性化联邦学习
4.1 系统设计
FedCache的功能模块图如图2所示,它包含一个带有三个功能模块的服务器:(服务器-客户端)通信模块、(知识)集成模块、知识缓存模块;以及K个带有五个功能模块的客户端:(客户端-服务器)通信模块、(知识)蒸馏模块、数据模块、模型模块和(样本)编码器模块。
- 集成模块:将从知识缓存中获取的知识进行组合,得到待蒸馏到客户端的个性化知识。
- 知识缓存模块:这是我们设计的自组织知识存储结构,便于在服务器端获取每个客户端的相关知识。
- 模型模块:从本地数据中提取知识,并在蒸馏模块的指导下进行模型更新。
- 编码器模块:将私有数据编码为哈希值,用于初始化知识缓存。编码器应具备高效、鲁棒和判别性,确保能够快速计算本地样本的哈希值,并可靠地反映样本间的语义相似度。
在初始化阶段,客户端生成的哈希码一次性上传到服务器。然后,在服务器端知识缓存中执行HNSW算法,旨在为每个样本检索R个最相关的样本,匹配依据是哈希值的余弦相似度。
图3展示了在FashionMNIST数据集上的样本匹配结果。如图所示,匹配到的样本与原始样本非常相似,使得从中提取的知识有助于客户端对原始样本的蒸馏。
在训练阶段,每个通信轮次中,私有样本的logits及其索引被上传到服务器。然后,基于预先建立的相似关系,从知识缓存中获取每个样本的R个具有最高哈希相似度的最佳匹配知识,接着进行知识集成,再将知识传输给对应的客户端用于本地蒸馏。由于蒸馏阶段仅依赖于客户端各自私有数据的高度相关知识,所得的模型具有良好的本地适应性,能够很好地完成个性化任务。
4.2 知识缓存
服务器上的知识缓存旨在以可控的计算复杂度为任意本地样本异步获取相关知识,其中提取相关知识的样本的哈希值应为原始样本哈希值的 (R) 个最近邻之一。基于上述设计,我们在知识缓存中维护多个映射对,包括标签到索引的映射对(LI)、索引到知识的映射对(IK)、索引到哈希的映射对(IH)和索引关系映射对(IR),每个映射对实现从第一个元素到第二个元素的映射。在此基础上,知识缓存包含两个主要阶段:初始化和训练。
初始化过程包括以下步骤:
-
映射对初始化。与每个样本索引 ( k , i ) (k, i) (k,i) 对应的上传哈希值 $ h_i^k$ 存储在 IH 中。此外,根据样本的标签类别将索引添加到 LI 中,并且 IK 中每个给定索引对应的知识初始化为零。
公式 (7):将客户端 k k k 的第 i i i 个样本的哈希值 h i k h_i^k hik 存储到索引-哈希映射 I H IH IH 中。
I H ( k , i ) ← h i k (7) IH(k, i) \leftarrow h_i^k \tag{7} IH(k,i)←hik(7)- I H IH IH 是一个映射表,键是样本索引 ( k , i ) (k, i) (k,i),值是该样本对应的哈希值。
- 这一步在初始化阶段完成,服务器收集所有客户端上传的哈希值,建立索引与哈希值的对应关系,为后续检索相似样本做准备。
公式 (8):将样本索引 ( k , i ) (k, i) (k,i) 添加到标签-索引映射 L I LI LI 中对应标签 y i k y_i^k yik 的集合里。
L I ( y i k ) ← L I ( y i k ) ∪ { ( k , i ) } (8) LI(y_i^k) \leftarrow LI(y_i^k) \cup \{(k, i)\} \tag{8} LI(yik)←LI(yik)∪{(k,i)}(8)- L I LI LI 是一个以标签类别为键、以该类所有样本索引的集合为值的映射。例如,若样本标签为“猫”,则将该样本索引加入 L I ( “猫” ) LI(\text{“猫”}) LI(“猫”) 中。
- 这样做是为了后续建立关系时,只考虑同一标签类别内的样本,减少候选范围,提高检索效率。
公式 (9):将索引-知识映射 I K IK IK 中对应样本索引 ( k , i ) (k, i) (k,i) 的知识初始化为长度为 C C C 的全零向量。
I K ( k , i ) ← ( 0 , … , 0 ⏟ C 个零 ) (9) IK(k, i) \leftarrow (\underbrace{0, \ldots, 0}_{C\text{个零}}) \tag{9} IK(k,i)←(C个零 0,…,0)(9)- I K IK IK 存储每个样本最新的知识(即模型输出的 logits,维度为类别数 C C C)。
- 初始化时知识未知,因此先填充零向量。随着训练的进行,客户端会不断上传最新的知识, I K IK IK 中的条目会被逐步更新(见公式 13)。
-
建立关系。对于给定的样本索引 ( k , i ) (k, i) (k,i),从所有候选样本中选出 R R R 个索引 { ( l 1 , j 1 ) , … , ( l R , j R ) } \{(l_1,j_1),\ldots,(l_R,j_R)\} {(l1,j1),…,(lR,jR)},使得这些样本的哈希值与 ( k , i ) (k, i) (k,i) 的哈希值的余弦相似度之和最大。
arg max ( l 1 , j 1 ) , ( l 2 , j 2 ) , … , ( l R , j R ) ∑ m = 1 R cos ( I H ( k , i ) , I H ( l m , j m ) ) , s.t. { l n 1 ≠ l n 2 ∨ j n 1 ≠ j n 2 , ∀ n 1 , n 2 且 n 1 ≠ n 2 , ( k , i ) ∈ L I ( y ∗ ) ∧ ( l m , j m ) ∈ L I ( y ∗ ) , ∃ y ∗ , n 1 , n 2 , m ∈ { 1 , 2 , … , R } , y ∗ ∈ { 1 , 2 , … , C } , (10) \begin{aligned} &\arg \max_{(l_1,j_1),(l_2,j_2),\ldots,(l_R,j_R)} \sum_{m=1}^{R} \cos\big(IH(k, i), IH(l_m, j_m)\big), \\ &\text{s.t. } \begin{cases} l_{n_1} \neq l_{n_2} \lor j_{n_1} \neq j_{n_2}, & \forall n_1, n_2 \text{ 且 } n_1 \neq n_2,\\ (k, i) \in LI(y^*) \land (l_m, j_m) \in LI(y^*), & \exists y^*,\\ n_1, n_2, m \in \{1, 2, \ldots, R\},\\ y^* \in \{1, 2, \ldots, C\}, \end{cases} \end{aligned} \tag{10} arg(l1,j1),(l2,j2),…,(lR,jR)maxm=1∑Rcos(IH(k,i),IH(lm,jm)),s.t. ⎩ ⎨ ⎧ln1=ln2∨jn1=jn2,(k,i)∈LI(y∗)∧(lm,jm)∈LI(y∗),n1,n2,m∈{1,2,…,R},y∗∈{1,2,…,C},∀n1,n2 且 n1=n2,∃y∗,(10)- 互异性:选出的 R R R 个索引不能有重复(即任意两个不能同时相同)。
- 同标签约束: ( k , i ) (k, i) (k,i) 和选出的每个邻居 ( l m , j m ) (l_m, j_m) (lm,jm) 必须属于同一个标签类别 y ∗ y^* y∗(即 L I ( y ∗ ) LI(y^*) LI(y∗) 中包含它们)。
- 取值范围:索引的客户端编号和样本编号在合理范围内,且 y ∗ y^* y∗ 是某个有效类别。
在此过程中,采用 HNSW 来实现 (R) 近邻检索。然后,将与每个样本索引相关的检索结果保存在 IR 中,以供后续访问,即:
I R ( k , i ) ← { ( l 1 , j 1 ) , ( l 2 , j 2 ) , … , ( l R , j R ) } . (11) IR(k, i) \leftarrow \{(l_1, j_1), (l_2, j_2), \ldots, (l_R, j_R)\}. \tag{11} IR(k,i)←{(l1,j1),(l2,j2),…,(lR,jR)}.(11)
在训练过程中,对于每个给定的样本索引,应执行以下步骤:
-
知识获取。基于提供的样本索引,可以从知识缓存 (KC) 中获取最相关的知识:对于新上传的样本索引 ( k , i ) (k, i) (k,i) ,根据 I R IR IR存储的 I R ( k , i ) IR(k, i) IR(k,i) 的相关样本索引找到该样本的 R R R 个邻居索引,通过 I K IK IK(索引-知识映射)取出这些邻居当前存储的最新知识,获取并返回相应的知识(通常是一个包含 R R R 个 logits 向量的列表),即:
K C ( h i k ; k , i ) = I K ( I R ( k , i ) ) . (12) KC(h_i^k; k, i) = IK(IR(k, i)). \tag{12} KC(hik;k,i)=IK(IR(k,i)).(12)知识获取只需要请求知识的客户端在线,客户端可以异步执行基于获取知识的优化。
-
知识更新。在每一轮训练中,客户端会将自己对样本 ( k , i ) (k, i) (k,i) 的模型输出(logits) z i k z_i^k zik 上传给服务器。 服务器收到后,执行此更新操作,将 I K ( k , i ) IK(k, i) IK(k,i) 中的旧知识替换为最新的 z i k z_i^k zik,以便下次访问时可以获取到最新的知识,即:
I K ( k , i ) ← z i k . (13) IK(k, i) \leftarrow z_i^k. \tag{13} IK(k,i)←zik.(13)
4.3 知识缓存驱动的个性化蒸馏
我们通过个性化联邦蒸馏来优化设备端模型,其中与每个客户端私有数据相似的样本知识从知识缓存中获取。在此基础上,每个客户端对获取的知识执行集成蒸馏,以实现设备端模型的建设性优化。
具体来说,在初始化阶段,采用一个预训练的深度神经网络 f h ( ⋅ ) f^{h}(\cdot) fh(⋅) 作为编码器,在客户端上生成样本的哈希值,即:
h i k = f h ( X i k ) , (14) h_i^k = f^{h}(X_i^k), \quad \tag{14} hik=fh(Xik),(14)
这些哈希值连同相应的样本索引和标签被上传到服务器,用于根据公式 (7)、(8)、(9)、(10)、(11) 初始化知识缓存。
在训练过程中,对于每个给定样本 ( X i k , y i k ) (X_i^k, y_i^k) (Xik,yik),客户端 k k k 首先在 X i k X_i^k Xik 上提取知识 z i k z_i^k zik,然后将 z i k z_i^k zik 连同相应的样本索引 ( k , i ) (k, i) (k,i) 上传到服务器,其中:
z i k = f k ( X i k ) 。 (15) z_i^k = f^{k}(X_i^k)。 \quad \tag{15} zik=fk(Xik)。(15)
然后,根据公式 (12) 从知识缓存 K C KC KC 中获取与样本索引 ( k , i ) (k, i) (k,i) 相关的 R R R 个知识,即:
( z r i k ) 1 , ( z r i k ) 2 , … , ( z r i k ) R = K C ( h i k ; k , i ) , (16) (zr_i^k)_1, (zr_i^k)_2, \dots, (zr_i^k)_R = KC(h_i^k; k, i), \quad \tag{16} (zrik)1,(zrik)2,…,(zrik)R=KC(hik;k,i),(16)
其中 ( z r i k ) s (zr_i^k)_s (zrik)s 是为给定样本索引 ( k , i ) (k, i) (k,i) 获取的第 s s s 个知识。获取的知识以平均方式集成,可以表示为:
z r i k ‾ = 1 R ∑ s = 1 R ( z r i k ) s 。 (17) \overline{zr_i^k} = \frac{1}{R} \sum_{s=1}^{R} (zr_i^k)_s。 \quad \tag{17} zrik=R1s=1∑R(zrik)s。(17)
随后,集成的知识被分发给客户端 k k k,用于执行基于蒸馏的本地模型优化,其权重由因子 β \beta β 控制,定义如下:
arg min W k J k ( W k ) = arg min W k ∑ ( X i k , y i k ) ∈ D k [ L C E ( τ ( f k ( X i k ) ) , y i k ) + β ⋅ K L ( τ ( f k ( X i k ) ) ∥ τ ( z r i k ‾ ) ) ] 。 (18) \begin{aligned} \arg\min_{W^k} J^k(W^k) = \arg\min_{W^k} \sum_{(X_i^k, y_i^k) \in \mathcal{D}^k} \Big[ & L_{CE}\big(\tau(f^k(X_i^k)), y_i^k\big) \\ & + \beta \cdot KL\big(\tau(f^k(X_i^k)) \big\| \tau(\overline{zr_i^k})\big) \Big]。 \end{aligned} \quad \tag{18} argWkminJk(Wk)=argWkmin(Xik,yik)∈Dk∑[LCE(τ(fk(Xik)),yik)+β⋅KL(τ(fk(Xik)) τ(zrik))]。(18)
- 左边: arg min W k J k ( W k ) \arg\min_{W^k} J^k(W^k) argminWkJk(Wk) 表示我们要寻找一组模型参数 W k W^k Wk,使得损失函数 J k ( W k ) J^k(W^k) Jk(Wk) 的值最小化。
- 右边:损失函数 J k ( W k ) J^k(W^k) Jk(Wk) 是对客户端 k k k 的本地数据集 D k \mathcal{D}^k Dk 中所有样本的损失求和。每个样本的损失由两部分组成:交叉熵损失(监督学习)和 KL 散度损失(知识蒸馏),两者通过权重因子 β \beta β 平衡。
4.4 FedCache 的形式化描述
FedCache 的执行流程概览如图 4 所示,FedCache 在客户端 k k k 和服务器上的执行过程分别在算法 1 和算法 2 中形式化给出。从整体上看,我们允许设备上的个性化本地模型在服务器端知识缓存的辅助下,对与私有数据相似的样本的集成知识进行蒸馏。
具体来说,FedCache 包含以下步骤:
- 哈希编码与上传。对于给定客户端的每个样本,根据公式 (14) 基于预训练的本地编码器编码得到哈希值(算法 1 第 3 行)。该哈希值连同相应的标签和样本索引一起上传到服务器(算法 1 第 4 行)。由于编码器是一个深度预训练神经网络,具有大量叠加的非线性映射,且输出码的维度远小于数据本身的维度,因此与服务器共享哈希值是保护隐私的。
- 知识缓存初始化。服务器接受来自客户端的上传信息(算法 2 第 3 行),并根据公式 (7)、(8)、(9)、(10)、(11) 在知识缓存中建立样本索引之间的关系,使得每个样本可以被索引到 RR 个相关样本(算法 2 第 4-6 行)。
- 知识提取与上传。客户端提取 logits(算法 1 第 9 行),并将 logits 连同相应的样本索引上传到服务器(算法 1 第 10 行)。这一步替代了 PIA 和 SLIA-FE 中的参数/特征上传步骤。由于 logits 和样本索引的大小比模型参数或特征小几个数量级,通信负担可以显著降低。
- 知识获取。服务器接受客户端上传的样本索引(算法 2 第 10 行),并基于预先建立的样本索引关系,从知识缓存中获取 R R R 个最匹配的知识(算法 2 第 11 行)。这一步使得设备端模型能够获得样本粒度的知识,而不受类别数量的限制。
- 知识集成与分发。获取的知识在服务器上根据公式 (17) 进行集成(算法 2 第 12 行),随后分发给对应的客户端(算法 2 第 13 行)。这一步同样通信高效,因为服务器和客户端之间只传输 logits。
- 知识更新。知识缓存中存储的知识根据新上传的知识(算法 2 第 10 行)按照公式 (13) 进行更新(算法 2 第 14 行)。
- 知识接收与蒸馏。客户端接收从服务器分发下来的集成知识(算法 1 第 11 行),并根据公式 (18) 优化客户端本地模型(算法 1 第 12-13 行)。这一步可以在每个客户端上异步执行,无需等待其他客户端完成之前的步骤。

算法 1:客户端上的 FedCache

- 初始化阶段(第 1-5 行):客户端利用预训练编码器 f h f^h fh 为每个本地样本生成哈希值 h i k h_i^k hik,并将哈希值、样本索引和标签上传给服务器。这一步只需执行一次,用于在服务器端建立样本相似性索引。
- 训练阶段(第 6-14 行):在每一轮通信中,客户端对每个本地样本计算模型输出 z i k z_i^k zik(logits),上传给服务器;然后等待服务器返回该样本的集成知识 z i k ‾ \overline{z_i^k} zik;最后利用本地损失函数 J k J^k Jk(公式 18)进行梯度下降更新模型参数 W k W^k Wk。整个过程可以异步进行,客户端无需等待其他客户端。
算法 2:服务器上的 FedCache

- 初始化阶段(第 1-6 行):服务器收集所有客户端上传的哈希值、样本索引和标签,初始化索引-哈希映射 I H IH IH、标签-索引映射 L I LI LI 和索引-知识映射 I K IK IK(初始化为零)。然后利用 HNSW 算法为每个样本找到 R R R 个哈希最相似的邻居,建立索引关系映射 I R IR IR。这一步只需执行一次。
- 训练阶段(第 7-16 行):服务器持续监听客户端上传的知识。当收到某个样本的 logits z i k z_i^k zik 和索引 ( k , i ) (k,i) (k,i) 时,首先根据 I R IR IR 找到该样本的 R R R 个邻居索引,然后从 I K IK IK 中取出这些邻居当前存储的知识;将这些知识集成为 z r i k ‾ \overline{zr_i^k} zrik(平均),下发给对应的客户端;最后用新上传的 z r i k zr_i^k zrik 更新 I K IK IK 中该样本对应的条目,以便后续查询使用。所有操作都是按需触发的,支持异步处理。
5 实验
5.1 实验设置
5.1.1 数据集与预处理
为了在个性化数据上评估 FedCache,对完整的训练集和测试集采用相同的数据划分策略,确保每个设备上训练和测试本地数据的标签分布一致。每个数据集划分为 300 份非独立同分布的数据,用于在 K = 300 K = 300 K=300 个不同客户端上进行训练和测试,超参数 α \alpha α 设置为 1.0。每个客户端在本地运行一个 epoch,然后进行模型聚合或特征/知识传输。
- 数据集:MNIST、FashionMNIST、CIFAR-10、CINIC-10。
- 数据预处理:采用FedML中的数据划分方案,使用超参数 α \alpha α( α > 0 \alpha > 0 α>0)来控制设备间本地数据分布的差异程度。随着 α \alpha α 降低,设备之间的数据分布显示出更大程度的异构性。
5.1.2 基准方法与评价标准
- 基准算法:与多种架构的最先进 PFL 方法进行比较,包括基于 PIA 的 FMTL和 pFedMe,基于 SLIA-FE 的 FedDKC和 FedICT,以及基于 CLIA 的 FD。
- 评价指标:使用最大平均用户模型准确性(MAUA)作为主要的精度评价指标。此外,还计算了达到特定平均UA精度的通信开销acc@以衡量系统的通信效率。
- 训练时间:论文中的MAUA 结果是在合理的训练时间内获得的,当算法达到收敛或总通信开销达到给定限制时(例如 CIFAR-10 和 CINIC-10 数据集分别为 55G 和 19G),记录结果。
5.1.3 模型
- 深度预训练编码器:使用在 ImageNet 上预训练的 MobileNetV3并删除了最后一个全连接层。
- 模型架构:考虑了4 种不同的模型架构,其中 A 1 C , A 2 C , A 3 C {A_1^C, A_2^C, A_3^C} A1C,A2C,A3C 用于客户端, A S A^S AS 用于服务器(服务器上的模型不包含最前面的 Conv+Batch+ReLU 层),四种采用模型的主要配置如表 4 所示。
表 4. 四种采用模型的主要配置。输入图像的高度和宽度分别记为 H H H 和 W W W。
| 模型 | 符号 | 特征形状 | 参数量 |
|---|---|---|---|
| ResNet-small | A 1 C A_1^C A1C | 76.2K | |
| ResNet-medium | A 2 C A_2^C A2C | 171.2K | |
| ResNet-large | A 3 C A_3^C A3C | H × W × 16 H \times W \times 16 H×W×16 | 266.1K |
| ResNet-server | A 4 S A_4^S A4S | 588.2K |
实验同时考虑了客户端模型的同质性和异质性。在同构模型实验中,将 FedCache 与上述所有基准算法进行比较,所有客户端均采用模型架构 A 3 C A_3^C A3C。在异构模型实验中,FedCache 仅与支持客户端间模型异构性的基准算法进行比较,包括 FedDKC、FedICT 和 FD,其中索引模 3 余数为 0、1 和 2 的客户端分别分配模型架构 A 1 C A_1^C A1C、 A 2 C A_2^C A2C 和 A 3 C A_3^C A3C。
5.1.4 超参数设置
在所有实验中采用随机梯度下降法,学习率 l r = 0.01 lr = 0.01 lr=0.01,批大小为 8。
| 算法 | 超参数设置 |
|---|---|
| pFedMe | η = 0.005 \eta = 0.005 η=0.005, λ = 15 \lambda = 15 λ=15, β = 1 \beta = 1 β=1 |
| MTFL (FMTL) | 采用 FedAvg 优化策略,其余超参数遵循 [51] 默认设置 |
| FedDKC | KKR 知识细化策略, β = 1.5 \beta = 1.5 β=1.5, T = 0.12 T = 0.12 T=0.12 |
| FedICT | 基于相似性的 LKA 策略, β = λ = μ = 1.5 \beta = \lambda = \mu = 1.5 β=λ=μ=1.5, T = 3.0 T = 3.0 T=3.0 |
| FD | 无需单独的个性化超参数 |
| FedCache (ours) | β = 1.5 \beta = 1.5 β=1.5, R = 16 R = 16 R=16 |
5.2 实验结果
5.2.1 同构模型上的性能
MAUA:FedCache 在 FashionMNIST、CIFAR-10 和 CINIC-10 数据集上分别达到了 77.71%、44.42% 和 40.45% 的 MAUA,与所考虑的基准算法性能相当。
通信开销:FedCache 在上述三个数据集上的总通信开销均小于 0.20G,加速比均超过 78 倍,通信效率远高于以往架构的现有方法。(图 5 FedCache 收敛曲线比 FedDKC、FedICT、pFedMe 和 MTFL 陡峭得多)
表 5. 同构设备端模型上的 MAUA (%)、通信开销和通信效率加速比。部分方法在给定实验设置下无法达到表 3 中的 MAUA,因此无法计算其通信开销和相应的加速比,对应项用“-”表示。下表同。
| 数据集 | 方法 | 模型 | 模型 | MAUA (%) | 通信开销 (G) | 加速比 |
|---|---|---|---|---|---|---|
| 客户端 | 服务器 | |||||
| MNIST | pFedMe | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 94.89 | 13.25 | ×1.0 |
| MTFL | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 95.59 | 7.77 | ×1.7 | |
| FedDKC | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 89.62 | 9.13 | - | |
| FedICT | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 84.62 | - | - | |
| FD | A 3 C A_3^C A3C | - | 84.19 | - | - | |
| FedCache | A 3 C A_3^C A3C | - | - | - | - | |
| FashionMNIST | pFedMe | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 81.57 | 20.71 | ×1.0 |
| MTFL | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 83.92 | 12.33 | ×1.7 | |
| FedDKC | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 78.24 | 8.43 | - | |
| FedICT | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 76.90 | 13.34 | ×1.6 | |
| FD | A 3 C A_3^C A3C | - | 76.32 | - | - | |
| FedCache | A 3 C A_3^C A3C | - | 77.71 | 0.08 | ×258.9 | |
| CIFAR-10 | pFedMe | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 37.49 | - | - |
| MTFL | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 43.43 | 52.99 | ×1.0 | |
| FedDKC | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 45.87 | 11.46 | - | |
| FedICT | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 43.61 | 10.69 | ×5.0 | |
| FD | A 3 C A_3^C A3C | - | 42.77 | - | - | |
| FedCache | A 3 C A_3^C A3C | - | 44.42 | 0.19 | ×278.9 | |
| CINIC-10 | pFedMe | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 31.65 | - | - |
| MTFL | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 34.09 | - | - | |
| FedDKC | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 43.95 | 4.12 | - | |
| FedICT | A 3 C A_3^C A3C | A 3 C A_3^C A3C | 42.79 | 5.50 | - | |
| FD | A 3 C A_3^C A3C | - | 39.36 | - | - | |
| FedCache | A 3 C A_3^C A3C | - | 40.45 | 0.07 | ×78.6 |
5.2.2 异构模型上的性能
表 6 显示了 FedCache 与支持客户端模型异构的基准算法的比较。FedCache 实现了与所考虑基准相当的 MAUA,但由于消除了 FedDKC 和 FedICT 所需的特征传输,通信效率极高。
表 6. 异构设备端模型上的 MAUA (%)、通信开销和通信效率加速比
| 数据集 | 方法 | 模型 | 模型 | MAUA (%) | 通信开销 (G) | 加速比 |
|---|---|---|---|---|---|---|
| 客户端 | 服务器 | |||||
| MNIST | FedDKC | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 85.38 | 10.53 | ×1.0 |
| FedICT | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 80.53 | - | - | |
| FD | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 79.90 | - | - | |
| FedCache | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 83.94 | 0.10 | ×105.3 | |
| FashionMNIST | FedDKC | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 77.96 | 12.64 | ×1.0 |
| FedICT | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 76.11 | - | - | |
| FD | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 75.57 | - | - | |
| FedCache | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 77.26 | 0.08 | ×158.0 | |
| CIFAR-10 | FedDKC | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 44.53 | 4.58 | ×1.2 |
| FedICT | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 43.96 | 5.35 | ×1.0 | |
| FD | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 40.40 | - | - | |
| FedCache | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 41.59 | 0.05 | ×107.0 | |
| CINIC-10 | FedDKC | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 44.80 | 4.12 | ×1.3 |
| FedICT | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | A 4 S A_4^S A4S | 43.40 | 5.50 | ×1.0 | |
| FD | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 40.76 | - | - | |
| FedCache | A 1 C , A 2 C , A 3 C A_1^C, A_2^C, A_3^C A1C,A2C,A3C | - | 41.71 |
6 消融研究
6.1 消融设置
消融实验均在 FashionMNIST 数据集上进行评估,探究三个因素对 FedCache 性能的影响:数据异质性程度、本地样本比例和相关样本数量。算法的性能在以下小节中均以 MAUA (%) 来衡量。
6.2 结果
6.2.1 数据异质性程度的影响
为了探究数据异质性对 FedCache 性能的影响,将超参数 α \alpha α 设置为不同的值 α ∈ 1.0 , 3.0 , 10.0 \alpha \in {1.0, 3.0, 10.0} α∈1.0,3.0,10.0 来控制数据异质性程度,并比较 FedCache 与 FD 的性能。
尽管本地数据分布存在偏斜,FedCache 始终优于 FD,反映了FedCache对不同数据环境的适应性。
6.2.2 本地数据比例的影响
为了研究 FedCache 在不同本地数据占整体数据百分比下的性能,将本地样本数量控制在整个数据集的 {0.33%, 1%, 5%, 20%},并比较 FedCache 与 FD 的性能。
随着客户端本地样本占比的增加,FD 和 FedCache 的性能都有所提升。尽管如此,FedCache 始终优于 FD,这证实了FedCache 在单个客户端的本地数据比例变化时仍具有优越性能。
6.2.3 相关样本数量的影响
为了评估 FedCache 在不同相关样本数量下的性能,设置 R ∈ 1 , 4 , 16 , 64 , 256 R \in {1,4,16,64,256} R∈1,4,16,64,256,并比较 FedCache 与 FD 在上述 R R R 设置下的性能。
FedCache在不同的 R R R 设置下始终优于 FD,这表明 FedCache 对相关样本的选择具有鲁棒性。
7 讨论
7.1 计算复杂度分析
- 设备端
- FedCache 与 PIA、SLIA-FE、CLIA 计算复杂度相同,都集中在本地模型的前向传播上。相比 SLIA-PD,FedCache 更低(因为 SLIA-PD 需要在公共数据集上额外训练)。
- 服务器端
- 与 PIA 对比:FedCache 的复杂度取决于样本量和参数规模。实验中(R=16,单客户端样本数百个),FedCache 的服务器端计算复杂度远小于 PIA,且收敛性更好。
- 与 SLIA-FE/SLIA-PD 对比:FedCache 的服务器端复杂度远小于两者。因为 FedCache 不需要像 SLIA-FE 那样对每个样本进行服务器模型前向传播,也不像 SLIA-PD 那样需要处理大规模的公共数据集。
- 与 CLIA 对比:CLIA 的服务器端复杂度虽然更低,但这是以牺牲知识丰富性和性能为代价的。FedCache 在平衡计算开销和系统性能方面具有不可替代的优势。
7.2 局限性
- 仅适用于个性化任务:FedCache 的蒸馏过程只关注与本地样本相关的知识,忽略了全局泛化能力的学习。因此,它更适合个性化场景,而不适用于需要强全局泛化能力的通用任务。未来可以通过引入部分全局参数或公共数据来改进。
- 实验范围有限:目前的实验仅验证了传统图像分类问题。对于其他类型的数据(如序列化数据、非图像结构化数据),需要进一步研究相应的数据编码策略和哈希相关性度量方法。
- 不支持动态/连续数据:FedCache 的架构假设数据是静态的。但在实际边缘智能场景中,终端设备会持续产生新的数据,需要实时处理和分析。目前的 FedCache 无法很好地支持这种动态性和连续性数据流。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)