文本分类模型的训练方法及文本分类方法
未命名
08-07
阅读:80
评论:0

1.本发明涉及语义分析技术领域,尤其涉及一种文本分类模型的训练方法及文本分类方法。
背景技术:
2.文本分类是自然语言处理中的重要任务之一,其主要目标是将输入的文本归类到事先定义的一组预定义类别中。这个过程涉及到对文本的特征提取与选择,以及对分类器的训练和优化,从而实现对文本的自动分类和标注。
3.目前,传统的分类方法已经在自然语言处理中得到广泛应用,然而它们的简单计算模型并不能很好地处理高随机性和大数据背景下的分类任务,从而难以保证文本分类的精度和效率。
4.基于此,如何有效提高文本分类的精度成为亟需解决的问题。
技术实现要素:
5.本发明实施例提供了一种文本分类模型的训练方法及文本分类方法,以解决现有技术中文本分类精度不高的问题。
6.第一方面,本发明实施例提供了一种文本分类模型的训练方法,包括:
7.获取训练集数据;其中,所述训练集数据包括文本数据以及与文本数据对应的文本分类标签;
8.基于训练集数据对预设transformer模型进行训练,得到中间transformer模型;
9.基于中间transformer模型的文本分类预测结果以及文本分类标签确定中间transformer模型的误差函数的梯度;
10.当所述梯度不满足预设梯度要求时,利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,直到梯度满足所述预设梯度要求或者模型训练达到最大迭代次数时,获得目标transformer模型。
11.在一种可能的实现方式中,所述利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,包括:
12.基于所述梯度和预设修正系数计算雅可比矩阵;
13.基于雅可比矩阵进行残差模块网络参数的权值修正,得到权值变化量;
14.基于权值变化量对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正。
15.在一种可能的实现方式中,所述基于所述梯度和预设修正系数计算雅可比矩阵,包括:
16.基于j=g
t
g+λdiag(g
t
g),计算雅可比矩阵;
17.其中,j表示所述雅可比矩阵,g表示所述梯度,g
t
表示所述梯度的转置矩阵,λ表示所述预设修正系数,diag(g
t
g)表示g
t
g的对角矩阵。
18.在一种可能的实现方式中,所述基于雅可比矩阵进行残差模块网络参数的权值修正,得到权值变化量,包括:
19.基于[j
t
(ωk)j(ωk)+μkdiag(j
t
(ωk)j(ωk))]δω=-j
t
(ωk)e(ωk),进行残差模块网络参数的权值修正,得到权值变化量;
[0020]
其中,j表示雅可比矩阵,j
t
表示雅克比矩阵的转置矩阵,ωk表示修正前残差模块网络参数的权值,μk表示控制梯度下降步长的阻尼因子,e(ωk)表示残差向量,δω表示所述权值变化量,diag(j
t
(ωk)j(ωk))表示j
t
(ωk)j(ωk)的对角矩阵,k表示第k步权值修正。
[0021]
在一种可能的实现方式中,所述基于权值变化量对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正,包括:
[0022]
基于ω
k+1
=ωk+δω,对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正;
[0023]
其中,ω
k+1
表示修正后残差模块网络参数的权值,ωk表示修正前残差模块网络参数的权值,δω表示权值变化量,k表示第k步权值修正。
[0024]
在一种可能的实现方式中,在获取训练集数据之后,还包括:
[0025]
对文本数据进行分词处理,得到词组集合;
[0026]
为所述词组集合中的每个单词添加相应的词性标签;
[0027]
将文本分类标签和添加词性标签后的文本数据转换为数值化的特征向量,得到特征向量集;
[0028]
对所述特征向量集进行归一化处理,得到归一化向量集;
[0029]
对归一化向量集进行聚类分析,得到文本特征数据集;
[0030]
所述基于训练集数据对预设transformer模型进行训练,得到中间transformer模型,包括:
[0031]
基于所述文本特征数据集对预设transformer模型进行训练,得到中间transformer模型。
[0032]
第二方面,本发明实施例提供了一种文本分类方法,包括:
[0033]
获取待分类文本数据;
[0034]
将所述待分类文本数据输入至目标transformer模型中,输出所述待分类文本数据对应的文本分类预测结果;其中,所述目标transformer模型基于如上第一方面或第一方面的任一种可能的实现方式所述的文本分类模型的训练方法训练得到。
[0035]
第三方面,本发明实施例提供了一种训练设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如上第一方面或第一方面的任一种可能的实现方式所述方法的步骤。
[0036]
第四方面,本发明实施例提供了一种文本分类设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如上第二方面所述方法的步骤。
[0037]
第五方面,本发明实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上第一方面、第二方面或第一方面的任一种可能的实现方式所述方法的步骤。
[0038]
本发明实施例提供一种文本分类模型的训练方法及文本分类方法,通过获取包括
文本数据以及与文本数据对应的文本分类标签的训练集数据,然后基于训练集数据对预设transformer模型进行训练,得到中间transformer模型;之后再进一步确定当前中间transformer模型的误差函数的梯度;当梯度不满足预设梯度要求时,即可确定当前中间transformer模型的预测精度还未达到要求,此时利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,然后通过不断修正残差模块网络参数,进而对中间transformer模型进行迭代训练,从而得到文本分类预测精度更高的模型;直到模型的梯度满足预设梯度要求或者模型训练达到预先设置的最大迭代次数时,即可获得目标transformer模型,从而可基于该目标transformer模型实现对文本数据的高精度分类。
附图说明
[0039]
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
[0040]
图1是本发明实施例提供的文本分类模型的训练方法的实现流程图;
[0041]
图2是本发明实施例提供的transformer模型的整体架构图;
[0042]
图3是本发明实施例提供的残差块的结构示意图;
[0043]
图4是本发明实施例提供的文本分类方法的实现流程图;
[0044]
图5是本发明实施例提供的文本分类模型的训练装置的结构示意图;
[0045]
图6是本发明实施例提供的文本分类装置的结构示意图;
[0046]
图7是本发明实施例提供的训练设备的示意图;
[0047]
图8是本发明实施例提供的文本分类设备的示意图。
具体实施方式
[0048]
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本发明实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本发明。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本发明的描述。
[0049]
为使本发明的目的、技术方案和优点更加清楚,下面将结合附图通过具体实施例来进行说明。
[0050]
相比于已被广泛应用于自然语言处理中的传统的分类方法,神经网络模型是一种基于人工神经元和神经层的方法,常用于文本序列的预测和分类。在文本分类中,可以利用神经网络模型构建分层神经网络,以对词汇位置或文本序列等进行分类。与传统的分类方法相比,神经网络模型方法具有更高的分类精度,并且对于各种外部影响因素具有更好的鲁棒性。基于此,本发明实施例提供了一种文本分类模型的训练方法,图1为本发明实施例提供的文本分类模型的训练方法的实现流程图。如图1所示,该训练方法包括:
[0051]
步骤101:获取训练集数据;其中,训练集数据包括文本数据以及与文本数据对应的文本分类标签。
[0052]
在步骤101中,在神经网络模型训练之前,需要获取训练集数据。获取的训练集数
据可以包括:文本数据以及与文本数据对应的文本分类标签。示例性的,文本数据可以是短信文本、邮件文本等。例如,文本数据可以是spambase data set垃圾邮件数据集、用于异常短信识别的sms spam collection数据集等,本技术对此不作限定。而对于训练集数据中与文本数据对应的文本分类标签,示例性的,对于邮件文本数据而言,对应的文本分类标签可以为:1或0;其中,“1”用于表征垃圾邮件,“0”用于表征非垃圾邮件。对于短信文本数据而言,对应的文本分类标签可以为:abnormal(异常的)或normal(正常的);其中,“abnormal”用于表征异常短信,“normal”用于表征正常短信。可选的,文本分类标签可以视具体情况具体设定,本技术仅以上述两种表征方式对本技术实施例作出解释说明,而非限定。
[0053]
可选的,文本分类标签可以添加在每一条文本数据的开头、结尾或某一特定位置,本技术对此不作限定。对于文本数据的样本数据量和所包含的特征数量,可以在实际进行神经网络模型训练时,根据模型的具体训练情况确定。
[0054]
在一种可能的实现方式中,在获取训练集数据之后,还包括:
[0055]
对文本数据进行分词处理,得到词组集合。
[0056]
为词组集合中的每个单词添加相应的词性标签。
[0057]
将文本分类标签和添加词性标签后的文本数据转换为数值化的特征向量,得到特征向量集。
[0058]
对特征向量集进行归一化处理,得到归一化向量集。
[0059]
对归一化向量集进行聚类分析,得到文本特征数据集。
[0060]
基于训练集数据对预设transformer模型进行训练,得到中间transformer模型,包括:
[0061]
基于文本特征数据集对预设transformer模型进行训练,得到中间transformer模型。
[0062]
本实施例中,在获取到训练集数据之后,可以首先对训练集数据进行一系列预处理,以便使得神经网络模型可以更好地进行训练。对于连续的文本数据(例如,包含两个以上中文词语单元或英文词语单元等的文本数据),可以对这种类型的文本数据进行分词处理。也即,将其分割成单个的词语单元,得到对应的词组集合,以便于进行后续的特征提取和计算。示例性的,对于中文文本数据,可以利用jieba分词、pkuseg分词等中文分词库进行分词处理。
[0063]
可选的,在分词处理之后,可以对词组集合进行一次去噪声处理。进行去噪声处理的目的是为了清除文本中的无意义信息,如标点符号、停用词等,从而可以有效减少文本数据的特征量(也即,去除文本数据中一些不必要的特征量),进而有利于提高后续文本分类效果。
[0064]
可选的,还可以为词组集合中的每个单词添加相应的词性标签(例如:人称代词、动词、形容词或副词等),以便于后续神经网络模型可以更好地理解和处理文本,进行准确的特征选择和特征提取。
[0065]
可选的,在为文本数据中的各个单词添加了相应的词性标签之后,可以对此时的文本数据和文本分类标签进行特征提取,以将此时的文本数据和文本分类标签转换为数值化的特征向量进行表示,得到特征向量集,从而便于后续神经网络模型能够更好地理解处理文本数据。示例性的,可以基于tf-idf、词袋模型、word2vec等方法进行特征提取,所提取
的文本数据的特征数量可以视实际情况确定,本技术对此不作限定。
[0066]
可选的,在得到特征向量集之后,可以进一步对其进行数据清洗处理。进行数据清洗处理的目的是为了清理掉无效、错误、重复或不符合要求的文本数据,从而能够保证最终训练得到的神经网络模型的文本分类准确性和可信度。
[0067]
可选的,对文本数据进行数据清洗之后,还可以对此时的文本数据进行归一化处理,以将文本数据的值域缩放到一个统一的范围内,从而便于后续神经网络模型的高效训练。
[0068]
可选的,在对文本数据进行归一化处理之后,可以进一步对此时的文本数据进行聚类分析,以降低其数据维度。如此一来,既能大幅降低后续模型处理数据的运算量,又能有效解决模型可能因各类文本数据的数目不一致而导致的过拟合问题。
[0069]
可选的,可以采用k-means聚类方法对文本数据进行聚类分析,详述如下:
[0070]
从当前的文本数据中选出k
p
个文本特征,并将这k
p
个文本特征作为初始聚类中心。
[0071]
将每个文本特征数据划分到距离其最近的初始聚类中心点处,采用欧氏距离计算各个文本特征数据到达初始聚类中心点的距离。示例性的,空间中各文本特征数据与初始聚类中心点之间的欧氏距离的计算公式为:
[0072][0073]
其中,dis表示空间中各文本特征数据与初始聚类中心点之间的欧氏距离,m表示文本特征数据的个数,cj表示第j个文本特征数据,di表示第i个初始聚类中心点,1≤i≤k
p
且i为整数。
[0074]
计算每类文本特征数据到聚类中心的距离的平均值,再次分配每个文本特征数据到距离其最近的聚类中心。不断重复该过程,直到所有的文本特征数据都不再被分配或是已达到最大的迭代次数(例如,设定一个最大迭代次数,100、150等),则迭代结束,获得聚类分析后的文本特征数据集。
[0075]
可选的,在得到文本特征数据集之后,可以基于该文本特征数据集对预设transformer模型进行训练,得到中间transformer模型。
[0076]
可选的,在得到文本特征数据集之后,且在模型训练之前,还可以从该文本特征数据集中取出一定比例的数据作为后续的验证集数据和测试集数据,以便于后续对模型进行训练时,可以直接利用处理划分好的验证集数据和测试集数据来实现模型的参数调节以及误差的测试验证等。
[0077]
本实施例中,通过对原始的训练集数据进行分词、去噪声、词性标注、特征提取、数据清洗、归一化处理以及聚类分析等多个预处理步骤之后,可以得到质量较好且轻量化的数据集。如此一来,不仅便于后续模型的训练,还有利于提高模型的训练效率和文本分类精度。
[0078]
步骤102:基于训练集数据对预设transformer模型进行训练,得到中间transformer模型。
[0079]
在步骤102中,可以基于训练集数据对预设transformer模型进行训练,得到中间transformer模型。可选的,可以基于上述经一系列预处理之后得到的文本特征数据集对预
设transformer模型进行训练,得到中间transformer模型。如此一来,有利于提升中间transformer模型的训练效率和后续的分类精度。
[0080]
可选的,对于transformer模型,图2为本发明实施例提供的transformer模型的整体架构图。如图2所示,transformer模型的内部结构主要包含四个部分:输入部分、编码部分、解码部分以及输出部分。
[0081]
对于模型的输入部分,通过为文本数据中每个单词的词向量添加相应的位置编码,从而为模型提供当前时间步的前后出现顺序的信息。对于生成位置编码,可以采用不同频率的正弦和余弦函数,生成一个与单词的词向量维度一致的位置向量,最后将其加到对应单词的词向量上即可得到对应单词的输入表示。其中,位置编码(positional encoding,pe)的计算公式可参考下式:
[0082][0083][0084]
其中,pos表示单词在句子中的位置,d
model
表示词向量的维度,i表示词向量中的第i维,2i表示偶数维度,2i+1表示奇数维度。
[0085]
而对于编码部分,由6个encoder block堆叠而成,每个encoder block的输入都是由词向量组成的矩阵,且每个encoder block是由一个多头注意力模块(multi-head attention)和全连接神经网络(feed forward)构成的。其中,multi-head attention计算按照不同的头数分割输入,然后进行注意力计算,最后将各头的输出合并作为该层的输出。具体地,词向量矩阵首先经过变换,得到计算注意力值所需的q(query)、k(keys)、v(values)矩阵,然后即可基于这些矩阵进行注意力值的计算。
[0086]
对于计算注意力值:
[0087]
利用q矩阵和k矩阵计算文本数据中各个单词之间的相关性得分。可以采用点积法,即用q中的每一个向量与k中的每一个向量计算点积,从而得到每条文本数据中各个单词之间的相关性得分。这样就可以以一种有效的方式对输入序列进行编码,使其适用于各种自然语言处理任务。具体到矩阵的形式如下:
[0088]
score=q
·kt
。
[0089]
其中,score表示文本数据中各个单词之间的相关性得分,k
t
表示k矩阵的转置矩阵。
[0090]
然后,将各个相关性得分进行归一化处理,得到各个单词之间的得分向量。如此一来,能够保证模型训练时梯度的稳定。归一化处理具体可如下:
[0091][0092]
其中,score'表示归一化处理之后文本数据中各个单词之间的相关性得分,也即得分向量;dk表示k矩阵的维度。
[0093]
然后,再通过softmax函数,将各个单词之间的得分向量转换成[0,1]之间的概率分布,得到各个单词之间的概率分布情况。
[0094]
然后,根据各个单词之间的概率分布情况,乘以对应的v矩阵,即可得到z维的注意力矩阵。
[0095]
z=softmax(score')
·
v。
[0096]
其中,z表示注意力矩阵。
[0097]
而对于经过多头注意力模块得到的多头注意力矩阵,是注意力机制对输入序列编码之后的结果,该矩阵是通过将多组注意力矩阵z进行拼接得到的。
[0098]
在得到多头注意力矩阵之后,为了加强模型的表达能力和减轻模型的训练难度,可以在多头注意力矩阵的基础上加入残差神经网络。图3为本发明实施例提供的残差块的结构示意图,图3中,x表示残差块的输入向量,其包含了输入残差块的特征映射。w(也即w1和w2)为残差块的权重矩阵,是可学习的参数。b(b1和b2)均为残差块的偏移量向量,也是可学习的参数。f1(x)表示第一层卷积层,经过该卷积层后得到的是:经过第一层线性变化并激活后的输出向量。f2(x)表示第二层卷积层,经过该卷积层后得到的是:经过第二层线性变化后的输出向量。y表示残差块的输出向量,该输出向量是输入向量x和网络权重参数w加上偏移量b所计算得到的。
[0099]
在残差块模块中,首先传递输入向量x,然后将x输入到两个卷积层中,通过这些卷积层,网络会学习到一种特征表示。然后将输入残差块的x与卷积层的输出进行加和,得到自适应残差向量。然后将自适应残差向量传递到残差块的下一层卷积层中,再次执行残差块计算流程,最后,最终得到的残差块的输出y即可在整个神经网络中被用来做文本分类。如图3所示,示例性的,使用add操作将残差块x(也即矩阵x)与多头注意力矩阵相加,从而得到残差向量。这个残差向量的作用在于:既能保留输入序列的信息,同时又能够避免在深度神经网络训练中出现梯度消失或爆炸的问题。然后,采用ln(layer normalization)对多头注意力矩阵进行归一化处理,得到归一化的注意力矩阵。与bn(batch normalization)不同的是,ln是在同一个样本中的不同神经元之间进行归一化。这样做能够更好地保留每个神经元的特征信息,并且能够减少数据批次之间的差异,提高模型的泛化能力。最终,经过add&normalize操作后,会得到高质量的特征表示,使得模型能够更加准确地对输入序列进行文本分类。
[0100]
对于全连接神经网络,全连接层公式采用如下:
[0101]
ffn(x)=max(0,xw1+b1)w2+b2。
[0102]
其中,x表示全连接层的输入向量,w1表示输入到隐藏层的权重矩阵,b1表示隐藏层的偏置向量,w2表示隐藏层到输出层的权重矩阵,b2表示输出层的偏置向量,max()表示relu激活函数。max(0,xw1+b1)是指将输入向量x加权之后加上偏置向量b1,然后使用relu函数进行激活。
[0103]
经过add&normalize处理后,输入到下一个encoder中,连续经过6个encoder block的处理后,输入序列被编码成为一系列高维特征向量。这些高维特征向量有助于模型在后续训练过程中更好地理解和推断输入序列的信息。
[0104]
相应地,在经过6个encoder block的处理后,处理结果会输入至decoder block中。而对于解码部分,由6个decoder block堆叠而成。每个decoder block同样是由multi-head attention和feed forward两个子层构成的。不同的是,decoder block的第一个multi-head attention是基于encoder block的输出计算的,因此,其输入是该层先前输出
的结果,keys和values则是最后一层encoder block的输出。而第二个multi-head attention则是基于masked multi-head attention计算的,其输入中的query来自于上一层的输出,keys和values仍然来自于最后一层encoder的输出。通过这种方式,decoder block可以比较好地理解输入序列和输出序列之间的关系,并对下一步的输出进行预测。
[0105]
解码部分中,在经过第二个multi-head attention后,再一次使用与encoder block中相同的feed forward network对向量进行处理。接着,将处理后的向量传递到下一个decoder block中进行下一步的处理。这个过程同样会重复进行6次,直到最后的输出层。
[0106]
在输出层,也即输出部分,会首先进行一次线性变换操作,然后使用softmax函数将得到的输出向量转换为概率分布。最后,通过与预先定义好的词典进行比对,可以得到概率最大的对应单词,从而作为分类结果进行输出。如此一来,训练集数据经过transformer网络的输入、编码、解码以及输出部分的一系列处理训练,即可得到中间transformer模型。
[0107]
步骤103:基于中间transformer模型的文本分类预测结果以及文本分类标签确定中间transformer模型的误差函数的梯度。
[0108]
在步骤103中,在得到中间transformer模型后,可以基于中间transformer模型的文本分类预测结果和文本分类标签来确定该模型的误差函数的梯度。可选的,梯度的计算原理可详述为:
[0109]
在计算误差函数的梯度时,模型中的每个模块(例如,编码部分的每个encoder block层、解码部分的每个decoder block层等)都是由多个子层组成的。而对于每个子层(例如,多头注意力模块、前馈神经网络层等),它都有其特定的运算过程。在计算误差函数对模型中每个参数的梯度时,可以使用反向传播算法来自动计算每个模块中各个子层对误差函数的贡献,将所有子层对误差函数的贡献累加起来,即可得到整个模块对每个参数的梯度。在前向传播中,模型根据输入数据进行一系列运算,最终输出文本分类结果。在反向传播中,模型根据误差函数对文本分类结果的差异进行反向传播,计算模型中每个参数对误差函数的梯度,并更新参数。对于每个参数,反向传播算法会计算误差函数对其输出值的梯度、输入值的梯度以及参数的梯度。然后,通过链式法则将这些梯度相乘得到误差函数对参数的梯度,从而得到误差函数的梯度。如此一来,后续便可以使用这个梯度更新模型中的参数,由此优化模型性能。
[0110]
步骤104:当梯度不满足预设梯度要求时,利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,直到梯度满足预设梯度要求或者模型训练达到最大迭代次数时,获得目标transformer模型。
[0111]
在步骤104中,当误差函数的梯度不满足预设梯度要求时,也即基于当前模型的文本分类性能还不足以高质量、高精度地完成文本分类任务时,可以利用改进lm(levenberg-marquardt,拉文伯格-马夸特)算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正。对于模型的每一轮训练,都基于改进lm算法对残差模块网络参数进行修正,以确保每一轮模型训练时的参数都是当前阶段的最佳参数。改进lm算法是一种非线性优化算法,主要用于解决神经网络训练过程中震荡幅度较大的问题。本实施例中,通过采用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行调整以优化模型的误差函数。优化残差模块网络参数主要是为了更好地解决梯度消失的问题以及权重矩阵的退化问题,以使得编码器能够更好地适应给定的自然语言处理任务,从而提高最终目标
transformer模型泛化能力和分类精度。
[0112]
可选的,当满足模型训练的停止条件时,参数修正结束。也即当梯度满足预设梯度要求(也即误差函数收敛)或者模型训练达到最大迭代次数时,获得目标transformer模型。
[0113]
在一种可能的实现方式中,利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,包括:
[0114]
基于梯度和预设修正系数计算雅可比矩阵。
[0115]
基于雅可比矩阵进行残差模块网络参数的权值修正,得到权值变化量。
[0116]
基于权值变化量对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正。
[0117]
本实施例中,可以基于预设修正系数和误差函数的梯度计算得到雅可比矩阵。然后,基于该雅可比矩阵对残差模块网络参数的权值进行修正,从而得到权值变化量。然后,再根据该权值变化量修正残差模块网络参数的权值,得到优化后的新的权值。将新的权值应用到神经网络模型中进行训练,经过多次权值修正迭代即可得到满足模型误差精度要求的最优权值。如此一来,将最优权值再应用到神经网络模型中进行训练,即可得到高精度的文本分类模型。
[0118]
在一种可能的实现方式中,基于梯度和预设修正系数计算雅可比矩阵,包括:
[0119]
基于j=g
t
g+λdiag(g
t
g),计算雅可比矩阵。
[0120]
其中,j表示雅可比矩阵,g表示梯度,g
t
表示梯度的转置矩阵,λ表示预设修正系数,diag(g
t
g)表示g
t
g的对角矩阵。
[0121]
本实施例中,基于上式可得出:当λ较小时,g
t
g的贡献较大,此时雅可比矩阵j的主要成分为g
t
g,算法的收敛速度较快,模型的泛化能力较高;而当λ较大时,diag(g
t
g)的贡献较大,此时雅可比矩阵j的主要成分为diag(g
t
g),算法能够在误差函数的平稳区域内寻找误差最小值,避免出现网络震荡。可选的,λ的取值范围为[0,1]。可选的,本实施例中,为了在提高神经网络模型的泛化能力以及加快算法收敛速度的前提下,兼顾避免出现网络震荡的问题,因此,可以将λ取值为0.5。
[0122]
在一种可能的实现方式中,基于雅可比矩阵进行残差模块网络参数的权值修正,得到权值变化量,包括:
[0123]
基于[j
t
(ωk)j(ωk)+μkdiag(j
t
(ωk)j(ωk))]δω=-j
t
(ωk)e(ωk),进行残差模块网络参数的权值修正,得到权值变化量。
[0124]
其中,j表示雅可比矩阵,j
t
表示雅克比矩阵的转置矩阵,ωk表示修正前残差模块网络参数的权值,μk表示控制梯度下降步长的阻尼因子,e(ωk)表示残差向量,δω表示权值变化量,diag(j
t
(ωk)j(ωk))表示j
t
(ωk)j(ωk)的对角矩阵,k表示第k步权值修正。
[0125]
本实施例中,在计算出雅可比矩阵之后,根据雅可比矩阵对残差模块网络参数的权值进行修正,计算得到权值变化量。其中,j(ωk)表示基于修正前残差模块网络参数的权值所确定的雅可比矩阵。j
t
(ωk)表示基于修正前残差模块网络参数的权值所确定的雅可比矩阵的转置矩阵。
[0126]
在一种可能的实现方式中,基于权值变化量对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正,包括:
[0127]
基于ω
k+1
=ωk+δω,对中间transformer模型中编码器的残差模块网络参数的
权值进行迭代修正。
[0128]
其中,ω
k+1
表示修正后残差模块网络参数的权值,ωk表示修正前残差模块网络参数的权值,δω表示权值变化量,k表示第k步权值修正。
[0129]
本实施例中,在计算得到权值变化量之后,可以基于该权值变化量修正残差模块网络参数的权值。然后,根据修正后的权值,继续训练中间transformer模型,并计算训练得到的模型的文本分类预测结果(模型实际输出值)和文本分类标签(理论输出值)之间的误差,若误差在最大可允许累计误差范围内,则认为当前模型的训练精度已满足文本分类精度要求,此时即可得到目标transformer模型。
[0130]
可选的,为评估模型预测的精度,可以选取精确率、召回率以及f1测试度等多项评估指标,来分析评估模型的文本分类效果。
[0131]
本发明实施例提供一种文本分类模型的训练方法,通过获取包括文本数据以及与文本数据对应的文本分类标签的训练集数据,然后基于训练集数据对预设transformer模型进行训练,得到中间transformer模型;之后再进一步确定当前中间transformer模型的误差函数的梯度;当梯度不满足预设梯度要求时,即可确定当前中间transformer模型的预测精度还未达到要求,此时利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,然后通过不断修正残差模块网络参数,进而对中间transformer模型进行迭代训练,从而得到文本分类预测精度更高的模型;直到模型的梯度满足预设梯度要求或者模型训练达到预先设置的最大迭代次数时,即可获得目标transformer模型,从而可基于该目标transformer模型实现对文本数据的高精度分类。
[0132]
图4为本发明实施例提供的文本分类方法的实现流程图。如图4所示,本发明实施例提供了一种文本分类方法,包括:
[0133]
步骤401:获取待分类文本数据。
[0134]
在步骤401中,获取待分类文本数据。待分类文本数据可以是短信文本数据、邮件数据等,本技术对此不作限定。
[0135]
步骤402:将待分类文本数据输入至目标transformer模型中,输出待分类文本数据对应的文本分类预测结果;其中,目标transformer模型基于如上第一方面或第一方面的任一种可能的实现方式的文本分类模型的训练方法训练得到。
[0136]
在步骤402中,将待分类文本数据输入至目标中,经过模型的一系列处理,输出待分类文本数据所对应的文本分类预测结果。该目标transformer模型是基于上述文本分类模型的训练方法训练得到。
[0137]
本发明实施例提供一种文本分类方法,通过将获取的待分类文本数据输入至目标transformer模型中,即可得到模型输出的高精度的文本分类预测结果。
[0138]
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
[0139]
以下为本发明的装置实施例,对于其中未详尽描述的细节,可以参考上述对应的方法实施例。
[0140]
图5为本发明实施例提供的文本分类模型的训练装置的结构示意图,为了便于说明,仅示出了与本发明实施例相关的部分,详述如下:
[0141]
如图5所示,文本分类模型的训练装置5包括:
[0142]
数据获取模块501,用于获取训练集数据;其中,训练集数据包括文本数据以及与文本数据对应的文本分类标签。
[0143]
中间模型训练模块502,用于基于训练集数据对预设transformer模型进行训练,得到中间transformer模型。
[0144]
梯度计算模块503,用于基于中间transformer模型的文本分类预测结果以及文本分类标签确定中间transformer模型的误差函数的梯度。
[0145]
目标模型训练模块504,用于当梯度不满足预设梯度要求时,利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,直到梯度满足预设梯度要求或者模型训练达到最大迭代次数时,获得目标transformer模型。
[0146]
本发明实施例提供一种文本分类模型的训练装置,该装置包括:数据获取模块501、中间模型训练模块502、梯度计算模块503以及目标模型训练模块504。通过获取包括文本数据以及与文本数据对应的文本分类标签的训练集数据,然后基于训练集数据对预设transformer模型进行训练,得到中间transformer模型;之后再进一步确定当前中间transformer模型的误差函数的梯度;当梯度不满足预设梯度要求时,即可确定当前中间transformer模型的预测精度还未达到要求,此时利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,然后通过不断修正残差模块网络参数,进而对中间transformer模型进行迭代训练,从而得到文本分类预测精度更高的模型;直到模型的梯度满足预设梯度要求或者模型训练达到预先设置的最大迭代次数时,即可获得目标transformer模型,从而可基于该目标transformer模型实现对文本数据的高精度分类。
[0147]
在一种可能的实现方式中,中间模型训练模块502具体用于:
[0148]
对文本数据进行分词处理,得到词组集合。
[0149]
为词组集合中的每个单词添加相应的词性标签。
[0150]
将文本分类标签和添加词性标签后的文本数据转换为数值化的特征向量,得到特征向量集。
[0151]
对特征向量集进行归一化处理,得到归一化向量集。
[0152]
对归一化向量集进行聚类分析,得到文本特征数据集。
[0153]
基于文本特征数据集对预设transformer模型进行训练,得到中间transformer模型。
[0154]
在一种可能的实现方式中,目标模型训练模块504具体用于:
[0155]
基于梯度和预设修正系数计算雅可比矩阵。
[0156]
基于雅可比矩阵进行残差模块网络参数的权值修正,得到权值变化量。
[0157]
基于权值变化量对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正。
[0158]
在一种可能的实现方式中,目标模型训练模块504还具体用于:
[0159]
基于j=g
t
g+λdiag(g
t
g),计算雅可比矩阵。
[0160]
其中,j表示雅可比矩阵,g表示梯度,g
t
表示梯度的转置矩阵,λ表示预设修正系数,diag(g
t
g)表示g
t
g的对角矩阵。
[0161]
在一种可能的实现方式中,目标模型训练模块504还具体用于:
programmable gate array,fpga)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
[0176]
所述存储器71可以是所述训练设备7的内部存储单元,例如训练设备7的硬盘或内存。所述存储器71也可以是所述训练设备7的外部存储设备,例如所述训练设备7上配备的插接式硬盘,智能存储卡(smart media card,smc),安全数字(secure digital,sd)卡,闪存卡(flash card)等。进一步地,所述存储器71还可以既包括所述训练设备7的内部存储单元也包括外部存储设备。所述存储器71用于存储所述计算机程序以及所述训练设备所需的其他程序和数据。所述存储器71还可以用于暂时地存储已经输出或者将要输出的数据。
[0177]
图8为本发明实施例提供的文本分类设备的示意图。如图8所示,该实施例的文本分类设备8包括:处理器80、存储器81以及存储在所述存储器81中并可在所述处理器80上运行的计算机程序82。所述处理器80执行所述计算机程序82时实现上述各个文本分类方法实施例中的步骤,例如图4所示的步骤401至步骤402。或者,所述处理器80执行所述计算机程序82时实现上述各装置实施例中各模块的功能,例如图6所示模块601至602的功能。
[0178]
示例性的,所述计算机程序82可以被分割成一个或多个模块/单元,所述一个或者多个模块/单元被存储在所述存储器81中,并由所述处理器80执行,以完成本发明。所述一个或多个模块/单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述所述计算机程序82在所述文本分类设备8中的执行过程。例如,所述计算机程序82可以被分割成图6所示的模块601至602。
[0179]
所述文本分类设备8可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述文本分类设备8可包括,但不仅限于,处理器80、存储器81。本领域技术人员可以理解,图8仅仅是文本分类设备8的示例,并不构成对文本分类设备8的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如所述训练设备还可以包括输入输出设备、网络接入设备、总线等。
[0180]
所称处理器80可以是中央处理单元(central processing unit,cpu),还可以是其他通用处理器、数字信号处理器(digital signal processor,dsp)、专用集成电路(application specific integrated circuit,asic)、现场可编程门阵列(field-programmable gate array,fpga)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
[0181]
所述存储器81可以是所述文本分类设备8的内部存储单元,例如文本分类设备8的硬盘或内存。所述存储器81也可以是所述文本分类设备8的外部存储设备,例如所述文本分类设备8上配备的插接式硬盘,智能存储卡(smart media card,smc),安全数字(secure digital,sd)卡,闪存卡(flash card)等。进一步地,所述存储器81还可以既包括所述文本分类设备8的内部存储单元也包括外部存储设备。所述存储器81用于存储所述计算机程序以及所述训练设备所需的其他程序和数据。所述存储器81还可以用于暂时地存储已经输出或者将要输出的数据。
[0182]
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的
功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本技术的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
[0183]
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
[0184]
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
[0185]
在本发明所提供的实施例中,应该理解到,所揭露的装置/训练设备/文本分类设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/训练设备/文本分类设备实施例仅仅是示意性的,例如,所述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
[0186]
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
[0187]
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
[0188]
所述集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个文本分类模型的训练方法或文本分类方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、u盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(read-only memory,rom)、随机存取存储器(random access memory,ram)、电载波信号、电信信号以及软件分发介质等。需要说明的是,所述计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括是电载波信号和电信信号。
[0189]
以上所述实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
技术特征:
1.一种文本分类模型的训练方法,其特征在于,包括:获取训练集数据;其中,所述训练集数据包括文本数据以及与文本数据对应的文本分类标签;基于训练集数据对预设transformer模型进行训练,得到中间transformer模型;基于中间transformer模型的文本分类预测结果以及文本分类标签确定中间transformer模型的误差函数的梯度;当所述梯度不满足预设梯度要求时,利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,直到梯度满足所述预设梯度要求或者模型训练达到最大迭代次数时,获得目标transformer模型。2.根据权利要求1所述的文本分类模型的训练方法,其特征在于,所述利用改进lm算法对中间transformer模型中编码器的残差模块网络参数进行迭代修正,包括:基于所述梯度和预设修正系数计算雅可比矩阵;基于雅可比矩阵进行残差模块网络参数的权值修正,得到权值变化量;基于权值变化量对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正。3.根据权利要求2所述的文本分类模型的训练方法,其特征在于,所述基于所述梯度和预设修正系数计算雅可比矩阵,包括:基于j=g
t
g+λdiag(g
t
g),计算雅可比矩阵;其中,j表示所述雅可比矩阵,g表示所述梯度,g
t
表示所述梯度的转置矩阵,λ表示所述预设修正系数,diag(g
t
g)表示g
t
g的对角矩阵。4.根据权利要求2所述的文本分类模型的训练方法,其特征在于,所述基于雅可比矩阵进行残差模块网络参数的权值修正,得到权值变化量,包括:基于[j
t
(ω
k
)j(ω
k
)+μ
k
diag(j
t
(ω
k
)j(ω
k
))]δω=-j
t
(ω
k
)e(ω
k
),进行残差模块网络参数的权值修正,得到权值变化量;其中,j表示雅可比矩阵,j
t
表示雅克比矩阵的转置矩阵,ω
k
表示修正前残差模块网络参数的权值,μ
k
表示控制梯度下降步长的阻尼因子,e(ω
k
)表示残差向量,δω表示所述权值变化量,diag(j
t
(ω
k
)j(ω
k
))表示j
t
(ω
k
)j(ω
k
)的对角矩阵,k表示第k步权值修正。5.根据权利要求2至4中任一项所述的文本分类模型的训练方法,其特征在于,所述基于权值变化量对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正,包括:基于ω
k+1
=ω
k
+δω,对中间transformer模型中编码器的残差模块网络参数的权值进行迭代修正;其中,ω
k+1
表示修正后残差模块网络参数的权值,ω
k
表示修正前残差模块网络参数的权值,δω表示权值变化量,k表示第k步权值修正。6.根据权利要求1所述的文本分类模型的训练方法,其特征在于,在获取训练集数据之后,还包括:对文本数据进行分词处理,得到词组集合;为所述词组集合中的每个单词添加相应的词性标签;将文本分类标签和添加词性标签后的文本数据转换为数值化的特征向量,得到特征向
量集;对所述特征向量集进行归一化处理,得到归一化向量集;对归一化向量集进行聚类分析,得到文本特征数据集;所述基于训练集数据对预设transformer模型进行训练,得到中间transformer模型,包括:基于所述文本特征数据集对预设transformer模型进行训练,得到中间transformer模型。7.一种文本分类方法,其特征在于,包括:获取待分类文本数据;将所述待分类文本数据输入至目标transformer模型中,输出所述待分类文本数据对应的文本分类预测结果;其中,所述目标transformer模型基于权利要求1-6中任一项所述的文本分类模型的训练方法训练得到。8.一种训练设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如上的权利要求1至6中任一项所述方法的步骤。9.一种文本分类设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如上的权利要求7所述方法的步骤。10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如上的权利要求1至7中任一项所述方法的步骤。
技术总结
本发明提供一种文本分类模型的训练方法及文本分类方法。该方法包括:获取训练集数据;其中,训练集数据包括文本数据以及与文本数据对应的文本分类标签;基于训练集数据对预设模型进行训练,得到中间Transformer模型;基于中间Transformer模型的文本分类预测结果以及文本分类标签确定其误差函数的梯度;当梯度不满足预设梯度要求时,利用改进LM算法对中间Transformer模型中编码器的残差模块网络参数进行迭代修正,直到梯度满足预设梯度要求或者模型训练达到最大迭代次数时,获得目标Transformer模型。基于本发明提供的文本分类模型能够实现对文本数据的高精度分类。模型能够实现对文本数据的高精度分类。模型能够实现对文本数据的高精度分类。
技术研发人员:雷宇 屈可帅 王旭光 赵一凡 韩庆
受保护的技术使用者:石家庄铁道大学
技术研发日:2023.05.06
技术公布日:2023/8/6
版权声明
本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
航空之家 https://www.aerohome.com.cn/
飞机超市 https://mall.aerohome.com.cn/
航空资讯 https://news.aerohome.com.cn/
上一篇:信息处理方法、装置、终端和存储介质与流程 下一篇:轮辐轮圈的成型方法与流程