多类别目标检测方法及其模型训练方法、装置与流程
未命名
08-13
阅读:105
评论:0

1.本技术涉及目标检测技术领域,尤其涉及一种多类别目标检测方法及其模型训练方法、装置。
背景技术:
2.当项目中的目标检测模型需要能够同时检测和识别多种物体对象,例如,可同时检测出画面中的所有水果,并且可以识别出这些水果的所属种类,通常地,需要先对相应的样本数据集进行整理,然后进行目标检测模型的训练。其中,在收集数据集进行训练的过程中,一般希望能够收集完整的一个数据集,即数据集中每一张图片里包含所有待检测物体对象的标注。
3.然而,在实际项目中,比较常见的情况是:已有一个数据集a,里面带有物体a的标注;已有一个数据集b,里面带有物体b的标注,但是在数据集a中,即使图片中出现了物体b,也没有对物体b进行标注,数据集b中也是同理。要是想利用这样两个数据集a和b来直接合并进行多个物体的目标检测模型训练是不可行的。
技术实现要素:
4.有鉴于此,本技术实施例提供一种多类别目标检测方法及其模型训练方法、装置。
5.第一方面,本技术实施例提供一种将训练图片分别输入至预先训练的不同老师网络中进行目标检测,输出每个老师网络中不同检测阶段的参考特征图;其中,不同老师网络用于对不同类别的目标进行检测;所述训练图片中包含多个目标类别且带有部分目标类别标注;
6.基于不同检测阶段的所述参考特征图,得到所述训练图片在对应类别目标检测中的正负样本数据;
7.将所述训练图片输入至学生网络中进行多类别目标检测,输出在所述各个检测阶段的预测特征图;
8.基于所述正负样本数据、所述参考特征图和所述预测特征图,计算通过所述学生网络进行多类别目标检测时的损失值;
9.利用所述损失值进行反向传播以继续训练所述学生网络,直至满足预设训练停止条件,得到多类别目标检测模型。
10.在一些实施例中,所述输出每个老师网络中不同检测阶段的参考特征图,包括:按照预设尺寸输出每个老师网络在不同检测阶段得到的参考特征图,其中,每个所述参考特征图中的每个点均包含预设数量的锚框,所述锚框用于通过解析得到所述正负样本数据。
11.在一些实施例中,所述基于不同检测阶段的所述参考特征图,得到所述训练图片在对应类别目标检测中的正负样本数据,包括:
12.在对应类别目标检测过程中,按照预设分配规则将各个检测阶段中的所述参考特征图中相应位置的锚框分配为正样本还是负样本;
13.将区分出的正样本和负样本通过矩阵形式进行描述,以得到所述训练图片在对应类别目标检测中的正负样本数据。
14.在一些实施例中,利用每个检测阶段的所述正负样本数据,分别计算每个检测阶段的所述参考特征图与所述预测特征图之间的知识蒸馏损失值;
15.将所有检测阶段的所述知识蒸馏损失值进行加权计算,得到知识蒸馏总损失值;
16.计算所述预测特征图与基于所述训练图片包含的所述目标类别标注得到的真实特征图之间的目标检测损失值;
17.利用所述知识蒸馏总损失值与所述目标检测损失值,计算得到通过所述学生网络进行多类别目标检测时的损失值。
18.在一些实施例中,所述不同检测阶段包括三个阶段,分别为回归检测阶段、是否为目标判定阶段和目标类别判定阶段;
19.其中,每个老师网络在所述目标类别判定阶段中仅输出对其中一种目标类别判定的参考特征图;所述学生网络在所述目标类别判定阶段中输出对所有目标类别判定的相应预测特征图。
20.在一些实施例中,所述老师网络的数量不超过需要检测的目标类别的总数量;所述不同老师网络通过预先训练得到,包括:
21.将包含多种类别目标的每个样本图片,按照不同的部分目标标注划分得到带不同目标类别标注的若干个数据集;
22.利用不同的所述数据集分别对构建的若干个神经网络进行不同目标检测训练,以得到与所述数据集数量相等的用于检测不同类别目标的若干个老师网络。
23.第二方面,本技术实施例还提供一种多类别目标检测方法,包括:
24.将目标图像输入至所述的多类别目标检测模型训练方法得到的多类别目标检测模型中进行目标检测,得到所述目标图像中存在的所有类别目标的预测结果。
25.第三方面,本技术实施例还提供一种多类别目标检测模型训练装置,包括:
26.老师网络推理模块,用于将训练图片分别输入至预先训练的不同老师网络中进行目标检测,输出每个老师网络中不同检测阶段的参考特征图;其中,不同老师网络用于对不同类别的目标进行检测;所述训练图片中包含多个目标类别且带有部分目标类别标注;
27.正负样本获取模块,用于基于不同检测阶段的所述参考特征图,得到所述训练图片在对应类别目标检测中的正负样本数据;
28.学生网络训练模块,用于将所述训练图片输入至学生网络中进行多类别目标检测,输出在所述各个检测阶段的预测特征图;
29.损失计算模块,用于基于所述正负样本数据、所述参考特征图和所述预测特征图,计算通过所述学生网络进行多类别目标检测时的损失值;
30.所述学生网络训练模块,还用于利用所述损失值进行反向传播以继续训练所述学生网络,直至满足预设训练停止条件,得到多类别目标检测模型。
31.第四方面,本技术实施例还提供一种终端设备,所述终端设备包括处理器和存储器,所述存储器存储有计算机程序,所述处理器用于执行所述计算机程序以实施所述的多类别目标检测模型训练方法或多类别目标检测方法。
32.第五方面,本技术实施例还提供一种可读存储介质,其存储有计算机程序,所述计
算机程序在处理器上执行时,实施所述的多类别目标检测模型训练方法或多类别目标检测方法。
33.本技术的实施例具有如下有益效果:
34.本技术的多类别目标检测模型训练方法通过预先训练得到不同的老师网络,并利用每个老师网络对仅带有部分目标类别标注的训练图片进行对应类别目标检测,以得到各自在不同检测阶段的相应参考特征图;之后,利用得到的这些参考特征图来对需要训练的学生网络所输出的相应预测特征图计算损失值,最后利用损失值进行反向传播训练,从而得到一个多类别目标检测模型。该训练方法不仅不需要手动对已有数据集进行标注补充和调整,还通过知识蒸馏充分利用了所有的数据集信息来进行模型训练,达到了同时提升训练效率和效果的目的。
附图说明
35.为了更清楚地说明本技术实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本技术的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
36.图1示出了一个包含二分类器的典型多目标检测网络的结构示意图;
37.图2示出了本技术实施例多类别目标检测模型训练方法的第一流程图;
38.图3示出了本技术实施例多类别目标检测模型训练方法的第二流程图;
39.图4示出了一种采用本技术的训练方法进行两类目标检测的示意图;
40.图5示出了本技术实施例多类别目标检测模型训练方法的第三流程图;
41.图6示出了本技术实施例多类别目标检测方法的流程图;
42.图7示出了本技术实施例多类别目标检测模型训练方法的结构示意图。
具体实施方式
43.下面将结合本技术实施例中附图,对本技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本技术一部分实施例,而不是全部的实施例。
44.通常在此处附图中描述和示出的本技术实施例的组件可以以各种不同的配置来布置和设计。因此,以下对在附图中提供的本技术的实施例的详细描述并非旨在限制要求保护的本技术的范围,而是仅仅表示本技术的选定实施例。基于本技术的实施例,本领域技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本技术保护的范围。
45.在下文中,可在本技术的各种实施例中使用的术语“包括”、“具有”及其同源词仅意在表示特定特征、数字、步骤、操作、元件、组件或前述项的组合,并且不应被理解为首先排除一个或更多个其它特征、数字、步骤、操作、元件、组件或前述项的组合的存在或增加一个或更多个特征、数字、步骤、操作、元件、组件或前述项的组合的可能性。此外,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
46.除非另有限定,否则这里使用的所有术语(包括技术术语和科学术语)具有与本技术的各种实施例所属领域普通技术人员通常理解的含义相同的含义。所述术语(诸如在一
般使用的词典中限定的术语)将被解释为具有与在相关技术领域中的语境含义相同的含义并且将不被解释为具有理想化的含义或过于正式的含义,除非在本技术的各种实施例中被清楚地限定。
47.下面结合附图,对本技术的一些实施方式作详细说明。在不冲突的情况下,下述的实施例及实施例中的特征可以相互结合。
48.通常地,对于仅带有部分物体标注的两个数据集a和b,若想用于实现对两种目标的检测,最简单的方法是使用数据集a训练得到针对物体a的检测模型,再使用数据集b训练得到物体b的检测模型。然而在实际项目中因终端设备等资源有限,通常只能使用一个检测模型来同时对物体a和物体b进行检测,但是由于数据集a中可能有大量未标注的物体b,同样数据集b中有大量未标注的物体a,为了实现对多个物体的同时检测,目前可采用的做法主要有以下几种:
49.第一种是,手动在数据集a中增加对物体b的标注,且在数据集b中增加对物体a的标注,然后再进行一个检测模型的训练,可想而知,对这些数据集的补充标注,将需要耗费大量的标注成本。
50.第二种是,分别利用数据集a和数据集b训练得到两个目标检测网络,然后利用训练得到的目标检测网络在两个数据集中推理得到缺少的类别的伪标签,最后在补全的数据集上进行完整的训练。该方案会对伪标签的质量非常敏感,将依然需要耗费时间对伪标签进行调整。
51.第三种则是,按照如图1所示的网络结构方式进行目标检测网络模型设计,其中,bockbone(骨干网络)负责特征提取,heads(头部)中的reg head负责计算回归检测框,obj head负责判断是否为目标,cls head负责判断目标的类别,需要注意的是,cls head使用多个二分类分类器(而不使用常见的softmax分类器)。在训练过程中,当加载数据集a训练时,只训练cls head a,关闭cls head b;同理,加载数据集b训练时,只训练cls head b,关闭cls head a。该方案虽然基本是可用的,但是数据集a中出现的物体b的信息,完全没有得到有效的利用。
52.为了解决多类别目标检测训练任务中存在多个数据集,且每个数据集所包含的物体类别标注不完整时,而使用经典方案进行模型训练所存在的问题,本技术提出了一种新的使用多个仅带有部分类别标注的数据集基于知识蒸馏思想来进行多类别目标检测模型的训练方法,以提升模型训练的效率和效果。
53.下面将结合一些具体的实施例对该多类别目标检测模型训练方法进行详细说明。
54.示范性地,如图2所示,该多类别目标检测模型训练方法包括:
55.s110,将训练图片分别输入至预先训练的不同老师网络中进行目标检测,输出每个老师网络中的不同检测阶段的参考特征图;其中,不同老师网络用于对不同类别的目标进行检测。
56.基于知识蒸馏原理可知,其核心思路是先训练一个复杂网络模型,然后使用这个复杂网络的输出和数据的真实标签去训练一个更小的网络,其中,复杂网络模型称为老师(teacher)网络,而该更小的网络模型则称为学生(student)网络。
57.其中,该训练图片中包含多个目标类别且带有部分目标类别标注。本技术中,通过先利用带有部分目标类别标注的多个数据集分别预先训练得到多个老师网络,再利用这些
老师网络对每一张输入的训练图片进行对应的目标类别检测,从而得到各自在相应检测阶段的特征图。可以理解,由于已训练的各个老师网络输出的特征图后续将用作训练学生网络时的参考,故这里称为参考特征图。
58.在一种实施方式中,如图3所示,在步骤s110之前,该方法还包括:
59.s100,将包含多种类别目标的每个样本图片,按照不同的部分目标标注划分得到带不同目标类别标注的若干个数据集;进而,利用不同的数据集分别对构建的若干个神经网络进行不同目标检测训练,以得到与数据集数量相等的用于检测不同类别目标的若干个老师网络。
60.其中,用作老师网络的各个神经网络的网络层结构可以选取相同,只是由于对不同类别的目标进行检测,其内部的网络参数可能会有所差异。
61.值得注意的是,老师网络的数量不会超过需要检测的目标类别的总数量,可取决于现有数据集的数量。例如,可以是每一个老师网络用于检测一个目标类别,也可以是一个老师网络用于检测两个及以上的目标类别,由于已有数据集中的标注信息并不完整,故无法直接检测出所有目标类别。
62.例如,为了训练一个能够检测目标oa、ob和oc三类物体的模型,假设各样本图片中存在的目标类别共有3种,但每张中标注的只有其中一种或两种目标类别(即标注信息不完整),那么可以将具有同一类别目标标注的所有样本图片划分到一个数据集中,可以得到具有不同目标类别标注的数据集da、db和dc。进而,利用数据集da去训练神经网络sa进行目标oa的检测,利用数据集db去训练神经网络sb对目标ob的检测,以及利用数据集dc去训练神经网络sc对目标oc的检测,从而得到能够用于分别检测三个目标类别a、b和c的三个老师网络ta、tb、tc。
63.在另一种情况下,若每张样本图片中同时标注有两种目标类别,即一个数据集d1中对两种目标(如oa、ob)同时进行了标注,而另一个数据集d2对目标oc进行标注,那么,此时可使用数据集d1训练一个可以同时检测目标oa、ob的老师网络t1,使用数据集d2训练一个可以检测目标oc的老师网络t2。
64.可以理解,通过设置不超过需要检测的目标类别总数量相等的多个老师网络,可以输出分别用于检测对应目标类别时得到的相应特征图,由于这些老师网络已经经过训练,故输出的这些特征图满足预测准确性要求,因此将其作为训练学生网络时的参考,可以让学生网络学习到所有老师网络的目标识别检测能力,同时还具有更高的精准率和更少的资源占用等。
65.其中,在目标检测过程中,往往需要执行多个任务,如回归、分类等。对于本技术所述的不同检测阶段,在一种实施方式中,主要包括三个阶段,分别为:回归检测框计算阶段、是否为目标判定阶段和目标类别判定阶段。其中,回归检测框计算阶段主要用于计算逼近目标的回归检测框的位置信息;是否为目标判定阶段主要用于判断该点所在网格是属于目标(前景)还是属于背景;目标类别判定阶段主要用于判断该目标是哪一种目标类别。
66.可以理解,一个老师网络在目标类别判定阶段中只能输出部分目标类别判定的参考特征图。与之不同的是,学生网络在目标类别判定阶段中可以输出对所有目标类别判定的相应预测特征图,即存在几个目标类别,则会输出几种目标类别的预测特征图。
67.进一步地,在输出参考特征图时,可按照预设尺寸来输出每个老师网络在不同检
测阶段得到的参考特征图。值得注意的是,输出的每个参考特征图中的每个点(网格)均包含预设数量的锚框(anchor)。
68.如图4所示,以两类目标检测(目标a和目标b)为例,将一训练图片输入至老师网络a进行目标a检测时,可输出老师网络a在三个阶段得到的相应特征图,分别记为reg
a'
、obj
a'
、cls
a'
;同理,将该训练图片输入至老师网络b进行目标b检测时,可输出老师网络a在三个阶段得到的相应特征图,分别记为reg
b'
、obj
b'
、cls
b'
,其中,这些参考特征图的大小为s
×
s,且图中每一个点均包含有k个anchor。
69.可以理解,通过进一步解析这些参考特征图中的anchor,可以得到该训练图片进行对应目标类别检测时的正负样本分布,以便计算对应检测阶段的目标类别、前背景区分等的损失。
70.s120,基于不同检测阶段的参考特征图,得到该训练图片在对应类别目标检测过程中的正负样本数据。
71.示范性地,在对应类别目标检测过程中,可按照预设分配规则将各个检测阶段中的参考特征图中相应位置的锚框分配为正样本还是负样本;进而,将区分出的正样本和负样本通过矩阵形式进行描述,以得到训练图片在对应类别目标检测中的正负样本数据。
72.其中,针对预设分配规则,将这些参考特征图中的锚框分配给正负样本的方式有多种,这里不作限定。例如,在一种实施方式中,对于每个anchor,可通过计算与anchor最匹配的gt(ground truth)的最大iou,并分别设定正负样本的阈值,以实现正负样本的匹配。具体地,若该最大iou大于设定的正样本阈值,则认为该anchor box为正样本;若该最大iou小于设定的负样本阈值,则认为该anchor为负样本。其中,iou用于测量真实值与预测值之间的相关度,是两个区域重叠的部分除以两个区域的集合部分得出的结果。
73.例如,在另一种实施方式中,也可以通过计算各个anchor与对应的gt的iou,并求均值和标准差的和作为正样本筛选阈值,进而找出与gt的iou大于筛选阈值的anchor,将anchor中心在gt内部且与gt的iou大于筛选阈值的anchor作为正样本。
74.可选地,在又一种实施方式中,还可以通过概率分布方式来区分正负样本。例如,通过计算在相应检测阶段中的该参考特征图中相应位置的anchor的得分,以得到对应的概率分布;其中,该得分可以反映出检测框分类和定位的质量。然后,利用该概率分布区分这个anchor为正样本还是负样本。可以理解,通过对每个anchor进行打分,可以得到模型认为这个anchor是否对于检测目标物体非常有用的评价,以此来确定该anchor是属于正样本还是负样本。
75.由此,可将这些anchor进行正负样本分配,为方便后续的损失计算,可将区分出的正样本和负样本通过矩阵形式进行描述,例如,对于为正样本的anchor,可用“1”表示;对于为负样本的anchor,可用“0”表示,以得到训练图片在对应类别目标检测过程中的正负样本数据。
76.s130,将该训练图片输入至学生网络中进行多类别目标检测,输出在各个检测阶段的预测特征图。
77.如图4所示,与老师网络仅包含部分分类判定分支的结构不同,本技术中的学生网络中同时包含用于不同目标类别判定的所有分类判定分支。
78.示范性地,将上述的训练图片输入到学生网络中,可以得到对应检测阶段的特征
图。例如,以上述的三个阶段为例,可预测输出对应的特征图reg、obj、clsa和clsb。进而,根据这些预测特征图来计算训练损失。
79.s140,基于所述正负样本数据、所述参考特征图和所述预测特征图,计算所述学生网络进行多类别目标检测时的损失值。
80.在一种实施方式中,如图5所示,上述步骤s140包括如下子步骤:
81.s141,利用所述正负样本数据,分别计算每个检测阶段的参考特征图与所述预测特征图之间的知识蒸馏损失值。
82.s142,将所有检测阶段的知识蒸馏损失值进行加权计算,得到知识蒸馏总损失值。
83.例如,以上述的两个目标类别的检测和三个检测阶段为例,知识蒸馏损失值的计算表达式如下:
[0084][0085][0086][0087]
式中,l
kd_reg
、l
kd_obj
、l
kd_cls
分别表示回归检测框计算阶段、是否为目标判定阶段和目标类别判定阶段的知识蒸馏损失值;i
a'
、i
b'
表示两个老师网络各自推理得到的正负样本矩阵;
[0088]
l
oss
(reg
a'_ij-reg
ij
)表示对于目标a在回归检测框计算阶段的参考特征图reg
a'_ij
与预测特征图reg
ij
之间的损失值,l
oss
(reg
b'_ij-reg
ij
)表示对于目标b在回归检测框计算阶段的参考特征图reg
b'_ij
与预测特征图reg
ij
之间的损失值;其中,ij表示第i个点的第j个anchor的取值,0代表该点为负样本,1代表该点为正样本;
[0089]
同理,l
oss
(obj
a'_ij-obj
ij
)表示对于目标a在是否为目标判定阶段的参考特征图obj
a'_ij
与预测特征图obj
ij
之间的损失值,l
oss
(obj
b'_ij-obj
ij
)表示对于目标b在是否为目标判定阶段的参考特征图obj
b'_ij
与预测特征图obj
ij
之间的损失值;
[0090]
l
oss
(cls
a'_ij-cls
a_ij
)表示对于目标a在目标类别判定阶段的参考特征图cls
a'_ij
与预测特征图cls
a_ij
之间的损失值,l
oss
(cls
b'_ij-cls
b_ij
)表示对于目标b在目标类别判定阶段的参考特征图cls
b'_ij
与预测特征图cls
b_ij
之间的损失值。
[0091]
可以理解,知识蒸馏损失值的计算与需要检测的目标类别数量、检测阶段有关,其中,在计算某一个检测阶段的损失值l
oss
时,可以采用l1、l2、smooth l1、smooth l2等常用的损失函数计算,这里不作限定。
[0092]
进而,结合预设权重,该知识蒸馏总损失值l
kd
可按照如下公式计算:
[0093]
l
kd
=λ1l
kd_reg
+λ2l
kd_obj
+λ3l
kd_cls
;
[0094]
式中,λ1、λ2和λ3分别为三个阶段各自的预设权重。
[0095]
s143,计算所述预测特征图与基于该训练图片包含的目标类别标注得到的真实特
征图之间的目标检测损失值。
[0096]
除了考虑知识蒸馏损失外,同时,由于已知该训练图片属于哪个数据集,因此,可以利用带有的部分目标类别标注信息来计算预测该类目标的目标检测损失,其中可包括分类损失和回归损失。
[0097]
以上述的两类目标检测为例,若该训练图片属于数据集a,则基于对目标a的标注对应的真实特征图(ground truth),可以计算出目标a在回归检测框计算阶段的预测特征图与真实特征图之间的回归损失l
reg
,例如,可通过l1、l2、smooth l1、smooth l2等损失函数计算。同时,在目标类别判定阶段,可计算出目标a的类别损失l
cls
,例如,可采用交叉熵损失函数等来计算。最后,将分类损失和回归损失相加或加权,可得到目标检测损失值l。
[0098]
s144,利用所述知识蒸馏总损失值与所述目标检测损失值,计算得到通过学生网络进行多类别目标检测时的损失值。
[0099]
于是,将所有类型的损失值相加可得到总损失值。例如,可表示如下:
[0100]
l
total
=γl
kd
+l;
[0101]
式中,l
total
表示总损失值;λ表示预设的权重值,其可根据需求设定。
[0102]
s150,利用所述损失值进行反向传播以继续训练所述学生网络,直至满足预设训练停止条件,得到多类别目标检测模型。
[0103]
在训练过程中,将根据计算到的总损失值进行梯度反向传输,通过反向传播,例如,可采用梯度下降法来迭代该学生网络中的权值,并继续训练,直至满足预设条件为止,可得到训练好的学生网络,以作为多类别目标检测模型。关于神经网络的反向传输原理,由于不属于本技术的重点,具体可参见已公开的相关文献,这里不展开描述。
[0104]
其中,上述的预设训练停止条件用于判断何时该停止训练,例如,可以包括但不限于为,总损失值小于预设阈值(即总损失值足够小,如趋近于0或位于某范围内等),达到迭代次数阈值,又或是总损失值和迭代次数同时满足要求等,这里不作具体限定。
[0105]
本实施例的多类别目标检测模型训练方法利用知识蒸馏思想,通过利用带部分目标类别标注的数据集来预先训练得到对应的多个老师网络,进而利用各老师网络在不同检测阶段的相应参考特征图作为训练学生网络时的参考,并结合正负样本数据来计算训练损失,最后进行反向传播训练以得到多类别目标检测模型。该训练方法融合了分类器、伪标签、知识蒸馏的方法,不仅不需要手动对已有数据集进行标注补充和调整,还通过知识蒸馏充分利用了所有的数据集信息来进行模型训练,达到了同时提升训练效率和效果的目的。
[0106]
图6示出为本技术实施例提出的一种多类别目标检测方法。示范性地,该多类别目标检测方法包括:
[0107]
s210,将目标图像输入至通过上述的多类别目标检测模型训练方法所得到的多类别目标检测模型中进行目标检测,得到该目标图像中存在的所有类别目标的预测结果。
[0108]
示范性地,在得到多类别目标检测模型后,可将需要检测的目标图像输入,可以预测输出其存在的所有目标类别及其位置信息等。例如,若目标图像中存在3个种不同的水果,如苹果、香蕉和火龙果,那么,可以得到这3种水果的类别信息、在图像中的具体位置信息等。
[0109]
其中,关于该多类别目标检测模型的训练方法,具体可参见上述实施例中的相关描述,故在此不再重复描述。
[0110]
此外,图7还示出了本技术实施例提出的一种多类别目标检测模型训练装置。示范性地,该多类别目标检测模型训练装置包括:
[0111]
老师网络推理模块110,用于将训练图片分别输入至预先训练的不同老师网络中进行目标检测,输出每个老师网络中不同检测阶段的参考特征图;其中,不同老师网络用于对不同类别的目标进行检测;所述训练图片中包含多个目标类别且带有部分目标类别标注。
[0112]
正负样本获取模块120,用于基于不同检测阶段的所述参考特征图,得到所述训练图片在对应类别目标检测中的正负样本数据。
[0113]
学生网络训练模块130,用于将所述训练图片输入至学生网络中进行多类别目标检测,输出在所述各个检测阶段的预测特征图。
[0114]
损失计算模块140,用于基于所述正负样本数据、所述参考特征图和所述预测特征图,计算通过所述学生网络进行多类别目标检测时的损失值。
[0115]
学生网络训练模块130,还用于利用所述损失值进行反向传播以继续训练所述学生网络,直至满足预设训练停止条件,得到多类别目标检测模型。
[0116]
可以理解,本实施例的装置对应于上述实施例的方法,上述实施例中的方法的可选项同样适用于本实施例,故在此不再重复描述。
[0117]
本技术还提供了一种终端设备,例如,该终端设备可包括但不限于为机器人、具有多类别目标识别及检测功能的摄像装置等,进一步地,若为机器人,其具体形状并不作限定。示范性地,该终端设备包括处理器和存储器,其中,存储器存储有计算机程序,处理器通过运行所述计算机程序,从而使终端设备执行上述的多类别目标检测模型训练方法或多类别目标检测方法。
[0118]
其中,处理器可以是一种具有信号的处理能力的集成电路芯片。处理器可以是通用处理器,包括中央处理器(central processing unit,cpu)、图形处理器(graphics processing unit,gpu)及网络处理器(network processor,np)、数字信号处理器(dsp)、专用集成电路(asic)、现成可编程门阵列(fpga)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件中的至少一种。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等,可以实现或者执行本技术实施例中的公开的各方法、步骤及逻辑框图。
[0119]
存储器可以是,但不限于,随机存取存储器(random access memory,ram),只读存储器(read only memory,rom),可编程只读存储器(programmable read-only memory,prom),可擦除只读存储器(erasable programmable read-only memory,eprom),电可擦除只读存储器(electric erasable programmable read-only memory,eeprom)等。其中,存储器用于存储计算机程序,处理器在接收到执行指令后,可相应地执行所述计算机程序。
[0120]
本技术还提供了一种可读存储介质,用于储存上述终端设备中使用的所述计算机程序,其中,所述计算机程序在处理器上执行时,实施上述实施例的多类别目标检测模型训练方法或多类别目标检测方法。
[0121]
其中,上述的可读存储介质可以是非易失性存储介质,也可以是易失性存储介质。例如,该可读存储介质可包括但不限于为:u盘、移动硬盘、只读存储器(rom,read-only memory)、随机存取存储器(ram,random access memory)、磁碟或者光盘等各种可以存储程
序代码的介质。
[0122]
在本技术所提供的几个实施例中,应该理解到,所揭露的装置和方法,也可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,附图中的流程图和结构图显示了根据本技术的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,所述模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在作为替换的实现方式中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,结构图和/或流程图中的每个方框、以及结构图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
[0123]
另外,在本技术各个实施例中的各功能模块或单元可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或更多个模块集成形成一个独立的部分。
[0124]
所述功能如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本技术的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是智能手机、个人计算机、服务器、或者网络设备等)执行本技术各个实施例所述方法的全部或部分步骤。
[0125]
以上所述,仅为本技术的具体实施方式,但本技术的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本技术揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本技术的保护范围之内。
技术特征:
1.一种多类别目标检测模型训练方法,其特征在于,包括:将训练图片分别输入至预先训练的不同老师网络中进行目标检测,输出每个老师网络中不同检测阶段的参考特征图;其中,不同老师网络用于对不同类别的目标进行检测;所述训练图片中包含多个目标类别且带有部分目标类别标注;基于不同检测阶段的所述参考特征图,得到所述训练图片在对应类别目标检测中的正负样本数据;将所述训练图片输入至学生网络中进行多类别目标检测,输出在各个所述检测阶段的预测特征图;基于所述正负样本数据、所述参考特征图和所述预测特征图,计算通过所述学生网络进行多类别目标检测时的损失值;利用所述损失值进行反向传播以继续训练所述学生网络,直至满足预设训练停止条件,得到多类别目标检测模型。2.根据权利要求1所述的多类别目标检测模型训练方法,其特征在于,所述输出每个老师网络中不同检测阶段的参考特征图,包括:按照预设尺寸输出每个老师网络在不同检测阶段得到的参考特征图,其中,每个所述参考特征图中的每个点均包含预设数量的锚框,所述锚框用于通过解析得到所述正负样本数据。3.根据权利要求2所述的多类别目标检测模型训练方法,其特征在于,所述基于不同检测阶段的所述参考特征图,得到所述训练图片在对应类别目标检测中的正负样本数据,包括:在对应类别目标检测过程中,按照预设分配规则将各个检测阶段中的所述参考特征图中相应位置的锚框分配为正样本还是负样本;将区分出的正样本和负样本通过矩阵形式进行描述,以得到所述训练图片在对应类别目标检测中的正负样本数据。4.根据权利要求1所述的多类别目标检测模型训练方法,其特征在于,所述基于所述正负样本数据、所述参考特征图和所述预测特征图,计算通过所述学生网络进行多类别目标检测时的损失值,包括:利用每个检测阶段的所述正负样本数据,分别计算每个检测阶段的所述参考特征图与所述预测特征图之间的知识蒸馏损失值;将所有检测阶段的所述知识蒸馏损失值进行加权计算,得到知识蒸馏总损失值;计算所述预测特征图与基于所述训练图片包含的所述目标类别标注得到的真实特征图之间的目标检测损失值;利用所述知识蒸馏总损失值与所述目标检测损失值,计算得到通过所述学生网络进行多类别目标检测时的损失值。5.根据权利要求1至4中任一项所述的多类别目标检测模型训练方法,其特征在于,所述不同检测阶段包括三个阶段,分别为回归检测阶段、是否为目标判定阶段和目标类别判定阶段;其中,每个老师网络在所述目标类别判定阶段中仅输出对其中一种目标类别判定的参考特征图;
所述学生网络在所述目标类别判定阶段中输出对所有目标类别判定的相应预测特征图。6.根据权利要求1所述的多类别目标检测模型训练方法,其特征在于,所述老师网络的数量不超过需要检测的目标类别的总数量;所述不同老师网络通过预先训练得到,包括:将包含多种类别目标的每个样本图片,按照不同的部分目标标注划分得到带不同目标类别标注的若干个数据集;利用不同的所述数据集分别对构建的若干个神经网络进行不同目标检测训练,以得到与所述数据集数量相等的用于检测不同类别目标的若干个老师网络。7.一种多类别目标检测方法,其特征在于,包括:将目标图像输入至通过权利要求1至6中任一项所述的训练方法得到的多类别目标检测模型中进行目标检测,得到所述目标图像中存在的所有类别目标的预测结果。8.一种多类别目标检测模型训练装置,其特征在于,包括:老师网络推理模块,用于将训练图片分别输入至预先训练的不同老师网络中进行目标检测,输出每个老师网络中不同检测阶段的参考特征图;其中,不同老师网络用于对不同类别的目标进行检测;所述训练图片中包含多个目标类别且带有部分目标类别标注;正负样本获取模块,用于基于不同检测阶段的所述参考特征图,得到所述训练图片在对应类别目标检测中的正负样本数据;学生网络训练模块,用于将所述训练图片输入至学生网络中进行多类别目标检测,输出在各个所述检测阶段的预测特征图;损失计算模块,用于基于所述正负样本数据、所述参考特征图和所述预测特征图,计算通过所述学生网络进行多类别目标检测时的损失值;所述学生网络训练模块,还用于利用所述损失值进行反向传播以继续训练所述学生网络,直至满足预设训练停止条件,得到多类别目标检测模型。9.一种终端设备,其特征在于,所述终端设备包括处理器和存储器,所述存储器存储有计算机程序,所述处理器用于执行所述计算机程序以实施权利要求1-6中任一项所述的多类别目标检测模型训练方法或权利要求7所述的多类别目标检测方法。10.一种可读存储介质,其特征在于,其存储有计算机程序,所述计算机程序在处理器上执行时,实施根据权利要求1-6中任一项所述的多类别目标检测模型训练方法或权利要求7所述的多类别目标检测方法。
技术总结
本申请涉及目标检测技术领域,提供了一种多类别目标检测方法及其模型训练方法、装置,该方法通过预先训练得到不同的老师网络,并利用每个老师网络对包含多个目标类别且仅带有部分目标类别标注的训练图片中的对应类别目标进行检测,以得到不同检测阶段的参考特征图;之后,利用所有老师网络得到的这些参考特征图,对需要训练的学生网络所输出的相应预测特征图计算损失值,最后利用损失值进行反向传播训练,从而得到多类别目标检测模型。该方法可以利用标注不完整的数据集来进行多类别目标检测模型的训练,可提升训练效果和效率等。可提升训练效果和效率等。可提升训练效果和效率等。
技术研发人员:胡淑萍 王侃 董培 庞建新 谭欢
受保护的技术使用者:深圳市优必选科技股份有限公司
技术研发日:2023.05.15
技术公布日:2023/8/9
版权声明
本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
航空之家 https://www.aerohome.com.cn/
飞机超市 https://mall.aerohome.com.cn/
航空资讯 https://news.aerohome.com.cn/