基于全局特征共享的个性化联邦学习方法、装置及系统
未命名
09-20
阅读:110
评论:0

1.本发明涉及一种人工智能技术领域,具体涉及一种基于全局特征共享的个性化联邦学习方法、装置及系统。
背景技术:
2.联邦学习是一种分布式机器学习方法,在多个本地节点上训练模型,无需将原始数据集传输到中心服务器,通常用于隐私任务,例如医疗保健、金融领域。标准的联邦学习的目标为训练一个泛化性能较好的全局模型,在此过程中节点和服务器共享相同的全局模型。但是,由于每个节点的数据分布不同,导致全局模型无法较好得适应所有节点的异构数据,从而导致其泛化性能较差。
3.由此,提出个性化联邦学习,在个性化联邦学习中,为每个节点提供一个在其本地数据上表现最佳的个性化模型,具体而言,个性化联邦学习需要为每个节点单独训练私有模型来拟合本地数据集,但是,由于本地数据的异构性和样本数量的限制,节点的本地模型容易出现过拟合的问题。
技术实现要素:
4.针对现有技术的缺陷,本公开的目的是提供一种基于全局特征共享的个性化联邦学习方法、装置及系统。
5.为了实现上述目的,根据本公开的第一方面,提供一种基于全局特征共享的个性化联邦学习方法,应用于客户端,所述客户端包括本地模型,包括:
6.接收服务器发送的全局特征提取器模型和全局特征;
7.根据所述全局特征提取器模型和本地分类器模型,初始化本地模型;
8.将本地图像数据输入经过所述初始化的本地模型进行模型训练,确定所述本地模型的损失函数,所述损失函数包括所述本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项;
9.根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理;
10.当所述本地模型收敛时,确定目标本地模型。
11.可选地,所述本地模型包括本地特征提取器模型和本地分类器模型;
12.所述根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理,包括:
13.根据所述本地模型的损失函数,基于反向传播对本地特征提取器模型进行所述第二更新处理以及对所述本地分类器模型进行所述第三更新处理。
14.可选地,所述方法还包括:
15.将所述本地图像数据输入经过所述第二更新处理后的本地特征提取器模型中,确定所述本地图像数据中的每一图像的末次更新特征;
16.在所述本地模型的第一更新处理次数达到预设阈值时,根据图像类别,将所述具有相同图像类别的所述本地图像数据的末次更新特征进行第一乘积处理,确定本地特征。
17.可选地,所述方法还包括:
18.在所述本地模型的第一更新处理次数达到预设阈值时,确定末次经过所述第二更新处理后的所述本地特征提取器模型;
19.将所述末次经过所述第二更新处理后的所述本地特征提取器模型和所述本地特征发送至所述服务器。
20.根据本公开的第二方面,提供一种基于全局特征共享的个性化联邦学习方法,应用于服务器端,所述服务器端包括全局特征提取器模型,包括:
21.应用于服务器端,所述服务器端包括全局特征提取器模型,包括:
22.初始化全局特征提取器模型和全局特征;
23.将经过所述初始化的全局特征提取器模型和所述全局特征发送至所述客户端;
24.接收所述客户端发送的末次经过第二更新处理后的本地特征提取器模型和本地特征;
25.将所述本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;
26.根据图像类别将所述本地特征进行第二乘积处理,确定全局特征。
27.可选地,所述方法还包括:
28.将所述全局特征提取器模型和所述全局特征发送至所述客户端,所述客户端执行接收服务器发送的所述全局特征提取器模型和所述全局特征的步骤。
29.根据本公开的第三方面,提供一种基于全局特征共享的个性化联邦学习装置,应用于客户端,所述客户端包括本地模型,包括:
30.客户端第一接收模块,用于接收服务器发送的全局特征提取器模型和全局特征;;
31.客户端初始化模块,用于根据所述全局特征提取器模型、所述全局特征以及本地分类器,初始化本地模型;
32.客户端第一确定模块,用于将本地图像数据输入经过所述初始化的本地模型进行模型训练,确定所述本地模型的损失函数,所述损失函数包括所述本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项;
33.客户端第一更新模块,用于根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理;
34.客户端第二确定模块,用于当所述本地模型收敛时,确定目标本地模型。
35.根据本公开的第四方面,提供一种基于全局特征共享的个性化联邦学习装置,应用于服务器端,所述服务器端包括全局特征提取器模型,包括:
36.服务器初始化模块,用于初始化全局特征提取器模型和全局特征;
37.服务器发送模块,用于将经过所述初始化的全局特征提取器模型和所述全局特征发送至所述客户端;
38.服务器接收模块,用于接收所述客户端发送的末次经过第二更新处理后的本地特征提取器模型和本地特征;
39.服务器第一确定模块,用于将所述本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;
40.服务器第二确定模块,用于根据图像类别将所述本地特征进行第二乘积处理,确定全局特征。
41.根据本公开的第五方面,提供一种基于全局特征共享的个性化联邦学习系统,包括:
42.本地模型更新模块,用于在本地训练节点基于反向传播对本地模型进行第一更新处理,所述对本地模型进行第一更新处理包括对本地特征提取器模型进行第二更新处理以及对本地分类器模型进行第三更新处理;
43.本地特征提取模块,用于当所述本地模型更新模块在所述本地训练节点对所述本地模型进行第一更新处理时,提取本地图像数据的更新特征并确定本地特征;
44.全局特征提取器聚合模块,用于将经过所述第二更新处理后的本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;
45.全局特征更新模块,用于将所述本地特征提取模块确定的所述本地特征根据图像类别进行第二乘积处理,确定全局特征;
46.通信模块,用于将所述客户端的经过所述第二更新处理后的所述本地特征提取器模型和所述本地特征传输至所述服务器端,并且将所述服务器端的所述全局特征提取器模型和所述全局特征传输至所述客户端。
47.可选地,所述客户端包括所述本地模型更新模块和所述本地特征提取模块,所述服务器端包括所述全局特征提取器聚合模块和所述全局特征更新模块,所述通信模块还用于连接所述客户端和所述服务器端。
48.与现有技术相比,本公开实施例具有如下至少一种有益效果:
49.通过上述技术方案,通过引入全局特征和条件互信息正则项,将客户端的本地训练节点的数据共享,采用数据分布广泛和数据特征全面的全局特征,向本地训练节点提供更全局更泛化的数据信息,本地训练节点能够采用其他节点的节点数据,从而防止本地模型的过拟合。
附图说明
50.通过阅读参照以下附图对非限制性实施例所作的详细描述,本发明的其它特征、目的和优点将会变得更明显:
51.图1是根据一示例性实施例示出的一种应用于客户端的基于全局特征共享的个性化联邦学习方法的流程图。
52.图2是根据另一示例性实施例示出的一种应用于客户端的基于全局特征共享的个性化联邦学习方法的流程图。
53.图3是根据一示例性实施例示出的一种应用于服务器的基于全局特征共享的个性化联邦学习方法的流程图。
54.图4是根据一示例性实施例示出的客户端、通信模块和服务器端之间的信令交互示意图。
55.图5是根据一示例性实施例示出的一种应用于客户端的基于全局特征共享的个性化联邦学习装置的框图。
56.图6是根据一示例性实施例示出的一种应用于服务器的基于全局特征共享的个性
化联邦学习装置的框图。
57.图7是根据一示例性实施例示出的一种基于全局特征共享的个性化联邦学习系统的框图。
具体实施方式
58.下面结合具体实施例对本发明进行详细说明。以下实施例将有助于本领域的技术人员进一步理解本发明,但不以任何形式限制本发明。应当指出的是,对本领域的普通技术人员来说,在不脱离本发明构思的前提下,还可以做出若干变形和改进。这些都属于本发明的保护范围。
59.图1是根据一示例性实施例示出的一种应用于客户端的基于全局特征共享的个性化联邦学习方法的流程图。如图1所示,一种基于全局特征共享的个性化联邦学习方法,应用于客户端,客户端包括本地模块,包括:
60.s11,接收服务器发送的全局特征提取器模型和全局特征。
61.其中,客户端的本地包括多个本地训练节点,本地训练节点基于本地隐私数据训练本地模型,本地隐私数据即为本地图像数据。由于通信连接的稳定性较差,第一轮参与模型训练的本地训练节点为随机节点。
62.本轮参与本地模型训练的随机本地训练节点接收服务器发送的全局特征提取器模型和全局特征。
63.s12,根据全局特征提取器模型、全局特征和本地分类器模型,初始化本地模型。
64.其中,本地模型包括本地特征提取器模型和本地分类器模型,在随机本地训练节点接收服务器发送的全局特征提取器模型和全局特征之后,根据全局特征提取器和全局特征以及本地分类器,对全局特征进行特征分类,从而对本地特征提取器模型进行初始化。
65.s13,将本地图像数据输入经过初始化的本地模型进行模型训练,确定本地模型的损失函数,损失函数包括本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项。
66.在一种可能的实施例中,将本地图像数据输入本轮的随机本地训练节点的本地模型中,进行正向模型训练。
67.其中,将本地图像数据输入本地模型的本地特征提取器模型中,输出本地图像特征,将本地图像数据输入本地分类器模型中,输出预测的训练标签。
68.损失函数包括:
[0069][0070]
其中,表示损失函数,fi(wi)表示本地模型wi的交叉熵损失函数,即本地图像数据的训练标签和真实标签之间的交叉熵损失,β||ii(z;xi|yi)-i(z;x|yi)||表示条件互信息正则项,β表示拉格朗日乘数,xi表示输入的本地图像数据,yi表示本地图像数据对应的训练标签,z表示生成的图像特征,ii(z;xi|yi)表示本地条件互信息,i(z;x|yi)表示全局条件互信息,p表示概率分布。
[0071]
在本公开中,条件互信息正则项表示本地条件互信息和全局条件互信息的差,本地条件互信息表示本地图像数据与本地特征在给定标签时的互信息;全局条件互信息表示
全局图像数据与全局特征在给定标签时的互信息。采用条件互信息正则项,促使本地条件互信息和全局条件互信息相互接近。
[0072]
s14,根据本地模型的损失函数,基于反向传播对本地模型进行第一更新处理。
[0073]
其中,第一更新处理包括第二更新处理和第三更新处理。
[0074]
在一些可能的实施例中,根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理,包括s141。
[0075]
s141,根据本地模型的损失函数,基于反向传播对本地特征提取器模型进行第二更新处理以及对本地分类器模型进行第三更新处理。
[0076]
根据所确定的本地模型的损失函数,通过反向传播和梯度计算,确定本地模型的梯度,并可以采用随机梯度下降算法优化本地模型的参数,即对本地特征提取器模型和本地分类器模型的参数进行更新处理。
[0077]
本领域技术人员应当理解,还可以采用其他算法优化本地模型的参数,均落入本公开的保护范围。
[0078]
s15,当本地模型收敛时,确定目标本地模型。
[0079]
当本地模型收敛时,停止对本地模型进行模型模型训练,将末次模型训练的本地模型确定为目标本地模型,从而确定本地训练节点上的个性化模型,防止本地模型出现过拟合的现象。
[0080]
通过上述技术方案,通过引入全局特征和条件互信息正则项,将客户端的本地训练节点的数据共享,采用数据分布广泛和数据特征全面的全局特征,向本地训练节点提供更全局更泛化的数据信息,本地训练节点能够采用其他节点的节点数据,从而防止本地模型的过拟合。
[0081]
图2是根据另一示例性实施例示出的一种应用于客户端的基于全局特征共享的个性化联邦学习方法的流程图。
[0082]
在一些可能的实施例中,如图2所示,基于全局特征共享的个性化联邦学习方法,应用于客户端,还包括s16至s17。
[0083]
s16,将本地图像数据输入经过第二更新处理后的本地特征提取器模型中,确定本地图像数据中的每一图像的末次更新特征。
[0084]
在对本地模型进行第一更新处理时,保存每一经过第二更新处理后的本地特征提起模型所输出的本地图像数据的图像特征。
[0085]
在重复s11至s14步骤的次数达到预设阈值k时,(k≥1),确定最后一次对本地特征提取器模型进行第二更新处理所输出的本地图像特征,即本地图像数据中的每一图像的末次更新特征。
[0086]
s17,在本地模型的第一更新处理次数达到预设阈值时,根据图像类别,将具有相同图像类别的本地图像数据的末次更新特征进行第一乘积处理,确定本地特征。
[0087]
其中,第一乘积处理可以采用专家积聚合方式。
[0088]
接上述示例,将最后一次对本地特征提取器模型进行第二更新处理所输出的本地图像特征,根据图像类别,将具有相同图像类别的本地图像特征相乘,生成不同类别的特征,并将其确定为本地特征。
[0089]
在一些可能的实施例中,如图2所示,基于全局特征共享的个性化联邦学习方法,
应用于客户端,还包括s18至s19。
[0090]
s18,在本地模型的第一更新处理次数达到预设阈值时,确定末次经过第二更新处理后的本地特征提取器模型。
[0091]
s19,将末次经过第二更新处理后的本地特征提取器模型和本地特征发送至服务器。
[0092]
接上述示例,在重复s11至s14步骤的次数达到预设阈值k时,(k≥1),将本轮训练的每一本地训练节点的最后一次经过第二更新处理的本地特征提取器模型和本地特征通过通信模块发送至服务器。
[0093]
其中,通信模块用于连接客户端的各个节点和服务器。
[0094]
通过上述技术方案,为本地训练节点引入正则化项,将本地条件互信息和全局条件互信息之间的差异最小化,以鼓励本地训练节点学习并共享特征表示,以缓解本地模型的过拟合现象。
[0095]
图3是根据一示例性实施例示出的一种应用于服务器的基于全局特征共享的个性化联邦学习方法的流程图。如图3所示,一种基于全局特征共享的个性化联邦学习方法,应用于服务器端,包括s21至s25。
[0096]
s21,初始化全局特征提取器模型和全局特征。
[0097]
s22,将经过初始化的全局特征提取器模型和全局特征发送至客户端。
[0098]
服务器首先初始化全局特征提取器模型和全局特征,由于服务器和客户端之间的通信连接不稳定,服务器向客户端的参与本轮模型训练的本地训练节点发送统一的全局特征提取器模型和统一的全局特征。
[0099]
s23,接收客户端发送的末次经过第二更新处理后的本地特征提取器模型和本地特征。
[0100]
s24,将本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型。
[0101]
其中,全局聚合处理可以采用算数平均的方法,其权重系数为本轮参与模型训练的本地训练节点数量占本地训练节点总数量的比例。
[0102]
服务器接收本轮参与模型训练的本地训练节点发送的本地特征提取器模型,并将所接收的本地特征提取器模型进行全局聚合处理,以确定全局特征提取器模型。
[0103]
s25,根据图像类别将本地特征进行第二乘积处理,确定全局特征。
[0104]
其中,第二乘积处理可以采用专家积的方式。
[0105]
服务器将接收到的本轮参与模型训练的本地训练节点的各个本地特征按照本地图像数据的图像类别相乘,输出不同类别的全局特征。
[0106]
在一些可能的实施例中,如图3所示,一种基于全局特征共享的个性化联邦学习方法,应用于服务器端,还包括s26。
[0107]
s26,将全局特征提取器模型和全局特征发送至客户端,客户端执行接收服务器发送的全局特征提取器模型和全局特征的步骤。
[0108]
接上述示例,将所确定的全局特征提取器模型和全局特征发送至客户端,客户端执行s12至s19,服务器端执行s21至s26,直至基于反向传播对本地模型进行第一更新处理后,本地模型收敛,停止本地模型训练。
[0109]
通过上述技术方案,可以训练出泛化性能较好的全局特征提取器模型,并且在此
过程中,客户端可以训练处具有优越泛化性能的局部模型,并将本地图像数据隐私存储至本地机构中,防止数据泄露。
[0110]
图4是根据一示例性实施例示出的客户端、通信模块和服务器端之间的信令交互示意图。
[0111]
如图4所示,s31,服务器初始化全局特征提取器模型和全局特征。
[0112]
s32,服务器将经过初始化的全局特征提取器模型和全局特征发送至客户端。
[0113]
s33,客户端接收服务器发送的全局特征提取器模型和全局特征。
[0114]
s34,客户端根据全局特征提取器模型、全局特征以及本地分类器模型,初始化本地模型。
[0115]
s35,客户端将本地图像数据输入经过初始化的本地模型进行模型训练,确定本地模型的损失函数。
[0116]
s36,客户端根据本地模型的损失函数,基于反向传播对本地特征提取器模型进行第二更新处理以及对本地分类器模型进行第三更新处理。
[0117]
s37,客户端将本地图像数据输入经过第二更新处理后的本地特征提取器模型中,确定本地图像数据中的每一图像的末次更新特征。
[0118]
s38,客户端在本地模型的第一更新处理次数达到预设阈值时,根据图像类别,将具有相同图像类别的本地图像数据的末次更新特征进行第一乘积处理,确定本地特征。
[0119]
s39,客户端在本地模型的第一更新处理次数达到预设阈值时,确定末次经过第二更新处理后的本地特征提取模型。
[0120]
s40,客户端将末次经过第二更新处理后的本地特征提取器模型和本地特征发送至服务器。
[0121]
s41,服务器端接收客户端发送的末次经过第二更新处理后的本地特征提取器模型和本地特征。
[0122]
s42,服务器端将本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型。
[0123]
s43,服务器端根据图像类别将本地特征进行第二乘积处理,确定全局特征。
[0124]
s44,服务器端将全局特征提取器模型和全局特征发送至客户端。
[0125]
s45,当客户端本地模型收敛时,确定目标本地模型。
[0126]
其中,通信模块作为客户端与服务器端信令传输的媒介。
[0127]
当s31至s44执行后,s33至s44循环执行,当确定本地模型的本地特征提取器模型收敛时,执行s45。
[0128]
在一些可能的实施例中,可以为每一客户端训练具有优越泛化性能的局部模型,并对本地数据进行保密存储,基于全局特征共享的个性化联邦学习方法为连接分散的医疗数据源和用户数据隐私保护提供前景。
[0129]
作为一种示例,在胸内淋巴结诊断过程中,通过超声引导支气管内凸探头进行成像,采用深度学习模型进行图像分类从而诊断患者的病情。为保护患者隐私,医院无法公开患者的病情检测图像数据集,深度学习模型无法获取足够的训练数据以进行有效训练。若各个医院基于本医院的患者的病情监测图像数据集进行模型训练,由于训练数据量过少容易导致模型过拟合的现象。
[0130]
基于可信赖的第三方机构,作为服务器端,各个医院作为客户端,构建服务器端与各个医院之间的通信网络,即通信模块,其中,各个医院之间无需直接进行通信,第三方机构与医院之间不传输原始的病情监测图像数据集。
[0131]
基于上文所述的基于全局特征共享的个性化联邦学习方法的s31至s45,实现为每一医院训练一个泛化性能良好的深度学习模型。
[0132]
(1)第三方机构(服务器端)初始化全局特征提取器模型和全局特征。
[0133]
(2)第三方机构(服务器端)将统一的初始化的全局特征提取器模型和全局特征发送至本轮参与模型训练的医院(客户端),由于通信连接的不稳定性,本轮参与模型训练的医院可以为随机参与。
[0134]
(3)医院(客户端)接收第三方机构(服务器端)发送的全局特征提取器模型和全局特征。
[0135]
(4)医院(客户端)根据全局特征提取器模型、全局特征以及本地分类器模型,初始化本地模型。
[0136]
(5)医院(客户端)将本院内的已存储的患者的病情检测图像数据集输入经过初始化的本地模型进行模型训练,确定本地模型的损失函数,即本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项的和:
[0137][0138]
(6)医院(客户端)根据本地模型的损失函数,基于反向传播对本地特征提取器模型进行第二更新出来以及对本地分类器模型进行第三更新处理。
[0139]
(7)医院(客户端)将本院内的已存储的患者的病情检测图像数据集输入经过第二更新处理后的本地特征提取器模型中,确定本院内的已存储的患者的病情检测图像数据集中的每一图像的末次更新特征。
[0140]
(8)医院(客户端)在本地模型的第一更新处理次数达到预设阈值时,根据图像类别,将将具有相同图像类别的患者的病情检测图像数据集中的图像的末次更新特征进行第一乘积处理,确定本地特征。
[0141]
(9)医院(客户端)在本地模型的第一更新处理次数达到预设阈值时,确定末次经过第二更新处理后的本地特征提取模型。
[0142]
(10)医院(客户端)将将末次经过第二更新处理后的本地特征提取器模型和本地特征发送至服务器。
[0143]
(11)第三方机构(服务器端)接收医院(客户端)发送的末次经过第二更新处理后的本地特征提取器模型和本地特征。
[0144]
(12)第三方机构(服务器端)将本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型,其中全局聚合处理可以采用算数平均的方法。
[0145]
(13)第三方机构(服务器端)根据图像类别将本地特征进行第二乘积处理,确定全局特征。
[0146]
(14)第三方机构(服务器端)将将全局特征提取器模型和全局特征发送至参与模型训练的医院(客户端),对医院(客户端)的本地模型继续进行模型训练。
[0147]
(15)当参与模型训练医院的(客户端)的本地模型收敛时,停止模型训练。
[0148]
上述过程中的数据通过医院与第三方机构之间的通信网络进行传输,各个医院之间无需进行数据传输。
[0149]
通过上述技术方案,可以为每一医院训练具有优越泛化性能的局部模型,并同时将其隐私数据存储在本地机构内不外泄,为分散的医疗数据源和医疗隐私保护提供了较好的应用意义。
[0150]
在一些可能的实施例中,如表1所示,基于emnist-l、fashion-mnist、cifra-10和cifar-100图像数据集,采用非独立数据划分方式non-iid-1和non-iid-2,验证本公开的基于全局特征共享的个性化联邦学习方法以及fedavg、fedavg-ft、fedper、lg-fedavg、fedrep、fedbabu、ditto、fedsr-ft、fedpac的方法的泛化表现。
[0151]
其中,在non-iid-1划分方式中,每一本地训练节点从上述emnist-l、fashion-mnist、cifra-10和cifar-100四个不同类别的图像数据集中均匀采样,四个类别的采样数据量相同,在non-iid-2划分方式中,每一本地训练节点根据迪利克雷分布进行采样,每一本地训练节点的训练数据集包括不固定数量的每一类别的图像数据集的图像数据,并且,每一类别的图像数据集的数据量均不相同。
[0152][0153][0154]
表1
[0155]
其中,ours表示本公开的全局特征共享的个性化联邦学习方法。
[0156]
通过表1可得,本公开的全局特征共享的个性化联邦学习方法在emnist-l、fashion-mnist、cifra-10和cifar-100图像数据集中均表现了最优泛化性。
[0157]
基于同一构思,本公开还提供一种基于全局特征共享的个性化联邦学习装置,应用于客户端,图5是根据一示例性实施例示出的一种应用于客户端的基于全局特征共享的个性化联邦学习装置的框图。参照图5,该基于全局特征共享的个性化联邦学习装置100包括:客户端第一接收模块110、客户端初始化模块120、客户端第一确定模块130、客户端第一更新模块140、客户端第二确定模块150。
[0158]
客户端第一接收模块110,用于接收服务器发送的全局特征提取器模型和全局特征;;
[0159]
客户端初始化模块120,用于根据所述全局特征提取器模型、所述全局特征以及本
地分类器,初始化本地模型;
[0160]
客户端第一确定模块130,用于将本地图像数据输入经过所述初始化的本地模型进行模型训练,确定所述本地模型的损失函数,所述损失函数包括所述本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项;
[0161]
客户端第一更新模块140,用于根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理;
[0162]
客户端第二确定模块150,用于当所述本地模型收敛时,确定目标本地模型。
[0163]
通过上述技术方案,通过引入全局特征和条件互信息正则项,将客户端的本地训练节点的数据共享,采用数据分布广泛和数据特征全面的全局特征,向本地训练节点提供更全局更泛化的数据信息,本地训练节点能够采用其他节点的节点数据,从而防止本地模型的过拟合。
[0164]
可选地,所述本地模型包括本地特征提取器模型和本地分类器模型。
[0165]
可选地,客户端第一更新模块140,还用于根据所述本地模型的损失函数,基于反向传播对本地特征提取器模型进行所述第二更新处理以及对所述本地分类器模型进行所述第三更新处理。
[0166]
可选地,所述装置100还包括:
[0167]
客户端第三确定模块,用于将所述本地图像数据输入经过所述第二更新处理后的本地特征提取器模型中,确定所述本地图像数据中的每一图像的末次更新特征;
[0168]
客户端第四确定模块,用于在所述本地模型的第一更新处理次数达到预设阈值时,根据图像类别,将所述具有相同图像类别的所述本地图像数据的末次更新特征进行第一乘积处理,确定本地特征。
[0169]
可选地,所述装置100还包括:
[0170]
客户端第五确定模块,用于在所述本地模型的第一更新处理次数达到预设阈值时,确定末次经过所述第二更新处理后的所述本地特征提取器模型;
[0171]
客户端发送模块,用于将所述末次经过所述第二更新处理后的所述本地特征提取器模型和所述本地特征发送至所述服务器。
[0172]
本公开还提供一种基于全局特征共享的个性化联邦学习装置,应用于服务器端,图6是根据另一示例性实施例示出的一种应用于客户端的基于全局特征共享的个性化联邦学习装置的框图。参照图6,该基于全局特征共享的个性化联邦学习装置200包括:服务器初始化模块210、服务器发送模块220、服务器接收模块230、服务器第一确定模块240、服务器第二确定模块250。
[0173]
服务器初始化模块210,用于初始化全局特征提取器模型和全局特征;
[0174]
服务器发送模块220,用于将经过所述初始化的全局特征提取器模型和所述全局特征发送至所述客户端;
[0175]
服务器接收模块230,用于接收所述客户端发送的末次经过第二更新处理后的本地特征提取器模型和本地特征;
[0176]
服务器第一确定模块240,用于将所述本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;
[0177]
服务器第二确定模块250,用于根据图像类别将所述本地特征进行第二乘积处理,
确定全局特征。
[0178]
可选地,所述服务器发送模块220还用于将所述全局特征提取器模型和所述全局特征发送至所述客户端。
[0179]
关于上述实施例中的装置,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
[0180]
图7是根据一示例性实施例示出的一种基于全局特征共享的个性化联邦学习系统的框图。如图7所示,基于全局特征共享的个性化联邦学习系统,包括:
[0181]
本地模型更新模块,用于在本地训练节点基于反向传播对本地模型进行第一更新处理,所述对本地模型进行第一更新处理包括对本地特征提取器模型进行第二更新处理以及对本地分类器模型进行第三更新处理;
[0182]
本地特征提取模块,用于当所述本地模型更新模块在所述本地训练节点对所述本地模型进行第一更新处理时,提取本地图像数据的更新特征并确定本地特征;
[0183]
全局特征提取器聚合模块,用于将经过所述第二更新处理后的本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;
[0184]
全局特征更新模块,用于将所述本地特征提取模块确定的所述本地特征根据图像类别进行第二乘积处理,确定全局特征;
[0185]
通信模块,用于将所述客户端的经过所述第二更新处理后的所述本地特征提取器模型和所述本地特征传输至所述服务器端,并且将所述服务器端的所述全局特征提取器模型和所述全局特征传输至所述客户端。
[0186]
可选地,所述客户端包括所述本地模型更新模块和所述本地特征提取模块,所述服务器端包括所述全局特征提取器聚合模块和所述全局特征更新模块,所述通信模块还用于连接所述客户端和所述服务器端。
[0187]
以上对本发明的具体实施例进行了描述。需要理解的是,本发明并不局限于上述特定实施方式,本领域技术人员可以在权利要求的范围内做出各种变形或修改,这并不影响本发明的实质内容。
技术特征:
1.一种基于全局特征共享的个性化联邦学习方法,其特征在于,应用于客户端,所述客户端包括本地模型,包括:接收服务器发送的全局特征提取器模型和全局特征;根据所述全局特征提取器模型和本地分类器模型,初始化本地模型;将本地图像数据输入经过所述初始化的本地模型进行模型训练,确定所述本地模型的损失函数,所述损失函数包括所述本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项;根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理;当所述本地模型收敛时,确定目标本地模型。2.根据权利要求1所述的方法,其特征在于,所述本地模型包括本地特征提取器模型和本地分类器模型;所述根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理,包括:根据所述本地模型的损失函数,基于反向传播对本地特征提取器模型进行所述第二更新处理以及对所述本地分类器模型进行所述第三更新处理。3.根据权利要求2所述的方法,其特征在于,所述方法还包括:将所述本地图像数据输入经过所述第二更新处理后的本地特征提取器模型中,确定所述本地图像数据中的每一图像的末次更新特征;在所述本地模型的第一更新处理次数达到预设阈值时,根据图像类别,将所述具有相同图像类别的所述本地图像数据的末次更新特征进行第一乘积处理,确定本地特征。4.根据权利要求3所述的方法,其特征在于,所述方法还包括:在所述本地模型的第一更新处理次数达到预设阈值时,确定末次经过所述第二更新处理后的所述本地特征提取器模型;将所述末次经过所述第二更新处理后的所述本地特征提取器模型和所述本地特征发送至所述服务器。5.一种基于全局特征共享的个性化联邦学习方法,其特征在于,应用于服务器端,所述服务器端包括全局特征提取器模型,包括:初始化全局特征提取器模型和全局特征;将经过所述初始化的全局特征提取器模型和所述全局特征发送至所述客户端;接收所述客户端发送的末次经过第二更新处理后的本地特征提取器模型和本地特征;将所述本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;根据图像类别将所述本地特征进行第二乘积处理,确定全局特征。6.根据权利要求5所述的方法,其特征在于,所述方法还包括:将所述全局特征提取器模型和所述全局特征发送至所述客户端,所述客户端执行接收服务器发送的所述全局特征提取器模型和所述全局特征的步骤。7.一种基于全局特征共享的个性化联邦学习装置,其特征在于,应用于客户端,所述客户端包括本地模型,包括:客户端第一接收模块,用于接收服务器发送的全局特征提取器模型和全局特征;;客户端初始化模块,用于根据所述全局特征提取器模型、所述全局特征以及本地分类
器,初始化本地模型;客户端第一确定模块,用于将本地图像数据输入经过所述初始化的本地模型进行模型训练,确定所述本地模型的损失函数,所述损失函数包括所述本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项;客户端第一更新模块,用于根据所述本地模型的损失函数,基于反向传播对所述本地模型进行第一更新处理;客户端第二确定模块,用于当所述本地模型收敛时,确定目标本地模型。8.一种基于全局特征共享的个性化联邦学习装置,其特征在于,应用于服务器端,所述服务器端包括全局特征提取器模型,包括:服务器初始化模块,用于初始化全局特征提取器模型和全局特征;服务器发送模块,用于将经过所述初始化的全局特征提取器模型和所述全局特征发送至所述客户端;服务器接收模块,用于接收所述客户端发送的末次经过第二更新处理后的本地特征提取器模型和本地特征;服务器第一确定模块,用于将所述本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;服务器第二确定模块,用于根据图像类别将所述本地特征进行第二乘积处理,确定全局特征。9.一种基于全局特征共享的个性化联邦学习系统,包括:本地模型更新模块,用于在本地训练节点基于反向传播对本地模型进行第一更新处理,所述对本地模型进行第一更新处理包括对本地特征提取器模型进行第二更新处理以及对本地分类器模型进行第三更新处理;本地特征提取模块,用于当所述本地模型更新模块在所述本地训练节点对所述本地模型进行第一更新处理时,提取本地图像数据的更新特征并确定本地特征;全局特征提取器聚合模块,用于将经过所述第二更新处理后的本地特征提取器模型进行全局聚合处理,确定全局特征提取器模型;全局特征更新模块,用于将所述本地特征提取模块确定的所述本地特征根据图像类别进行第二乘积处理,确定全局特征;通信模块,用于将所述客户端的经过所述第二更新处理后的所述本地特征提取器模型和所述本地特征传输至所述服务器端,并且将所述服务器端的所述全局特征提取器模型和所述全局特征传输至所述客户端。10.根据权利要求9所述的系统,其特征在于,所述客户端包括所述本地模型更新模块和所述本地特征提取模块,所述服务器端包括所述全局特征提取器聚合模块和所述全局特征更新模块,所述通信模块还用于连接所述客户端和所述服务器端。
技术总结
本发明涉及一种基于全局特征共享的个性化联邦学习方法、装置及系统。基于全局特征共享的个性化联邦学习方法应用于客户端,包括:接收服务器发送的全局特征提取器模型和全局特征;根据全局特征提取器模型和本地分类器模型,初始化本地模型;将本地图像数据输入经过初始化的本地模型进行模型训练,确定本地模型的损失函数,损失函数包括本地图像数据的训练标签和真实标签之间的交叉熵损失、条件互信息正则项;根据本地模型的损失函数,基于反向传播对本地模型进行第一更新处理;当本地模型收敛时,确定目标本地模型。本公开通过引入全局特征和条件互信息正则项,共享全局特征,提高本地模型的泛化表现,并防止本地模型的过拟合。合。合。
技术研发人员:李成林 张豪 戴文睿 邹君妮 熊红凯
受保护的技术使用者:上海交通大学
技术研发日:2023.06.30
技术公布日:2023/9/19
版权声明
本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
航空之家 https://www.aerohome.com.cn/
飞机超市 https://mall.aerohome.com.cn/
航空资讯 https://news.aerohome.com.cn/
上一篇:一种模拟化学地雷的制作方法 下一篇:一种热熔胶涂布机牵引机构的制作方法