一种分类模型的训练方法和眼底图像分类方法
未命名
09-20
阅读:68
评论:0

1.本发明涉及增量学习领域,具体来说,涉及应用增量学习的医学成像领域,更具体地说,涉及一种分类模型的训练方法和一种眼底图像分类方法。
背景技术:
2.深度神经网络(dnns)在许多机器学习分类任务中表现出色,例如,深度神经网络在医学成像应用中显示出的性能达到了人类水平,以及,最近的研究表明,基于深度学习的模型能够成功应用于视网膜疾病(如糖尿病视网膜病变dr和青光眼)筛查。其中,视网膜疾病筛查中应用的模型是假设在训练前所有类别都是已知的条件下才可实现的。然而,这一假设在医学领域经常被违背。在实际训练过程中,预先准备所有类别的数据是极难实现的,例如,由于一些疾病在不同的发病阶段,疾病的病变程度也会发生细微的变化,因此,疾病的病变程度发生变化的数据进行预先准备是很难实现的。此外,如果一个模型只学会了识别输入样本(例如疾病对应的样本)所属的泛化类别,而不能对所属的具化类别(例如,发病阶段和发病程度不同时的类别)进行准确识别,这会产生极其严重的后果。因为训练后模型的针对输入样本的输出结果会影响后续的应用效果。例如,同一疾病在不同的发病阶段所使用的治疗策略是不同的,而启动正确的治疗方案是取得良好治疗结果的关键。为此,构造机器学习分类任务中的类别增量学习方法,可以识别所属的具化类别(例如,疾病的病变程度与发病阶段对应的类别),从而在后续的应用中,可以在为患者提供持续有效、及时准确的健康检测以及疾病预警方面提供极大帮助。增量学习方法中需要解决的主要挑战是灾难性遗忘。直观地说,灾难性遗忘是由特征空间中新旧类的表示之间的重叠或混淆引起的,在学习新类时,以前类的决策边界可能会发生巨大变化,统一的分类器会有严重的偏差。
3.为解决这一挑战(灾难性遗忘),现有方法中包括以下两个不同的突破方向:第一种是将参数偏向于在旧类上学习的方向;第二种是保持一个来自以前任务的小数据缓冲区(这也被称为经验回放)的方向。针对第一种的采用的正则化的策略,在使用多头分类器以及推理时利用可用的任务标识符的场景中是有效的。针对第二种采用的经验回放方法,最为普遍的做法是以保存少数用于后续模型训练的真实数据来实现经验回放。此外,也有一些方法采用了额外的生成模型,如对抗生成网络(gnn)来生成数据实现经验回放。
4.然而,正如一些文献所注意到的,解决灾难性遗忘的方法在类别增量学习(cil)场景下性能较差。目前较为有效的正则化策略是基于知识蒸馏,强制学生模型完全模仿教师模型。具体来说,经过蒸馏的学生模型旨在模仿教师模型在训练样本上模型全连接层输出的logits,以获得与教师模型类似的泛化性能。然而,完全模仿教师模型的输出可能不是最优的,因为教师模型可能会自信地错误预测一些类,这会增加增量过程中错误信息传递的风险。对于经验回放的方法,常常需要大量的内存来重放之前看到的或建模的数据,以避免灾难性的遗忘问题。然而在某些实际场景(例如,物联网应用的设备上或隐私问题)中,由于内存限制,数据存储可能会受到限制。这样使人们专注于增量地合并新信息,而不存储旧知识,这被称为非基于保存样例的增量学习(non-exemplar-based incremental learning)。
5.在非基于保存样例的增量学习中,为了进一步从教师模型中学到更多有用的知识,以确定教师模型的哪些知识有助于建立一个更好的学生模型,除了传统的基于logit的蒸馏方法外,基于特征的蒸馏方法也受到了很多关注。这是因为教师模型的特征比基于logit的模型具有更多的信息,使用特征蒸馏可以使学生模型学习到更丰富的信息。然而,大多数基于特征蒸馏的研究都是手动链接教师模型和学生模型的特征,并通过单独的链接进行蒸馏,存在将不正确的中间过程强加给学生的风险。此外,在所有可能的环节中选择少数环节(换而言之,已有的方法中,是在学生模型蒸馏的步骤中使用人为的方式,选择一些人为认为有代表性的环节进行步骤选择操作),也会限制教师模型充分利用自己的全部知识;而且,在大多数知识蒸馏的情况下,学生模型和教师模型的特征具有不同的宽度、高度和渠道,通常是应用卷积层或全连接层来匹配它们的大小。这样使得在特征调整的过程中,教师模型的一些有用信息可能会丢失。
6.基于上述问题,将增量学习应用于医学成像领域,其中特别是视网膜异常筛查时,专利cn106022368a伴随的由于类别差异导致的灾难性遗忘问题。
技术实现要素:
7.因此,本发明的目的在于克服上述现有技术的缺陷,提供一种分类模型的训练方法和一种眼底图像分类方法。
8.本发明的目的是通过以下技术方案实现的:
9.根据本发明第一方面,提供一种分类模型的训练方法,所述分类模型用于眼底图像分类,所述方法包括:获取预训练的分类模型作为教师模型,其包括用于对输入眼底图像提取图像特征的特征提取网络以及用于根据图像特征识别该图像特征对应眼底图像所属的眼底类别的分类器,所述分类器包括全连接层和softmax层,所述教师模型能识别的眼底类别归为旧类;获取学生模型,其特征提取网络的可训练参数用教师模型中特征提取网络的可训练参数初始化,并且其分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别;针对所述预训练中用到的每个旧类,获取该旧类对应的伪样本,所述伪样本是利用教师模型的特征提取网络对属于该旧类的多个眼底图像提取的图像特征生成的;利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型,所述总损失根据以下损失加权求和确定:旧类分类损失、新类分类损失、第一特征蒸馏损失和第二特征蒸馏损失;其中,所述第一特征蒸馏损失根据旧类的伪样本在以下两者上的输出之间的差异进行确定:所述教师模型的全连接层的输出、所述学生模型的全连接层在旧类上的输出;所述第二特征蒸馏损失根据教师模型中特征提取网络的提取到的所有教师特征和经增量训练的学生模型中特征提取网络的提取到的所有学生特征之间的总的差异进行确定,其中,每个教师特征为一个伪样本经教师模型中特征提取网络提取得到的图像特征,每个学生特征为伪样本和属于新类的眼底图像分别经学生模型中特征提取网络提取得到的图像特征。
10.在本发明的一些实施例中,通过如下步骤计算所述总的差异:基于注意力机制,确定每个教师特征对各个学生特征的注意力值,以及利用该教师特征对各个学生特征的注意力值构成该教师特征对所有学生特征的注意力向量并进行归一化得到归一化后的注意力
向量;基于每个教师特征和每个学生特征的确定空间距离;计算空间距离和归一化后的注意力向量中对应元素的乘积,并将得到的所有乘积进行求和得到总的差异。
11.在本发明的一些实施例中,基于注意力机制,每个教师特征对各个学生特征的注意力值通过如下步骤计算:将学生特征进行数据转换,得到该学生特征在注意力机制中的一个key;将教师特征进行数据转换,得到该教师特征在注意力机制中的一个query;计算每个query对各个key的注意力值。
12.在本发明的一些实施例中,通过如下规则进行数据转换:
[0013][0014][0015]
其中,q
t
表示在注意力机制中的query,表示第t个教师特征,p
hw
(
·
)表示全局平均池化,表示的线性变换参数,表示的线性变换参数空间矩阵,fq(
·
)表示第一激活函数,ks表示在注意力机制中的key,表示第s个学生特征,表示的线性变换参数,表示的线性变换参数空间矩阵,表示的线性变换参数的空间矩阵,fk(
·
)表示第二激活函数,d表示线性变换参数空间矩阵的维度。
[0016]
在本发明的一些实施例中,通过以下规则计算每个query对所有key的注意力向量并进行归一化:
[0017][0018]
其中,softmax(.)表示归一化函数,表示q
t
的转置,表示双线性权值,k
t,1
表示对对应的key值,表示第t个教师特征的位置编码,表示第s个学生特征的位置编码,k
t,s
表示对对应的key值,表示和的乘积,表示的转置。
[0019]
在本发明的一些实施例中,通过以下方法计算第二特征蒸馏损失:
[0020][0021]
其中,α
t,s
表示第t个教师特征对第s个学生特征的归一化后的注意力值,表示空间距离,||.||2表示求l2范数,表示通道平均池化层与l2归一化的组合函数v/||v||2,v表示对进行平均池化得到的向量,表示对使用下采样或上采样得到的特征。
[0022]
在本发明的一些实施例中,按照如下步骤获得旧类的伪样本:
[0023]
t1、基于预训练中用到的旧类对应的多个图像特征计算该旧类对应的类均值向量;t2、从高斯分布中随机采样得到至少一个噪声向量,所述噪声向量与类均值向量的维度相同;t3、基于预先定义的增强尺度、从高斯分布中随机采样得到噪声向量以及旧类对应的类均值向量按照预设的增强规则计算得到该旧类的伪样本。
[0024]
在本发明的一些实施例中,所述方法包括对学生模型进行多次增量训练,其中,教师模型为当前次增量训练的上一次增量训练后的学生模型,通过以下规则计算旧类对应的类均值向量:
[0025][0026]
其中,表示旧类k
old
对应的类均值向量,表示旧类k
old
的样本数量,f(x
b-1,n
;θ
b-1
)表示旧类k
old
中第n个样本的样本特征,θ
b-1
表示在第b次增量训练中教师模型的特征提取网络的参数。
[0027]
在本发明的一些实施例中,所述预设的增强规则为:
[0028][0029]
其中,表示第b次增量训练中使用到的旧类k
old
对应伪样本,rb表示第b次增量训练中预先定义的增强尺度,e表示采样至高斯分布的一个随机噪声向量,e和的维度相同。
[0030]
在本发明的一些实施例中,通过如下规则设置预先定义增强尺度:
[0031][0032]
其中,k
old
和k
mew
分别表示第b次增量训练中旧类和新类的数量,r
b-1
表示第b-1次增量训练中预先定义的增强尺度,∑
b,k
表示第b次增量训练中第k个新类的协方差矩阵,tr(∑
b,k
)表示∑
b,k
的秩,m表示深度特征空间的维数,当b=1时,k1表示第1次预训练中类别数量。
[0033]
在本发明的一些实施例中,按照如下规则计算总损失:
[0034][0035]
其中,表示总损失,表示新类分类损失,表示旧类分类损失,λ表示第一超参数,表示第一特征蒸馏损失,β第二超参数,表示第二特征蒸馏损失,γ表示第三超参数。
[0036]
在本发明的一些实施例中,通过以下规则计算第一特征蒸馏损失:
[0037][0038]
其中,表示在增量训练中使用到的伪样本数据,表示增量训练中使用到的伪样本经学生模型中的特征提取网络提取后的样本特征,表示增量训练中学生模型的分类器中全连接层的参数,gb表示经学生模型的分类器处理中对应的全连接层在旧类上的输出,表示增量训练中使用到的伪样本数据经教师模型中的特征提取网络提取后的样本特征,表示增量训练中教师模型的分类器中全连接层的参数,增量训练中教师模型的分类器中全连接层的参数,表
示经教师模型处理时全连接层的输出,经教师模型处理时全连接层的输出,表示表示与之间对应元素之间的差异值之和,||.||2表示求l2范数。
[0039]
在本发明的一些实施例中,旧类包括以下类别中任意两种,剩余类别为新类:无糖尿病视网膜病变、轻度糖尿病视网膜病变、中度糖尿病视网膜病变、重度糖尿病视网膜病变、增发性糖尿病视网膜病变。
[0040]
根据本发明第二方面,提供一种眼底图像分类方法,该方法包括:
[0041]
s1、获取待分类的眼底图像;s2、采用如本发明第一方面提供的方法得到的经增量训练的学生模型对步骤s1中获得待分类的眼底图像进行图像分类,其中,学生模型的特征提取网络用于根据待分类的眼底图像提取其图像特征,学生模型的分类器用于根据该图像特征识别所述待分类的眼底图像所属的眼底类别。
[0042]
在本发明的一些实施例中,图像分类的结果包括:无糖尿病视网膜病变、轻度糖尿病视网膜病变、中度糖尿病视网膜病变、重度糖尿病视网膜病变、增发性糖尿病视网膜病变。
[0043]
与现有技术相比,本发明的优点在于:本发明利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,以及引入旧类分类损失、第二特征蒸馏损失、新类分类损失可以很好的指导分类模型进行增量学习,使得增量训练后的分类模型对眼底图像或者细粒度眼底图像的所有类别的整体分类能力(精确度)有明显的提升,并缓解了增量学习中的遗忘问题。
附图说明
[0044]
以下参照附图对本发明实施例作进一步说明,其中:
[0045]
图1为根据本发明实施例的一种分类模型的结构及增量训练过程示意图;
[0046]
图2为根据本发明实施例的一种实验结果示意图;
[0047]
图3为根据本发明实施例的又一种实验结果示意图。
具体实施方式
[0048]
为了使本发明的目的,技术方案及优点更加清楚明白,以下通过具体实施例对本发明进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。
[0049]
正如背景技术中指出的,将增量学习应用于医学成像领域,其中特别是视网膜异常筛查时,眼底图像增量获得过程中伴随的由于类别差异导致的灾难性遗忘问题。为解决此问题,本技术的发明人首先对现有技术中针对数据增量的眼底图像疾病识别方案进行分析,例如专利公开号为cn113066025a的专利文献,提出了一种基于增量学习与特征、注意力传递的图像去雾方法,该方法并没有很好的解决增量中灾难性遗忘的问题;专利公开号为cn112990280-a的专利文献在图像分类问题上采用了类别增量的方法,但没有更进一步的关注类别增量过程中的细粒度分类问题。由此,发现大多都没有充分利用增量过程中新旧
数据之间的特征关系与语义关系(换而言之,现有的类别增量过程没有充分利用基于特征蒸馏的增量学习和基于关系蒸馏的增量学习),仍没有很好地解决灾难性遗忘的问题。此外,目前解决灾难性遗忘的问题,广泛采用的是基于蒸馏损失与交叉熵损失共同作为损失函数指导模型更新的方法,该方法具有以下缺点:1.没有探究增量过程中深度学习模型提取到的特征中哪些是关键特征;2.没有充分利用增量过程中模型学到的表征特征知识,来实现模型中关键特征的迁移。因此,通过对增量过程中模型得到具有哪些属性的表征特征更能缓解增量过程中的灾难性遗忘是一个很关键的解决方向,以及进一步框定模型学到表征特征知识的边界范围,可以减少增量过程中错误信息传递的风险,从而缓解灾难性遗忘。
[0050]
由此,本发明针对眼底疾病图像增量获得过程中伴随的由于类别差异导致的灾难性遗忘问题,提出了一种分类模型的训练方法,所述分类模型用于眼底图像分类,以及基于该分类模型的眼底图像的分类方法。训练时,首先通过构造基于注意力的神经网络模型(该神经网络模型为现有的基础网络结构,用于模拟注意机制中的计算过程)来学习教师特征与学生特征之间的相对相似性(换而言之,采用注意力机制来学习教师特征与学生特征之间的相对相似性),并应用识别到的相似性来控制所有特征对(一组教师特征和一组学生特征组成一个特征对)之间的蒸馏强度,并且通过为每个旧类在表征空间中保存代表类的伪样本集和,来维持之前类别的决策边界,缓解了增量过程中数据受到的内存或隐私问题的限制(换而言之,模型学到表征特征知识是利用伪样本来维持之前类别的决策边界,也降低了增量过程中数据对内存或隐私的需求,从而很好地处理了由于类别差异导致的灾难性遗忘的问题).其中,每个教师特征为一个伪样本经教师模型中特征提取网络提取得到的图像特征,每个学生特征为伪样本和属于新类的眼底图像分别经学生模型中特征提取网络提取得到的图像特征。此外,利用伪样本来维持之前类别的决策边界,缓解增量过程中模型的遗忘效果,也提升模型的精确度。需要说明的是,增量过程中模型的遗忘是由于随着类别增量步骤的增加,类别也不断增加,模型对所有类别的整体分类能力(精确度)也随着增量步骤的增加逐步下降。采用伪样本参与后续增量学习,可以对此遗忘进行缓解,以及降低内存或隐私问题对增量过程中数据的限制。
[0051]
根据本发明的一个实施例,提供一种分类模型的训练方法,所述分类模型用于眼底图像分类,所述方法包括:获取预训练的分类模型作为教师模型,其包括用于对输入眼底图像提取图像特征的特征提取网络以及用于根据图像特征识别该图像特征对应眼底图像所属的眼底类别的分类器,所述分类器包括全连接层和softmax层,所述教师模型能识别的眼底类别归为旧类;获取学生模型,其特征提取网络的可训练参数用教师模型中特征提取网络的可训练参数初始化,并且其分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别;针对所述预训练中用到的每个旧类,获取该旧类对应的伪样本,所述伪样本是利用教师模型的特征提取网络对属于该旧类的多个眼底图像提取的图像特征生成的;利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型,所述总损失根据以下损失加权求和确定:旧类分类损失、新类分类损失、第一特征蒸馏损失和第二特征蒸馏损失;其中,所述第一特征蒸馏损失根据旧类的伪样本在以下两者上的输出之间的差异进行确定:所述教师模型的全连接层的输出、所述学生模型的全连接层在旧类上的输出;所述第二特征蒸馏损失根据教师模型中特征提取网络的
提取到的所有教师特征和经增量训练的学生模型中特征提取网络的提取到的所有学生特征之间的总的差异进行确定,其中,本发明实施中,每个教师特征为一个伪样本经教师模型中特征提取网络提取得到的图像特征,每个学生特征为伪样本和属于新类的眼底图像分别经学生模型中特征提取网络提取得到的图像特征。其中,利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,以及引入旧类分类损失、第二特征蒸馏损失、新类分类损失可以很好的指导分类模型进行增量学习,使得增量训练后的分类模型对眼底图像或者细粒度眼底图像的所有类别的整体分类能力(精确度)有明显的提升,并缓解了增量学习中的遗忘问题。为了更好地说明本发明的实施例,首先介绍下增量训练用到的数据集和分类模型的结构。
[0052]
对于数据集而言,本发明实施例采用的是眼底图像及眼底图像对应的标签构成数据集,也称原始数据。优选的,本发明实施例中根据糖尿病视网膜病变(dr)的严重程度眼底图像对应的标签包括无糖尿病视网膜病变(无dr)、轻度糖尿病视网膜病变(轻度dr)、中度糖尿病视网膜病变(中度dr)、重度糖尿病视网膜病变(重度dr)、增发性糖尿病视网膜病变(增发性dr)。原始数据的数据流由不同的任务组成,原始数据的数据流可以描述为:其中是系统在步骤b(本发明实施例步骤b也称第b次增量训练,其他地方不在赘述)收到的数据,xb表示步骤b中的所有眼底图像样本,yb表示步骤b中的所有眼底图像样本对应的类别标签。数据集db包含nb个有标签数据用来训练,且yb,j∈cb,其中cb是步骤b中的类别集合(也称类别标签集合),且不同步骤中的类别集合互不相交。需要说明的是,不同步骤中的类别集合互不相交是增量训练中最难的训练场景。本发明实施例以最难的训练场景为例来讲解分类模型的训练方法。其他较容易得训练场景,如不同步骤中的类别集合部分相交的场景,本发明也同样适用。其中,最难的训练场景是指,类别增量学习的增量过程中,每个增量步骤内的类别各不相同的增量,例如,第一个增量步骤使模型学习分类a类和b类,第二个增量步骤新增数据为c类和d类的训练数据,但希望模型在经过训练后,能同时对a类、b类、c类和d类同时进行分类。每个类别数据都有各自的最本质关键的特征,在增量过程中,新增量步骤学到的模型可能会对之前类别(猫和狗类)的本质关键特征产生遗忘,从而影响分类精度。为了广泛评估了本发明提供的方法的有效性,本发明实施例中可以采用2个数据集。其中包括公共医学图像数据集eyepacs和另一个私有数据集来自爱尔医院提供的内部数据集,eyepacs包含35125张眼底图像,并根据糖尿病视网膜病变(dr)的严重程度进行分级,眼底图像分为无dr、轻度dr、中度dr、重度dr、增发性dr共5类;另一个私有数据集来自爱尔医院提供的内部数据集,也包括以上5类,共包含5479张眼底图像。
[0053]
以步骤b为例讲解本发明实施例达到的目标,在步骤b中目标是在新的数据集db上最小化损失函数并准确地分类属于新的类集cb的例子,且不影响之前步骤中学到的知识,并可以在能对当前步骤b中数据流(例如,训练数据)进行准确分类的前提下从一定程度上提升之前学过的知识,从而缓解增量过程中的遗忘问题。
[0054]
对于分类模型而言,在类别增量学习(cil)中,cil的目标是依次学习一个统一的分类模型,对训练结束时已经学习的所有类的测试样本进行分类。分类模型包括特征提取器f
θ
(特征提取器也称特征提取网络)和一个统一的分类器g
φ
。为了方便描述,在增量训练
的过程,在当前增量训练的步骤中,将训练的分类模型称为学生模型,当前增量训练的前一增量训练中的分类模型称为教师模型。如图1所示,在步骤b-1中,教师模型m
b-1
包括特征提取网络和分类器,分类器包括全连接层和softmax层,此时教师模型的分类器所能识别的眼底类别也即是此时教师模型所能识别的眼底类别归为旧类;在步骤b中,学生模型mb包括特征提取网络和分类器,分类器包括全连接层和softmax层,此时学生模型的分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别。学生模型的特征提取网络用于根据眼底图像提取图像特征,学生模型的分类器用于根据图像特征识别该图像特征对应眼底图像所属的眼底类别。学生模型中特征提取网络的可训练参数用教师模型中特征提取网络的可训练参数初始化。在步骤b中增量训练学生模型时,将教师模型的参数进行冻结,不会对教师模型参数进行更新。学生模型的分类器能识别的新类也可以理解为:对于步骤b-1而言,步骤b中新增的类别。根据本发明的一个实施例,旧类包括以下类别中任意两种,剩余类别为新类:无糖尿病视网膜病变、轻度糖尿病视网膜病变、中度糖尿病视网膜病变、重度糖尿病视网膜病变、增发性糖尿病视网膜病变。其中,新类和旧类中的类别数量可以根据需求设置为其他数量,此处不再赘述。
[0055]
为了更好地讲解对本发明实施例进行理解,以下结合附图主要从伪样本的生成、增量训练过程、实验验证几个方面进行讲解。
[0056]
一、伪样本的生成
[0057]
正如前面提到的,没有充分利用增量过程中模型学到的表征特征知识,来实现模型中关键特征的迁移,会导致灾难性遗忘的问题。换而言之,当用旧模型(例如教师模型)训练新任务时,就会发生知识遗忘,继而导致了旧类决策边界的明显转移以及导致顺序学习的特征提取网络很容易偏向新类。为缓解增量学习中的存在灾难性遗忘的问题,与现有的基于记忆的方法不同,本发明实施专注于潜在特征空间,并生成原始数据对应的伪样本参与增量训练来缓解灾难性遗忘。在本发明实施例中也是称之为伪样本增强(psa)。psa只是从旧特征提取器(相对于学生模型来说,教师模型中的特征提取网络作为旧特征提取器)的输出中增强特征向量,并用这些增强后的特征向量(也称伪样本)来表示以前任务数据的分布。新数据(其中包括伪样本和增量步骤中新增类别对应眼底图像数据)在潜在空间的联合训练有助于减轻灾难性遗忘。根据本发明的一个实施例,按照如下步骤获得旧类的伪样本:t1、基于预训练中用到的旧类对应的多个图像特征计算该旧类对应的类均值向量;t2、从高斯分布中随机采样得到至少一个噪声向量,所述噪声向量与类均值向量的维度相同;t3、基于预先定义的增强尺度、从高斯分布中随机采样得到噪声向量以及旧类对应的类均值向量按照预设的增强规则计算得到该旧类的伪样本。
[0058]
需要说明的是,本发明实施例提供的分类模型的训练方法可以对学生模型进行多次增量训练,其中,所述预训练的分类模型为当前次增量训练的上一次增量训练后的学生模型,此处对学生模型的增量训练可以理解为对分类模型一种预训练。根据本发明的一个实施例,通过以下规则计算旧类对应的类均值向量:
[0059][0060]
其中,表示旧类k
old
对应的类均值向量,表示旧类k
old
的样本数
量,表示旧类k
old
中第n个样本的样本特征,θ
b-1
表示在第b次增量训练中教师模型的特征提取网络的参数。需要说明的是旧类对应的多个图像特征可以是旧类对应的全部图像特征,也可以是代表能代表旧类的图像特征,即旧类k
old
的样本数量小于等于旧类对应的全部样本的总数量。此外,公式(1)中,也可以增加第一校正系数等数学变形来增加类均值向量代表旧类的能力,例如类均值向量代表旧类的能力,例如其中为第一校正系数。
[0061]
当学习一个新任务(第b-1次增量训练而言,第b次增量训练称为新任务)时,每个旧类的原型(其中,原型可以理解为旧类的类均值特征向量)被增强。根据本发明的一个实施例,所述预设的增强规则为:
[0062][0063]
其中,表示第b次增量训练中使用到的旧类k
old
对应伪样本,rb表示第b次增量训练中预先定义的增强尺度,e表示采样至高斯分布的一个随机噪声向量,e和的维度相同。
[0064]
需要说明的是,psa假设从旧特征提取器获得的特征嵌入是从多元高斯分布中提取的。在步骤b中,只有数据db可以进行训练,为了减轻学习新任务时特征空间的变形,本发明实施例,优选通过每个类的嵌入输出均值来计算和记忆类嵌入,对于以前任务(对于第b次增量训练而言,第b-1次增量训练称为以前任务)中的每个类,在新任务(第b-1次增量训练而言,第b次增量训练称为新任务)上训练模型之前,生成服从高斯分布的每个旧类的随机特征嵌入以进行伪增强。具体来说,给定一个来自标准正态分布机特征嵌入以进行伪增强。具体来说,给定一个来自标准正态分布的随机噪声e,它与每个类嵌入具有相同的维数,本发明实施例应用了线性变换(公式2)来增强每个旧类的特征嵌入,该旧类的类别标签作为对应伪样本的标签。
[0065]
其中,r(rb表示第b次增量训练中预先定义的增强尺度)可以控制旧类的类均值特征向量的增强尺度。具体来说,增强尺度r可以预先定义,也可以计算为类表示的平均方差。根据本发明的一个实施例,通过如下规则设置预先定义增强尺度:
[0066][0067]
其中,k
old
和k
mew
分别表示第b次增量训练中旧类和新类的数量,r
b-1
表示第b-1次增量训练中预先定义的增强尺度,∑
b,k
表示第b次增量训练中第k个新类的协方差矩阵,tr(∑
b,k
)表示∑
b,k
的秩,m表示深度特征空间的维数,当b=1时,k1表示第1次预训练中类别数量,为类表示的平均方差。本发明实施例中,psa组件旨在利用嵌入空间中每个类的统计信息生成更多的伪样本,考虑协方差矩阵∑
b,k
表示k类在步骤b的静态表示,为特征嵌入中任意一对元素之间的方差;tr运算是计算协方差矩阵的迹。在类别增量学习cil实验的过程中,r在不同的步骤中略有变化。因此,第一次增量训练或者预训练中,类表示的平均方差为:此外,公式(3)中,也可以增加第二校正系数等数学变形来更好地控制旧类的类均值特征向量的增强尺度,例如,
表示第二校正系数。
[0068]
二、增量训练过程
[0069]
当开始对新任务进行训练时,利用生成的旧类的伪样本和属于新类的眼底图像对学生模型进行增量训练。训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型,所述总损失根据以下损失加权求和确定:旧类分类损失、新类分类损失、第一特征蒸馏损失和第二特征蒸馏损失。其中,对新任务进行训练,也就是开始新的增量训练,采用从新类对应的眼底图像和(从旧类中)增强的伪样本进行联合训练,通过学生模型中的特征提取网络进行特征提取,然后将提取到特征输入学生模型的分类器(是一个用于跨所有类(包括旧类和新类)的统一分类器),学生模型的分类器根据所述特征以获得相应的分类概率(其中,包括伪样本在旧类上的分类概率以及新类对应的眼底图像在新类上的分类概率)以及旧类的伪样本在学生模型的全连接层上对应旧类上的输出logit。此外,将伪样本输入教师模型,获得旧类的伪样本在教师模型中的特征提取后的图像特征以及旧类的伪样本在学生模型的全连接层上的输出logit。所述第一特征蒸馏损失根据旧类的伪样本在以下两者上的输出之间的差异进行确定:所述教师模型的全连接层的输出、所述学生模型的全连接层在旧类上的输出。为了探究增量过程中深度学习模型提取到的特征中哪些是关键特征,教师模型中特征提取网络的提取到的所有教师特征和经增量训练的学生模型中特征提取网络的提取到的所有学生特征之间的总的差异进行确定第二特征蒸馏损失来参与指导分类模型(学生模型)进行更新。
[0070]
根据本发明的一个实施例,按照如下规则计算总损失:
[0071][0072]
其中,表示总损失,表示新类分类损失,表示旧类分类损失,λ表示第一超参数,优选λ=0.03,表示第一特征蒸馏损失,β第二超参数,优选β=0.6,表示第二特征蒸馏损失,γ表示第三超参数,优选γ=0.5。需要说明是,公式(4)中b表示第b次增量训练,以第b次增量训练的总损失进行示意性的表示总损失的计算过程。此外,公式(4)中,也可以将超参数的进行数学变形后的值来更好的指导分类模型进行训练。例如,其中,λ
′
、β
′
和γ
′
表示将超参数进行数学变形后的值。
[0073]
为了更好的说明本发明实施例,以下继续结合附图1以第b次增量训练为例,分别讲解新类分类损失、旧类分类损失、第一特征蒸馏损失、第二特征蒸馏损失的具体获得过程。
[0074]
如图1所示,在步骤b-1中,也就是第b-1次增量训练时,训练教师模型用到的数据包括从原始数据中获取对应的眼底图像样本及眼底图像样本对应的眼底类别,训练教师模型用到眼底类别称为旧类,例如图1中示出的旧类1和旧类2;伪样本1表示对旧类1对应的类均值特征向量被增强后的数据;伪样本2表示对旧类2对应的类均值特征向量被增强后的数据。在步骤b中,也就是第b次增量训练时,训练学生模型用到的数据包括从原始数据中获取对应的眼底图像样本及眼底图像样本对应的眼底类别,此时训练学生模型用到眼底类别称为新类,例如图1中示出的新类1和新类2,优选新类1和新类2与旧类1和旧类2均不相同。在本发明实施例中,第b次增量训练时,训练学生模型用到的数据还用到了旧类的伪样本,例
如伪样本1和伪样本2来缓解增量过程中灾难性遗忘的问题。其中,根据教师模型中特征提取网络的提取到的所有教师特征和经增量训练的学生模型中特征提取网络的提取到的所有学生特征之间的总的差异确定第二特征蒸馏;根据旧类的伪样本在以下两者上的输出之间的差异确定第一特征蒸馏损失:所述教师模型的全连接层的输出、所述学生模型的全连接层在旧类上的输出;利用生成的旧类的伪样本和属于新类的眼底图像经学生模型处理后的分类概率,生成分类损失,分类损失包括旧类分类损失和新类分类损失。
[0075]
根据本发明的一个实施例,旧类分类损失通过如下规则计算:
[0076][0077]
其中,fi表示旧类集ci(也就是第i个旧类的伪样本组成的集合)的被学生模型的特征提取网络提提取后的特征,表示交叉熵损失,φb表示学生模型的分类器的参数,yi表示第i个旧类的伪样本对应的标签。直观来看,在特征空间中,旧类的类均值特征向量被软方差增强,其代表了对生成的特征的真实性的置信度。对于当前任务(第b次增量训练)的训练,将伪样本被提取后的特征馈送到分类器中,以保持到目前为止已经学习的所有类之间的区分和平衡。
[0078]
根据本发明的一个实施例,新类分类损失通过如下规则计算:
[0079][0080]
其中,f(xb;θb)表示新类集cb(也就是新类的眼底图像样本组成的集合)的被学生模型的特征提取网络提提取后的特征,表示交叉熵损失,φb表示学生模型的分类器的参数,yb表示新类的眼底图像样本对应的标签。对于当前任务(第b次增量训练)的训练,将新类对应眼底图像样本被提取后的特征馈送到分类器中,以获得对所有新类之间的区分和平衡。
[0081]
根据本发明的一个实施例,通过以下规则计算第一特征蒸馏损失:
[0082][0083]
其中,表示在增量训练中使用到的伪样本数据,表示增量训练中使用到的伪样本经学生模型中的特征提取网络提取后的样本特征,表示增量训练中学生模型的分类器中全连接层的参数,gb表示经学生模型的分类器处理中对应的全连接层在旧类上的输出,表示增量训练中使用到的伪样本数据经教师模型中的特征提取网络提取后的样本特征,表示增量训练中教师模型的分类器中全连接层的参数,增量训练中教师模型的分类器中全连接层的参数,表示经教师模型处理时全连接层的输出,经教师模型处理时全连接层的输出,表示表示与之间对应元素之间的差异值之和,||.||2表示求l2范数。需要说明的是,当学习新的类(新类)时,学生模型中的特征提取网络会逐步更新。为了缓解保存的伪样本与特征提取网络之间的不匹配,还采用基于逻辑的知识蒸馏(第一特征蒸馏损失)对学生模型中特征提取网络进行正则化;采用全连接层输出logit进行第一蒸
馏损失的计算,可以充分利用logit构建损失指导分类模型的学习,使得学习后的分类模型具有更多的信息。
[0084]
本发明实施例中的第二特征蒸馏损失是通过基于注意力的特征蒸馏进行确定的。根据本发明的一个实施例,通过如下步骤计算所述总的差异:基于注意力机制,确定每个教师特征对各个学生特征的注意力值,以及利用该教师特征对各个学生特征的注意力值构成该教师特征对所有学生特征的注意力向量并进行归一化得到归一化后的注意力向量;基于每个教师特征和每个学生特征的确定空间距离;计算空间距离和归一化后的注意力向量中对应元素的乘积,并将得到的所有乘积进行求和得到总的差异。
[0085]
根据本发明的一个实施例,基于注意力机制,每个教师特征对各个学生特征的注意力值通过如下步骤计算:将学生特征进行数据转换,得到该学生特征在注意力机制中的一个key;将教师特征进行数据转换,得到该教师特征在注意力机制中的一个query;计算每个query对各个key的注意力值。
[0086]
以下以公式的形式讲解第二特征蒸馏损失的计算过程。
[0087]
对于当前步骤b(第b次增量训练)的训练,本发明实施例,使用当前模型作为学生模型,使用步骤b-1模型作为教师模型。设为来自教师模型的候选特征集合(来自教师模型的候选特征集合是由所有伪样本经教师模型中特征提取网络提取得到的图像特征组成的集合),的图像特征组成的集合),是来自学生的候选特征集合(来自学生的候选特征集合是所有伪样本经学生模型中特征提取网络提取得到的图像特征以及属于新类的眼底图像经学生模型中特征提取网络提取得到的图像特征共同组成的集合),其中t和s分别表示来自教师模型和学生模型的提取到的图像特征的数量。每个候选的特征(候选的特征也称图像特征)映射大小和通道尺寸为征也称图像特征)映射大小和通道尺寸为其中h、w和d分别表示高度、宽度和通道尺寸。当给定两组候选项时,afd(注意力元网络)旨在识别所有可能组合(t
×
s对)的相似性,并将教师模型候选的知识传递给具有识别相似性的学生模型。为识别和间的相似度,afd使用了注意力机制中的query-key概念。如图1所示,每个教师特征生成一个query(如图1所示,多个query表示为queries),q
t
,每个学生特征作为一个key(如图1所示,多个key表示为keys),ks。根据本发明的一个实施例,通过如下规则进行数据转换:
[0088][0089][0090]
其中,q
t
表示在注意力机制中的query,表示第t哥个教师特征,p
hw
(
·
)表示全局平均池化,表示的线性变换参数,表示的线性变换参数空间矩阵,fq(
·
)表示第一激活函数,ks表示在注意力机制中的key,表示第s个学生特征,表示的线性变换参数,表示的线性变换参数空间矩阵,表示的线性变换参数的空间矩阵,fk(
·
)表示第二激活函数d表示线性变换参数空间矩阵的维度。需要注意的是,这些特征的过渡权值(例如,w
tq
、w
sk
)是不同的,因为它们通过不同的级别具有不同的属性,即低级视觉
特征可以表示一条线,高级视觉特征可以表示一个对象。因此,本发明实施例对每个特征(或者)使用不同的变换权重。
[0091]
通过使用queries和keys,表示教师和学生表征关系的注意力值可以使用“softmax”函数进行归一化计算。根据本发明的一个实施例,通过以下规则计算每个query对所有key的注意力向量并进行归一化:
[0092][0093]
其中,soffmax(.)表示归一化函数,表示q
t
的转置,表示双线性权值,k
t,1
表示对对应的key值,表示第t个教师特征的位置编码,表示第s个学生特征的位置编码,k
t,s
表示对对应的key值,表示和的乘积,表示的转置。在注意力机制中,由于查询和关键字是从不同维度特征中识别的,因此双线性权值用于概括来自不同来源级别的关注值(换而言之,由于q和k维度不一样,所以不能直接相乘,双线性权值的两个维度分别与q和k相同,这样q和k就能相乘了)。位置编码用于在不同实例上共享公共信息。注意力向量α
t
捕获了第t个教师特征和整个学生特征之间的关系。通过利用α
t
,教师特征能够有选择地将其知识转移到学生特征上。
[0094]
根据本发明的一个实施例,通过以下方法计算第二特征蒸馏损失:
[0095][0096]
其中,α
t,s
表示第t个教师特征对第s个学生特征的归一化后的注意力值,表示空间距离,||.||2表示求l2范数,表示通道平均池化层与l2归一化的组合函数v/||v||2,v表示对进行平均池化得到的向量,表示对使用上采样或下采样得到的特征,其中,使用上采样或下采样是用来匹配特征映射大小与教师特征的大小。
[0097]
总的来说,通过旧类分类损失、新类分类损失、第一特征蒸馏损失和第二特征蒸馏损失确定的总损失,可以指导分类模型进行训练,一方面使用伪样本增强方法构造伪样本集来保存旧类的关键特征,另一方面使用基于注意力机制的特征蒸馏强化增量过程中学生模型从教师模型处获得的关键特征知识(需要说明的是,关键特征知识是表征特征知识中更关键的那部分),使得分类模型在增量学习过程中的遗忘效果更小,精确度方面有明显的提升。换而言之,通过缓解增量过程中模型的遗忘效果,就可以提升模型的精确度。本发明实施例训练得到分类模型(也即是经增量训练的学生模型)是一种高精度强鲁棒,能够在缓解灾难性遗忘的情况下学习新知识的分类模型,并同时对新知识有较好的分类效果。此外,增量过程中使深度学习模型提取具有关键属性的特征、利用教师和学生模型间的相似特征控制特征对之间的蒸馏强度是有重要意义和应用价值。本发明实施例还提供一种眼底图像分类方法,所述方法包括:s1、获取待分类的眼底图像;s2、采用本发明实施例提供的一种分类模型的训练方法得到的经增量训练的学生模型对步骤s1中获得待分类的眼底图像进行图像分类,其中,学生模型的特征提取网络用于根据待分类的眼底图像提取其图像特征,学生模型的分类器用于根据该图像特征识别所述待分类的眼底图像所属的眼底类别。图像分
类结果包括:无糖尿病视网膜病变、轻度糖尿病视网膜病变、中度糖尿病视网膜病变、重度糖尿病视网膜病变、增发性糖尿病视网膜病变。本发明实施例提供一种眼底图像分类方法可以识别出待分类的眼底图像所属的具化类别和泛化类别,并具有很好的精确度。
[0098]
三、实验验证
[0099]
为了更好地说明本发明实施例的技术效果。以下采用对比实验进行验证。其中,实验数据集采用在爱尔医院提供的数据集和eyepacs数据集;分类模型采用cnn+分类器(全连层+softmax层);对比的训练方法包括lucir、lwf、wa、mas、der、muc、cwd、our(代表本发明提供的方法)、coda、upperbound;评价分类模型的指标为accuracy(准确率)。其中,lucir、lwf、wa、mas、der、muc、cwd、coda为论文方法中对应的对比的训练方法的简写。论文方法的具体全称与简写的对应关系为:
[0100]
lwf:learning without forgetting
[0101]
mas:memory aware synapses:learning what(not)to forget
[0102]
muc:more classifiers,less forgetting:a generic multi-classifier paradigm for incremental learning
[0103]
coda:coda-prompt:continual decomposed attention-based prompting for rehearsal-free continual learning
[0104]
lucir:learning a unified classifier incrementally via rebalancing
[0105]
wa:maintaining discrimination and fairness in class incremental learning
[0106]
der:der:dynamically expandable representation for class incremental learning.
[0107]
cwd:mimicking the oracle:an initial phase decorrelation approach for class incremental learning。
[0108]
此外,upperbound表示增量学习实验的上界,使用数据集全部数据按训练集与测试集划分后进行分类的结果。
[0109]
需要说明的是,本实验中,将爱尔医院提供的数据集和eyepacs数据集分别按照以下实验设置方式进行训练:
[0110]
将数据集中任意2个类(例如无dr、轻度dr)并从这2个类开始预训练分类模型,然后将剩下的3个类(例如,中度dr、重度dr、增发性dr)被分为三次增量训练。
[0111]
爱尔医院提供的数据集经过实验设置方式的设置并采用上述对比的训练方法进行训练后的评价结果如图2所示,accuracy代表准确率,number of classes代表类别数量,2_3steps表示按照实验设置方式的实验结果。
[0112]
eyepacs数据集经过实验设置方式的设置并采用上述对比的训练方法进行训练后的评价结果如图3所示,accuracy代表准确率,number of classes代表类别数量,2_3steps表示按照实验设置方式的实验结果。
[0113]
由图2和图3可知,本发明提供的方法(用our代表)明显优于非基于样本的方法,并且在分类精度曲线趋势和平均增量精度方面都大大优于大多数基于样本的方法,这证实了本发明提供的方法可以有效地解决ci l中的灾难性遗忘问题,而无需存储旧的训练样本,实现了更好的稳定性-可塑性平衡。特别是,在爱尔医院提供的数据集和eyepacs数据集上
的实验结果发现,增量学习结束时,总共5个类的整体性能提高了4.27%,平均精度提高了3.33%,由此,本发明提供的方法也始终优于非基于样本的sota方法coda。
[0114]
本发明针对眼底疾病图像增量获得过程中伴随的由于类别差异导致的灾难性遗忘问题,提出了一种分类模型的训练方法以及一种眼底图像分类方法。其属于一种基于注意力特征蒸馏的类别增量细粒度眼底图像识别方法,该方法首先通过构造基于注意力的神经网络模型,来学习教师与学生特征之间的相对相似性,并应用识别到的相似性来控制所有特征对之间的蒸馏强度,并通过为每个旧类在表征空间中保存代表类的伪样本集和,来维持之前类别的决策边界,缓解了增量过程中数据受到的内存或隐私问题的限制。基于注意力特征蒸馏的类别增量细粒度眼底图像识别方法在增量过程中精确度与鲁棒性方面对比现有最先进方法有明显的提升。
[0115]
需要说明的是,虽然上文按照特定顺序描述了各个步骤,但是并不意味着必须按照上述特定顺序来执行各个步骤,实际上,这些步骤中的一些可以并发执行,甚至改变顺序,只要能够实现所需要的功能即可。
[0116]
本发明可以是系统、方法和/或计算机程序产品。计算机程序产品可以包括计算机可读存储介质,其上载有用于使处理器实现本发明的各个方面的计算机可读程序指令。
[0117]
计算机可读存储介质可以是保持和存储由指令执行设备使用的指令的有形设备。计算机可读存储介质例如可以包括但不限于电存储设备、磁存储设备、光存储设备、电磁存储设备、半导体存储设备或者上述的任意合适的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:便携式计算机盘、硬盘、随机存取存储器(ram)、只读存储器(rom)、可擦式可编程只读存储器(eprom或闪存)、静态随机存取存储器(sram)、便携式压缩盘只读存储器(cd-rom)、数字多功能盘(dvd)、记忆棒、软盘、机械编码设备、例如其上存储有指令的打孔卡或凹槽内凸起结构、以及上述的任意合适的组合。
[0118]
以上已经描述了本发明的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。
技术特征:
1.一种分类模型的训练方法,所述分类模型用于眼底图像分类,所述方法包括:获取预训练的分类模型作为教师模型,其包括用于对输入眼底图像提取图像特征的特征提取网络以及用于根据图像特征识别该图像特征对应眼底图像所属的眼底类别的分类器,所述分类器包括全连接层和softmax层,所述教师模型能识别的眼底类别归为旧类;获取学生模型,其特征提取网络的可训练参数用教师模型中特征提取网络的可训练参数初始化,并且其分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别;针对所述预训练中用到的每个旧类,获取该旧类对应的伪样本,所述伪样本是利用教师模型的特征提取网络对属于该旧类的多个眼底图像提取的图像特征生成的;利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型,所述总损失根据以下损失加权求和确定:旧类分类损失、新类分类损失、第一特征蒸馏损失和第二特征蒸馏损失;其中,所述第一特征蒸馏损失根据旧类的伪样本在以下两者上的输出之间的差异进行确定:所述教师模型的全连接层的输出、所述学生模型的全连接层在旧类上的输出;所述第二特征蒸馏损失根据教师模型中特征提取网络的提取到的所有教师特征和经增量训练的学生模型中特征提取网络的提取到的所有学生特征之间的总的差异进行确定,其中,每个教师特征为一个伪样本经教师模型中特征提取网络提取得到的图像特征,每个学生特征为伪样本和属于新类的眼底图像分别经学生模型中特征提取网络提取得到的图像特征。2.根据权利要求1所述的方法,其特征在于,通过如下步骤计算所述总的差异:基于注意力机制,确定每个教师特征对各个学生特征的注意力值,以及利用该教师特征对各个学生特征的注意力值构成该教师特征对所有学生特征的注意力向量并进行归一化得到归一化后的注意力向量;基于每个教师特征和每个学生特征的确定空间距离;计算空间距离和归一化后的注意力向量中对应元素的乘积,并将得到的所有乘积进行求和得到总的差异。3.根据权利要求2所述的方法,其特征在于,基于注意力机制,每个教师特征对各个学生特征的注意力值通过如下步骤计算:将学生特征进行数据转换,得到该学生特征在注意力机制中的一个key;将教师特征进行数据转换,得到该教师特征在注意力机制中的一个query;计算每个query对各个key的注意力值。4.根据权利要求3所述的方法,其特征在于,通过如下规则进行数据转换:4.根据权利要求3所述的方法,其特征在于,通过如下规则进行数据转换:其中,q
t
表示在注意力机制中的query,表示第t个教师特征,p
hw
(
·
)表示全局平
均池化,表示的线性变换参数,表示的线性变换参数空间矩阵,f
q
(
·
)表示第一激活函数,k
s
表示在注意力机制中的key,表示第s个学生特征,表示的线性变换参数,表示的线性变换参数空间矩阵,表示的线性变换参数的空间矩阵,f
k
(
·
)表示第二激活函数,d表示线性变换参数空间矩阵的维度。5.根据权利要求4所述的方法,其特征在于,通过以下规则计算每个query对所有key的注意力向量并进行归一化:其中,softmax(.)表示归一化函数,表示q
t
的转置,表示双线性权值,k
t,1
表示对对应的key值,表示第t个教师特征的位置编码,表示第s个学生特征的位置编码,k
t,s
表示对对应的key值,表示和的乘积,表示的转置。6.根据权利要求5所述的方法,其特征在于,通过以下方法计算第二特征蒸馏损失:其中,α
t,s
表示第t个教师特征对第s个学生特征的归一化后的注意力值,表示空间距离,||.||2表示求l2范数,表示通道平均池化层与l2归一化的组合函数v/||v||2,v表示对进行平均池化得到的向量,表示对使用上采样或下采样得到的特征。7.根据权利要求1所述的方法,其特征在于,按照如下步骤获得旧类的伪样本:t1、基于预训练中用到的旧类对应的多个图像特征计算该旧类对应的类均值向量;t2、从高斯分布中随机采样得到至少一个噪声向量,所述噪声向量与类均值向量的维度相同;t3、基于预先定义的增强尺度、从高斯分布中随机采样得到噪声向量以及旧类对应的类均值向量按照预设的增强规则计算得到该旧类的伪样本。8.根据权利要求7所述的方法,其特征在于,所述方法包括对学生模型进行多次增量训练,其中,教师模型为当前次增量训练的上一次增量训练后的学生模型,通过以下规则计算旧类对应的类均值向量:其中,表示旧类k
old
对应的类均值向量,表示旧类k
old
的样本数量,f(x
b-1,n
;θ
b-1
)表示旧类k
old
中第m个样本的样本特征,θ
b-1
表示在第b次增量训练中教师模型的特征提取网络的参数。
9.根据权利要求8所述的方法,其特征在于,所述预设的增强规则为:其中,表示第b次增量训练中使用到的旧类k
old
对应伪样本,r
b
表示第b次增量训练中预先定义的增强尺度,e表示采样至高斯分布的一个随机噪声向量,e和的维度相同。10.根据权利要求9所述的方法,其特征在于,通过如下规则设置预先定义增强尺度:其中,k
old
和k
mew
分别表示第b次增量训练中旧类和新类的数量,r
b-1
表示第b-1次增量训练中预先定义的增强尺度,σ
b,k
表示第b次增量训练中第k个新类的协方差矩阵,tr(σ
b,k
)表示σ
b,k
的秩,m表示深度特征空间的维数,当b=1时,k1表示第1次预训练中类别数量。11.根据权利要求1所述的方法,其特征在于,按照如下规则计算总损失:其中,表示总损失,表示新类分类损失,表示旧类分类损失,λ表示第一超参数,表示第一特征蒸馏损失,β第二超参数,表示第二特征蒸馏损失,γ表示第三超参数。12.根据权利要求10所述的方法,其特征在于,通过以下规则计算第一特征蒸馏损失:其中,表示在增量训练中使用到的伪样本数据,表示增量训练中使用到的伪样本经学生模型中的特征提取网络提取后的样本特征,表示增量训练中学生模型的分类器中全连接层的参数,g
b
表示经学生模型的分类器处理中对应的全连接层在旧类上的输出,表示增量训练中使用到的伪样本数据经教师模型中的特征提取网络提取后的样本特征,表示增量训练中教师模型的分类器中全连接层的参数,量训练中教师模型的分类器中全连接层的参数,表示经教师模型处理时全连接层的输出,经教师模型处理时全连接层的输出,表示g
b
与之间对应元素之间的差异值之和,||.||2表示求l2范数。13.根据权利要求1-12任一项所述的方法,其特征在于,旧类包括以下类别中任意两
种,剩余类别为新类:无糖尿病视网膜病变、轻度糖尿病视网膜病变、中度糖尿病视网膜病变、重度糖尿病视网膜病变、增发性糖尿病视网膜病变。14.一种眼底图像分类方法,其特征在于,所述方法包括:s1、获取待分类的眼底图像;s2、采用如权利要求1-13任一所述方法得到的经增量训练的学生模型对步骤s1中获得待分类的眼底图像进行图像分类,其中,学生模型的特征提取网络用于根据待分类的眼底图像提取其图像特征,学生模型的分类器用于根据该图像特征识别所述待分类的眼底图像所属的眼底类别。15.根据权利要求14所述的方法,其特征在于,图像分类的结果包括:无糖尿病视网膜病变、轻度糖尿病视网膜病变、中度糖尿病视网膜病变、重度糖尿病视网膜病变、增发性糖尿病视网膜病变。16.一种计算机可读存储介质,其特征在于,其上存储有计算机程序,所述计算机程序可被处理器执行以实现权利要求1至15任一所述方法的步骤。17.一种电子设备,其特征在于,包括:一个或多个处理器;存储装置,用于存储一个或多个程序,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述电子设备实现如权利要求1至15中任一项所述方法的步骤。
技术总结
本发明提供一种分类模型的训练方法和眼底图像分类方法,属于增量学习领域。一种分类模型的训练方法,所述分类模型用于眼底图像分类,所述方法包括:获取预训练的分类模型作为教师模型,其包括特征提取网络以及分类器,所述教师模型能识别的眼底类别归为旧类;获取学生模型,其特征提取网络用教师模型初始化,并且其分类器设置为能对旧类和新类对应的眼底类别进行识别,所述新类是所述旧类之外的眼底类别;针对用到的每个旧类,获取该旧类对应的伪样本;利用生成的旧类的伪样本和属于新类的眼底图像对所述学生模型进行增量训练,训练时基于预设的总损失函数确定的总损失更新学生模型的参数,得到经增量训练的学生模型。本发明可以缓解灾难性遗忘。明可以缓解灾难性遗忘。明可以缓解灾难性遗忘。
技术研发人员:谷洋 郭帅 文世杰 马媛 杨昭华 翁伟宁 陈益强
受保护的技术使用者:中国科学院计算技术研究所
技术研发日:2023.07.19
技术公布日:2023/9/19
版权声明
本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
航空之家 https://www.aerohome.com.cn/
飞机超市 https://mall.aerohome.com.cn/
航空资讯 https://news.aerohome.com.cn/