中文文本纠错_论文Spelling Error Correction with Soft-Masked BERT(ACL_2020)学习笔记与模型复现
最近在ACL 2020上看到一篇论文《Spelling Error Correction with Soft-Masked BERT》,论文的主题为中文文本纠错中的**Chinese spelling error correction (CSC)**任务,论文作者为来自字节跳动AI Lab与复旦大学的研究人员。
《Spelling Error Correction with Soft-Masked BERT》一文中主要提出了一种新的模型框架名为Soft-Masked BERT。Soft-Masked BERT模型框架中主要含有两部分模型, 一部分称之为错误探查网络Detection Network, 另一部分称之为纠错网络Correction Network。
错误探查网络Detection Network由一个双向的GRU模型组成,而纠错网络Correction Network则基于预训练的Bert模型构建。两种网络则通过一种名为Soft masking的方式连接, 即错误探查网络Detection Network的输出经过Soft masking之后再输入进纠错网络Correction Network进行计算。
接下来, 将对论文内容中从数据集构建、模型构建、训练过程、主要实验结果、讨论这五部分进行详细地阐述。
一、数据集构建
数据集的构建在整个《Spelling Error Correction with Soft-Masked BERT》论文框架中起着重要的作用。论文中一共构建了三种数据集,分别为SIGHAN数据集、News Title数据集与5 million news titles数据集。
(1) SIGHAN数据集
SIGHAN数据集为Chinese Spelling Check Task领域的一个benchmark数据集,数据集的链接如下:SIGHAN 2013 Bake-off: Chinese Spelling Check Task。
SIGHAN数据集中包含1100条文本,共有461种错误(spelling errors),这些文本都是从中文文章中收集,相对来说数据集的主题范围较窄。
SIGHAN数据集被分为了三部分,分别为:训练集(training set)、开发集(development set)、测试集(test set)。在这里训练集(training set)用于Soft-Masked BERT模型的fine-tuning,测试集(test set)用于检测模型的性能,而开发集(development set)则被用来对超参数进行调整(hyper-parameter tuning)。在这里,一种能够提升模型性能的方法是将SIGHAN数据集的训练集(training set)中可能存在的一些不包含错误的文本从中剔除(unchanged texts),这样模型进行fine-tuning的SIGHAN训练集(training set)中所有文本都是包含错误(spelling errors)的文本。
(2) News Title数据集
相比于SIGHAN数据集,News Title数据集是一个更大的数据集。News Title数据集中的文本都来自于今日头条app中的文章的标题部分,这些文本的内容涉及政治、娱乐、体育、教育等许多方面。作者为了确保News Title数据集中包含足够多的错误文本,特意从低质量的文本中抽样了15730条样本,所有样本中一共有5,423个样本的文本包含了拼写错误(spelling errors),错误的类型一共有3441种。
值得注意的是,这里News Title数据集仅被对半分成了两部分,分别为:开发集(development set)与测试集(test set)。News Title数据集的开发集也是用来对超参数进行调整(hyper-parameter tuning),News Title数据集的测试集也用于检测模型的性能。
(3) 5 million news titles数据集
5 million news titles数据集中的文本都是从一些中文新闻app中爬取下来的。
同时,作者在这里创建了同音字混淆表(confusion table)。在5 million news titles数据集的文本中,对15%的字符进行随机替换,这15%被随机替换的字符中,有80%的字符使用同音字混淆表confusion table中此字符的同音字符进行替换;而剩下20%的字符使用随机字符进行替换。以这种方式构建数据集再用于训练模型,能够让训练出的Soft-Masked BERT模型获得较强的同音字混淆错误的纠正能力。
需要特别注意的是,5 million news titles数据集只会被用来对模型进行fine-tuning。
例如实验中,在利用SIGHAN数据集的测试集(test set)来检测模型性能之前,模型会先在5 million news titles数据集上做一次fine-tuning,再在SIGHAN数据集的训练集(training set)上再做一次fine-tuning,最后才会使用SIGHAN数据集的测试集(test set)来检测模型性能。
再如,实验中在利用News Title数据集的测试集(test set)来检测模型性能之前,也会先在5 million news titles数据集上fine-tuning一次,才会在News Title数据集的测试集(test set)上检测模型性能。
因此可以看出,在5 million news titles数据集上进行fine-tuning在整个模型训练过程以及之后的性能检测中起到至关重要的作用。
二、模型构建
Soft-Masked BERT模型的机构如下图所示:
Soft-Masked BERT模型框架细分则可以被分为三部分:错误探查网络Detection Network、Soft Masking Connection、纠错网络Correction Network。
(1) 模型输入
整个模型框架的输入input embeddings是由文本句子中每一个字符的word embedding、position embedding、segment embedding三部分嵌入的加和embedding构成的。因此,可以看出Soft-Masked BERT模型框架的输入实际和Bert模型的一般输入形式相同。
在上式中, x i x_{i} xi表示一个文本序列中的第 i i i个字符, e i e_{i} ei表示第 i i i个字符经过三部分嵌入后的加和embedding表示(input embedding)。
(2) 错误探查网络Detection Network
Soft-Masked BERT模型框架中的错误探查网络Detection Network实质上为一个双向GRU模型(Bi-GRU)。Bi-GRU模型对每个文本序列进行正向与反向编码,再将最后一层隐藏层中文本序列的正向编码的隐藏状态与反向编码的隐藏状态横向合并,Bi-GRU模型的计算过程如下公式所示:
h i d h_{i}^{d} hid是文本序列中字符 i i i的嵌入 e i e_{i} ei在经过Bi-GRU模型计算后最后一层隐藏层中双向编码的隐藏状态。论文中Bi-GRU模型的隐藏层维度数设置为256,双向编码后的隐藏层输出维度数为512。
之后,Bi-GRU模型计算得出的 h i d h_{i}^{d} hid会被输入进两个全连接层中分别计算。
-
Detection Network二分类输出计算
Detection Network中Bi-GRU模型的输出 h i d h_{i}^{d} hid会被输入进一个全连接层中进行二分类学习。在计算整个Soft-Masked BERT模型的损失函数时,Detection Network与Correction Network各自的交叉熵损失值的带权加和,构成了Soft-Masked BERT模型损失函数的表示。
上式中, b b b为此全连接层的偏置项; W W W为此全连接层的权重矩阵, W W W将 h i d h_{i}^{d} hid映射到维度为2的空间中,再经过一层 s o f t m a x softmax softmax层之后,即可计算Detection Network的二分类输出的损失值。此处 P d ( y i = k │ X ) P_d (y_i=k│X) Pd(yi=k│X)表示错误探查网络Detection Network分类文本序列中每一个字符 x i x_{i} xi是否为拼写错误字符的二分类条件概率。 -
Soft Masking Connection系数计算
另一个全连接层用来计算Soft-Masked BERT模型中Soft Masking Connection处的系数 p i p_{i} pi,其计算过程如下公式所示:
其中, b d b_{d} bd表示此全连接层中的偏置项; W d W_{d} Wd表示此全连接层中的权重矩阵,其会将 h i d h_{i}^{d} hid映射到维度为1的空间中。此全连接层的输出会再被输入进Sigmoid层中,将值映射到(0,1)之间,这样经过Sigmoid层后输出的值 p i p_{i} pi就为Soft Masking Connection处的系数。
Soft Masking Connection为此篇论文的核心idea之一,其作用是利用计算得到的Soft Masking Connection的系数 p i p_{i} pi来对整个模型框架的输入input embeddings( e i e_{i} ei)与"mask特殊符"的嵌入mask embeddings( e m a s k e_{mask} emask)来做一个加权求和。Soft Masking Connection具体计算过程如下公式所示:
上式中, e i ′ e_{i}^{'} ei′即为模型框架的输入input embeddings( e i e_{i} ei)与"mask特殊符"的嵌入mask embeddings( e m a s k e_{mask} emask)通过Soft Masking Connection的系数 p i p_{i} pi进行加权求和后得到的soft-masked embedding。 e i ′ e_{i}^{'} ei′即表示文本序列中第 i i i个字符通过Soft Masking Connection计算得到的soft-masked embedding。
通过上方Soft Masking Connection的计算公式可以看出,如果文本序列中某字符 i i i的系数 p i p_{i} pi越接近于1,则表示经过错误探查网络Detection Network计算之后,此字符很可能会被Detection Network分类为拼写错误字符,因而其计算出的Soft Masking Connection的系数 p i p_{i} pi也会越接近于1;字符 i i i的系数 p i p_{i} pi越接近于1,则其计算得到的soft-masked embedding( e i ′ e_{i}^{'} ei′)也会越接近于"mask特殊符"的嵌入mask embedding( e m a s k e_{mask} emask)。
如果文本序列中某字符 i i i的系数 p i p_{i} pi越接近于0,则代表经过错误探查网络Detection Network的计算,此字符是拼写错误字符的可能性很小,因而其计算得到的soft-masked embedding( e i ′ e_{i}^{'} ei′)也会越接近于模型框架的输入input embedding( e i e_{i} ei)。
最后,计算出的每个字符 i i i的soft-masked embedding( e i ′ e_{i}^{'} ei′)会被输入进纠错网络Correction Network进行错误纠正。
(3) 纠错网络Correction Network
Soft-Masked BERT模型框架中的纠错网络Correction Network实际为Bert模型。传统的Bert模型中包含了12个Encoder层,每个Encoder层中都含有Multi-head Self Attention、LayerNormalization与Feed-forward Network;同时在Masked Language Model与Next Sentence Prediction两个任务上进行预训练。每个Encoder层的计算公式如下所示:
经过Soft Masking Connection过程计算后得到的文本序列中所有字符的soft-masked embedding( e i ′ e_{i}^{'} ei′)会被输入进纠错网络Correction Network的Bert中进行计算。
取Bert模型中最后一层Encoder的所有隐藏状态(hidden states) h i c h_i^c hic:
(4) Residual Connection残差连接与输出
将Bert模型中最后一层Encoder的所有隐藏状态的输出(hidden states) h i c h_i^c hic与模型框架的输入input embeddings( e i e_{i} ei)相加得到 h i ′ h_{i}^{'} hi′,这一步操作为Residual Connection(如下式所示)。
将经过残差连接(Residual Connection)之后得到的相加值 h i ′ h_{i}^{'} hi′输入进一层全连接层中,此全连接层会将 h i ′ h_{i}^{'} hi′由Bert模型中隐藏状态(hidden states)的768维映射到与候选词表(candidate list)中的词数相同维数数目的空间中,再将此全连接层映射后的输出输入进softmax函数中计算文本序列中字符 x i x_{i} xi被纠正为候选词表(candidate list)中的字符 j j j的条件概率;此计算过程如下公式所示:
上式中 W W W代表全连接层中的权重矩阵, b b b代表偏置项;而 P c ( y i = j ∣ X ) P_{c}(y_{i}=j|X) Pc(yi=j∣X)即代表文本序列中字符 x i x_{i} xi被纠正为候选词表(candidate list)中的字符 j j j的条件概率。
三、训练过程
(1) 损失函数
如上式所示, L d ℒ_{d} Ld表示错误探查网络Detection Network最后的输出值计算的交叉熵损失; L c ℒ_{c} Lc表示纠错网络Correction Network之后的输出值计算的交叉熵损失。
而整个Soft-Masked BERT模型的损失函数是由错误探查网络Detection Network的损失函数与纠错网络Correction Network的损失函数共同构成:
上式即为Soft-Masked BERT模型的损失函数表示。式子中 1 − λ 1-\lambda 1−λ与 λ \lambda λ为Detection Network的损失函数与Correction Network的损失函数的线性组合系数,即这两个网络各自损失函数的线性组合为最终Soft-Masked BERT模型的总损失函数。 λ \lambda λ系数为一个位于[0, 1]之间的数。
而在这里,系数 λ \lambda λ为一个超参数,对于每一个不同的数据集(如SIGHAN数据集或者News Title数据集),超参数 λ \lambda λ的最优值可能都是不同的。
一般来说,系数 λ \lambda λ取一个大于0.5的值更合适,而这意味着 1 − λ 1-\lambda 1−λ的值会更小。这么做的原因是系数 λ \lambda λ是 纠错网络Correction Network的损失函数前的线性组合系数,纠错网络Correction Network的损失函数为多分类损失函数,而错误探查网络Detection Network的损失函数仅为二分类损失函数;Correction Network的多分类学习任务明显更难,而Detection Network的二分类学习任务明显更简单,因此要在Soft-Masked BERT模型损失函数的线性组合表示中,给Correction Network的损失函数更大的权重,这样模型才能在更难的Correction Network的多分类学习任务上取得更好的效果,这样也会令整个Soft-Masked BERT模型的效果更好。
在上方的表格中可以看出,当 λ \lambda λ为0.8时,Detection Network与Correction Network的F1值达到了最高。因此,当 λ \lambda λ的值设定为0.8时,为一个较为理想的状态,此时Detection Network损失函数的线性组合系数为0.2,而Correction Network损失函数的线性组合系数为0.8,这也符合较为简单的二分类学习任务分配较小的权重,更难的多分类学习任务分配更大的权重的想法。
(2) 优化过程
在对Soft-Masked BERT模型进行fine-tuning时,使用Adam优化器对参数进行优化。作者在这里为了降低训练技巧对于模型效果的影响,并未使用诸如动态学习率调整(dynamic learning rate strategy)等其他训练技巧,而仅是将学习率(learning rate)固定为 2 e − 5 2e^{-5} 2e−5。
此外, b a t c h s i z e batch \space size batch size在训练过程中被设为了320。
四、主要实验结果
(1) Soft-Masked BERT模型与Baseline模型在两个数据的测试集上效果的对比结果
在实验结果的对比部分中,作者将论文中提出的Soft-Masked BERT模型与其他的一些Baseline模型一起做了效果的对比。
这些Baseline模型包括NTOU(a method of using an n-gram model and a rule-based classifier)、NCTU-NTUT(a method of utilizing word vectors and conditional random field)、HanSpeller++(an unified framework employing a hidden Markov model to generate candidates and a filter to re-rank candidates)、Hybrid(a BiLSTM-based model trained on a generated dataset)、Confusionset(a Seq2Seq model consisting of a pointer network and copy mechanism)、FASPell(adopts a Seq2Seq model for CSC employing BERT as a denoising autoencoder and a decoder)、BERT-Pretrain(the method of using a pre-trained BERT)、BERT-Finetune(the method of using a fine-tuned BERT)。
上表中是Soft-Masked BERT模型与其他的一些Baseline模型分别在SIGHAN数据集划分出的测试集与News Title数据集划分出的测试集上测试的结果。
可以看出,在SIGHAN数据集的测试集上,Soft-Masked BERT模型的效果比其他Baseline模型的效果要好不少。但这里在Detection部分,HanSpeller++模型的precision要高于Soft-Masked BERT模型的precision,且Correction部分HanSpeller++模型的precision也高于Soft-Masked BERT模型的precision,这是因为HanSpeller++模型中有许多人工添加的规则与特征,这些人工添加的规则与特征能在Detection部分很好地消去false detections,虽然人工规则与特征很有效,但这种方式的开发成本很高且可能泛化性能有欠缺,并且基于人工规则与特征的模型也无法与基于学习的模型(如Soft-Masked BERT模型)直接进行比较。因此,整体上看,在SIGHAN数据集的测试集上,Soft-Masked BERT模型的效果最优。
而在News Title数据集的测试集上,可以看出Soft-Masked BERT模型的效果也是最优的。
需要特别注意的是,在上表中,BERT-Pretrain模型效果很差,而BERT-Finetune模型的效果远好于BERT-Pretrain模型。因此在这里,可以看出fine-tuning对于进行CSC(Chinese spelling error correction)任务的模型的重要性,没有经过fine-tuning的BERT-Pretrain模型完全无法用于CSC任务中。
(2) 进行fine-tuning的5 million news titles数据集的数量逐渐增加后的模型效果对比
在上表中, T r a i n S e t Train Set TrainSet为只用来对模型进行fine-tuning的5 million news titles数据集。可以看出,随着fine-tuning的5 million news titles数据集的数量的逐渐增加,BERT-Finetune模型与Soft-Masked BERT模型的效果也在逐渐改善,而Soft-Masked BERT模型的效果依然优于BERT-Finetune模型的效果。此表不仅说明了fine-tuning对于进行CSC(Chinese spelling error correction)任务的模型的重要性,也说明了用来进行fine-tuning的数据集的规模越大,fine-tuning之后的模型在CSC任务中的效果也会越好。
(3) 消融对比研究(Ablation Study)
作者在这里,为了证明Soft-Masked BERT模型框架中所使用的Soft Masking Connection与 Residual Connection残差连接等方法的有效性,而进行了消融对比研究(Ablation Study),即从Soft-Masked BERT模型框架中单独去除某一方法,如去除Soft Masking Connection或者去除Residual Connection残差连接之后,再去验证模型的效果是否产生了明显的下降。若模型效果确实明显下降了,则证明此时从Soft-Masked BERT模型框架中去除的方法是能够显著改善Soft-Masked BERT模型框架的效果的,因此此方法的有效性得到了验证。
上表中,Soft-Masked BERT-R模型为将纠错网络Correction Network后的Residual Connection残差连接去除后的模型;而Rand-Masked BERT模型中则将Soft Masking Connection的系数由原先根据错误探查网络Detection Network的输出计算得到,变为随机从[0, 1]之间取值作为Soft Masking Connection的系数;而Hard-Masked BERT(0.7 / 0.9 / 0.95)模型中,在Soft Masking Connection部分,若文本序列中某个字符处计算得到的Soft Masking Connection系数大于0.7 / 0.9 / 0.95,则将此字符位置输入后方纠错网络Correction Network中进行计算的嵌入向量直接设为"mask特殊符"的嵌入mask embeddings( e m a s k e_{mask} emask),若此字符处计算得到的Soft Masking Connection系数小于等于0.7 / 0.9 / 0.95,则将此字符位置输入后方纠错网络Correction Network中进行计算的嵌入向量直接设为模型框架的输入input embeddings( e i e_{i} ei),因此此模型被命名为Hard-Masked BERT。
从上表中的结果可以看出,去除Soft-Masked BERT模型中的Residual Connection残差连接或者改变模型中的Soft Masking Connection方法,都会对模型的效果产生较为明显的改变,甚至显著降低了模型的效果。因此,Soft-Masked BERT模型要想取得较好的效果,其中的Residual Connection残差连接与Soft Masking Connection等方法都是必不可少的,这些方法的有效性得到了验证。
而表格中的BERT-Finetune+Force(Upper Bound)则为强制让BERT-Finetune模型只去针对文本序列中那些被错误探查网络Detection Network预测为拼写错误处的字符进行纠错的一种模型,其效果可以被看作是纠错效果所能达到的上界(upper bound)。因此,与BERT-Finetune+Force(Upper Bound)模型的效果一对比,可以看出Soft-Masked BERT模型的效果还有很大的改进空间。
五、讨论
经过实验,发现Soft-Masked BERT模型可以使用错误探查网络Detection Network更好地识别拼写错误的字符,而错误探查网络Detection Network也可以促使纠错网络Correction Network中的Bert模型更好地使用局部上下文信息(local context information)与全局上下文信息(global context information)来进行拼写错误纠正任务。相比于BERT-Finetune模型,Soft-Masked BERT模型可以更好地利用全局上下文信息(global context information)。
在结果中,发现数据集中有些拼写错误的数据样本需要模型有很强的推理能力(reasoning ability / inference ability)才能去纠正这些拼写错误,而Soft-Masked BERT模型以及其他的一些Baseline模型的推理能力则稍有欠缺。
除此之外,数据集中还有些拼写错误的数据样本需要模型具备一定的现实世界中的知识(world knowledge)才能去纠正这些拼写错误。如将‘青弋江’ (Qingyu River)拼写为‘青戈江’ (Qingge River)这一类的拼写错误,人或许可以一眼看出错误,但是模型需要结合一定的外部世界知识(world knowledge)才能更好地去纠正这一类的错误。
纵观整篇论文,个人认为论文中最突出的几个优点为:
(1) Soft Masking Connection这种方法可以促使纠错网络Correction Network中的Bert模型更好地使用局部上下文信息(local context information)与全局上下文信息(global context information)来进行拼写错误纠正任务。
(2) Soft-Masked BERT模型先在5 million news titles数据集(或者可能再加上SIGHAN数据集的训练集)上进行fine-tuning,再在两个数据集(SIGHAN与News Title)的测试集上测试模型效果。这种方法说明了fine-tuning对于进行CSC(Chinese spelling error correction)任务的模型的重要性,并且证明了用来进行fine-tuning的数据集的规模越大,fine-tuning之后的模型在CSC任务中的效果也会越好。
(3) 将Soft-Masked BERT模型框架中分为错误探查网络Detection Network与纠错网络Correction Network两部分,并且将两部分网络的损失函数进行线性合并联合训练的方法。
错误探查网络Detection Network的存在能大大促进纠错网络Correction Network的纠错效果。
最后,基于https://github.com/huggingface/transformers库中的Bert模型,复现的Soft-Masked BERT模型框架的代码如下所示:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import transformers
# 引入Bert的BertTokenizer与BertModel, 并单独取出BertModel中的词嵌入word_embeddings层
from transformers import BertConfig,BertModel, BertTokenizer
# 引入Bert模型的基础类BertEmbeddings, BertEncoder,BertPooler,BertPreTrainedModel
from transformers.modeling_bert import BertEmbeddings, BertEncoder,BertPooler,BertPreTrainedModel
'''Soft_Masked_BERT模型, Proposed in the Paper of ACL 2020: Spelling Error Correction with Soft-Masked BERT(2020_ACL)'''
class Soft_Masked_BERT(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
# self.config中包含了拼写错误纠正网络Correction_Network中的Bert模型的各种配置超参数.
self.config = config
'''一、构建错误探查网络Detection_Network中所需的网络层'''
# Bi-GRU网络作为错误探查网络Detection_Network的编码器
# 此处由于BertModel中的embeddings层中所有子嵌入模块的嵌入维度都为768, 所以此处Bi-GRU网络的input_size也为768,
# 而将Bi-GRU网络的hidden_size设为256,是为了保证Bi-GRU网络双向编码后双向隐藏层拼接到一块后隐藏层维度能保持在512.
# 此时enc_hid_size为512.
self.enc_bi_gru = torch.nn.GRU(input_size=768, hidden_size=256, dropout=0.2, bidirectional=True)
# 双向GRU编码层对于输入错误探查网络Detection_Network中的input_embeddings进行双向编码,
# 此时双向GRU编码层的输出为(seq_len, batch_size, enc_hid_size * 2),将其交换维度变形为(batch_size, seq_len, enc_hid_size * 2),
# 再将双向GRU编码层的变形后的输出输入self.detection_network_dense_out层中,映射为形状(batch_size, seq_len, 2)的张量,
# 这样方便后面进行判断句子序列中每一个字符是否为拼写错误字符的二分类任务的交叉熵损失值计算.
self.detection_network_dense_out = torch.nn.Linear(512, 2)
# 同时,将双向GRU编码层输出后经过变形的形状为(batch_size, seq_len, enc_hid_size * 2),的张量输入进soft_masking_coef_mapping层中,
# 将其形状映射为(batch_size, seq_len, 1)的张量,此张量再在后面输入进Sigmoid()激活函数中, 将此张量的值映射至(0,1)之间,
# 这样这个张量即变为了后面计算soft-masked embeddings时和mask_embeddings相乘的系数p (结果pi即可表示为文本序列中第i处的字符拼写错误的似然概率(likelihood)).
self.soft_masking_coef_mapping = torch.nn.Linear(512, 1)
'''二、构建的拼写错误纠正网络Correction_Network中BertModel中所用的个三种网络层'''
''' (1): 嵌入层BertEmbeddings(),其中包含了每个character的word embedding、segment embeddings、position embedding三种嵌入函数. (2): Bert模型的核心,多层(12层)多头自注意力(multi-head self attention)编码层BertEncoder. (3): Bert模型最后的池化层BertPooler. '''
# 嵌入层BertEmbeddings().
self.embeddings = BertEmbeddings(config)
# 多层(12层)多头自注意力(multi-head self attention)编码层BertEncoder.
self.encoder = BertEncoder(config)
# 池化层BertPooler。
self.pooler = BertPooler(config)
# 初始化权重矩阵,偏置等.
self.init_weights()
'''获取遮罩特殊符[MASK]在Bert模型的嵌入层BertEmbeddings()中的词嵌入层word_embeddings层中特殊符[MASK]所对应索引的嵌入向量(embeddins vector)'''
# 在Bert模型的tokenizer类BertTokenizer()的词表中,遮罩特殊符[MASK]会被编码为索引103(只要是BertTokenizer()类,无论其from_pretrained哪种
# 预训练的Bert模型词表,遮罩特殊符[MASK]在词表中的索引都为103; 除非换预训练模型如换成Albert模型,遮罩特殊符[MASK]在词表中的索引才会变, 否则
# 遮罩特殊符[MASK]在同一类预训练Bert模型的词表下索引不变).
# 在之后, 遮罩特殊符[MASK]的张量self.mask_embedding的形状要变为和Bert模型嵌入层BertEmbeddings()的输出input_embeddings张量的形状一样,
# 此时self.mask_embeddings张量的形状要为(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768).
self.mask_embeddings = self.embeddings.word_embeddings.weight[103] # 此时,mask_embedding张量的形状为(768,)
# 注意!: 在soft_masked_embeddings输入拼写错误纠正网络correction network中的Bert模型后,其计算结果输入进最终的输出层与Softmax层之前,
# 拼写错误纠正网络correction network的结果需通过残差连接residual connection与输入模型一开始的input embeddings相加,
# 相加的结果才输入最终的输出层与Softmax层中做最终的正确字符预测。
'''self.soft_masked_bert_dense_out即为拼写错误纠正网络correction network之后的输出层, 其会将经过残差连接模块residual connection之后 的输出的维度由768投影到纠错词表的索引空间. (此处输出层self.soft_masked_bert_dense_out的输出即可被视为Soft_Masked_BERT模型的最终输出)'''
self.soft_masked_bert_dense_out = torch.nn.Linear(self.config.hidden_size, self.embeddings.word_embeddings.weight.shape[0])
'''此处可不写最后的Softmax()函数, 因为若之后在训练模型时使用CrossEntropyLoss()交叉熵函数来计算损失值的话, CrossEntropyLoss()函数 中默认会对输入进行Softmax()计算.'''
'''下方三个函数为BertModel类中自带的函数,放在此处是为了和源BertModel类保持一致防止出错. '''
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
'''构建错误探查网络Detection_Network'''
def Detection_Network(self, input_embeddings: torch.Tensor, attention_mask: torch.Tensor):
# 此时输入错误探查网络Detection_Network中的input_embeddings张量形状为:(seq_len, batch_size, embed_size)->(seq_len, batch_size, 768)
# attention_mask张量形状为:(batch_size, seq_len)
# 输入模型起始处的嵌入张量input embedding由一句sentence中每个character的word embedding、position embedding、segment embeddings三者相加而成.
# 初始化错误探查网络Detection_Network的双向GRU的初始隐藏状态h_0.
h_0 = torch.zeros(2, input_embeddings.shape[1], 256)
# 此时双向GRU层self.enc_bi_gru的输出为一个元组,元组中第一个元素为最后的隐藏层输出张量,第二个元素为最后一个时间步的隐藏状态h_n,此处仅需最后的隐藏层输出张量;
# 此时错误探查网络Detection_Network的双向GRU编码层最后的隐藏层输出张量bi_gru_final_hidden_layer的形状为(seq_len, batch_size, enc_hid_size * 2).
bi_gru_final_hidden_layer = self.enc_bi_gru(input_embeddings, h_0)[0]
# 将隐藏层输出张量bi_gru_final_hidden_layer的第一第二维度互换,形状变为(batch_size, seq_len, enc_hid_size * 2)
bi_gru_final_hidden_layer = bi_gru_final_hidden_layer.permute(1,0,2)
# 双向GRU编码层对于输入错误探查网络Detection_Network中的input_embeddings进行双向编码,
# 此时双向GRU编码层的输出为(seq_len, batch_size, enc_hid_size * 2),将其交换维度变形为(batch_size, seq_len, enc_hid_size * 2),
# 再将双向GRU编码层的变形后的输出输入self.detection_network_dense_out层中,映射为形状(batch_size, seq_len, 2)的张量detection_network_output,
# 这样方便后面进行判断句子序列中每一个字符是否为拼写错误字符的二分类任务的交叉熵损失值计算.
detection_network_output = self.detection_network_dense_out(bi_gru_final_hidden_layer) # 形状为(batch_size, seq_len, 2)
# 同时,将双向GRU编码层输出后经过变形的形状为(batch_size, seq_len, enc_hid_size * 2),的张量输入进soft_masking_coef_mapping层中,
# 将其形状映射为(batch_size, seq_len, 1)的张量,此张量再在后面输入进Sigmoid()激活函数中, 将此张量的值映射至(0,1)之间,
# 这样这个张量即变为了后面计算soft-masked embeddings时和mask_embeddings相乘的系数p (结果pi即可表示为文本序列中第i处的字符拼写错误的似然概率(likelihood)).
# 此时soft_masking_coefs张量可被称为:soft-masking系数张量, 其形状为(batch_size, seq_len, 1).
soft_masking_coefs = torch.nn.functional.sigmoid( self.soft_masking_coef_mapping(bi_gru_final_hidden_layer) ) # (batch_size, seq_len, 1)
# 此时将attention_mask张量形状变为(batch_size, seq_len,1),即令此时attention_mask张量的形状与soft-masking系数张量soft_masking_coefs的形状保持一致.
attention_mask = attention_mask.unsqueeze(dim=2)
# 利用attention_mask填充符逻辑指示张量,将soft-masking系数张量soft_masking_coefs中,seq_len上为"填充特殊符[PAD]索引"的位置的
# soft-masking系数变为0, 这样soft-masking系数张量soft_masking_coefs中"填充特殊符[PAD]索引"的位置在后面生成soft-masked embeddings时,
# "填充特殊符[PAD]索引"位置处的mask_embeddings系数即为0, input_embeddings系数即为1,这样即令"填充特殊符[PAD]索引"位置处保持input_embeddings
# 的值不变。
# 由于此时attention_mask张量中, 非特殊填充符的位置指示值为1,特殊填充符的位置指示值为0,因此在此处要用反向选择操作:
# soft_masking_coefs[~attention_mask],来让特殊填充符的位置指示值反转为1,以达到选中特殊填充符的位置并给其赋值0的目的。
attention_mask = (attention_mask != 0) # 将attention_mask张量从1/0变为True/False,方便进行下一步的反向选择操作.
soft_masking_coefs[~attention_mask] = 0
return detection_network_output, soft_masking_coefs
'''构建Soft Masking Connection连接模块.'''
# 在错误探查网络error detection network输出一个句子中每个位置的字符为错误拼写字符的概率之后,利用此概率作为[MASK] embeddings的权重,
# 而1减去这个概率作为句子中每个字符character的input embeddings的权重,[MASK] embeddings乘以权重的结果再加上input embeddings乘以权重的结果后
# 所得到的嵌入结果soft-masked embeddings即为之后的错误纠正网络error correction network的输入。
def Soft_Masking_Connection(self,input_embeddings: torch.Tensor,
mask_embeddings: torch.Tensor,
soft_masking_coefs: torch.Tensor):
# 此时输入Soft_Masking_Connection模块中:
# input_embeddings张量形状为:(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768);
# mask_embeddings为只包含"遮罩特殊符[MASK]"的embedding嵌入的张量,其形状也为:(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768);
# soft_masking_coefs张量可被称为:soft-masking系数张量, 其为计算soft-masked embeddings时和mask_embeddings相乘的系数p的张量,形状为(batch_size, seq_len, 1);
# 输入模型起始处的嵌入张量input embedding由一句sentence中每个character的word embedding、position embedding、segment embeddings三者相加而成.
# 得到soft-masking系数张量:soft_masking_coefs张量之后,利用soft_masking_coefs张量作为[MASK] embeddings的权重,
# 而1减去这个概率作为句子中每个字符character的input embeddings的权重,[MASK] embeddings乘以权重的结果再加上
# input embeddings乘以权重的结果后所得到的嵌入结果soft-masked embeddings即为之后的错误纠正网络error correction network的输入.
# 此时soft_masked_embeddings形状也为(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768),
soft_masked_embeddings = soft_masking_coefs * mask_embeddings + (1 - soft_masking_coefs) * input_embeddings
return soft_masked_embeddings
'''forward函数.'''
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None,
output_attentions=None,):
# 利用张量的long()函数确保这些张量全为int型张量.
input_ids = input_ids.long()
attention_mask = attention_mask.long()
token_type_ids = token_type_ids.long()
position_ids = position_ids.long()
'''以下部分为transformers库中BertModel类中的forward()部门的一小部分源码, 放在此处是为了和源BertModel类保持一致防止出错.'''
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
'''以上部分为transformers库中BertModel类中的forward()部门的一小部分源码, 放在此处是为了和源BertModel类保持一致防止出错.'''
# 输入模型起始处的嵌入张量input embedding由一句sentence中每个character的word embedding、segment embeddings、position embedding三者相加而成。
# 此时input_embeddings张量的形状为(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768),
# 应将input_embeddings张量的第一第二维度互换, 将其形状变为(seq_len, batch_size, embed_size)->(seq_len, batch_size, 768)才方便输入进
# 后方的错误探查网络Detection_Network中的Bi-GRU网络中(双向GRU).
input_embeddings = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds)
# 形状变为(seq_len, batch_size, embed_size)->(seq_len, batch_size, 768).
input_embeddings = input_embeddings.permute(1,0,2)
# (1)错误探查网络Detection_Network中的双向GRU编码层的输出为(seq_len, batch_size, enc_hid_size * 2),
# 将其交换维度变形为(batch_size, seq_len, enc_hid_size * 2),再将双向GRU编码层的变形后的输出输入self.detection_network_dense_out层中,
# 映射为形状(batch_size, seq_len, 2)的张量detection_network_output, 这样方便后面进行判断句子序列中每一个字符是否为拼写错误字符的二分类任务的交叉熵损失值计算.
# (2)此时soft_masking_coefs张量可被称为:soft-masking系数张量, 其形状为(batch_size, seq_len, 1).
detection_network_output, soft_masking_coefs = self.Detection_Network(input_embeddings=input_embeddings, attention_mask=attention_mask)
# 此时需再将input_embeddings张量的第一第二维度交换, 将其形状再变回(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768),
# 这样input_embeddings张量才方便输入进self.soft.masking_connection模块中计算soft_masked_embeddings.
input_embeddings = input_embeddings.permute(1,0,2)
# 遮罩特殊符[MASK]的张量self.mask_embedding的形状要变为和Bert模型嵌入层BertEmbeddings()的输出input_embeddings张量的形状一样,
# 此时self.mask_embeddings张量的形状要为(batch_size, seq_len, embed_size)->(batch_size, seq_len, 768).
self.mask_embeddings = self.mask_embeddings.unsqueeze(0).unsqueeze(0).repeat(1,input_embeddings.shape[1],1).repeat(input_embeddings.shape[0],1,1)
# 在错误探查网络detection network输出一个句子中每个位置的字符为错误拼写字符的概率之后,利用此概率作为[MASK] embeddings的权重,
# 而1减去这个概率作为句子中每个字符character的input embeddings的权重,[MASK] embeddings乘以权重的结果再加上input embeddings乘以权重的结果后
# 所得到的嵌入结果soft-masked embeddings即为之后的拼写错误纠正网络correction network的输入。
soft_masked_embeddings = self.Soft_Masking_Connection(input_embeddings=input_embeddings, mask_embeddings=self.mask_embeddings,
soft_masking_coefs=soft_masking_coefs)
'''拼写错误纠正网络Correction_Network'''
'''soft_masked_embeddings输入错误纠正网络correction network的Bert模型后的结果经过最后的输出层与Softmax层后, 即为句子中每个位置的字符经过错误纠正网络correction network计算后预测的正确字符索引结果的概率。'''
'''注意: 最新版本的transformers.modeling_bert中的BertEncoder()类中forward()方法所需传入的参数中不再有output_attentions这个参数.'''
encoder_outputs = self.encoder(soft_masked_embeddings,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
# add hidden_states and attentions if they are here
# outputs为一个包含四个元素的tuple:sequence_output, pooled_output, (hidden_states), (attentions)
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
# outputs[0]代表Bert模型中最后一个隐藏层的输出(此时Bert模型中的隐藏层有12层,即num_hidden_layers参数为12),
# 注意此处和循环神经网络的输出形状不同,循环网络隐藏层状态的输出为(seq_len, batch_size, bert_hidden_size),
# 此时outputs[0]的张量bert_output_final_hidden_layer的形状为(batch_size, seq_len, bert_hidden_size)—>(batch_size, seq_len, 768).
bert_output_final_hidden_layer = outputs[0]
# 注意!: 在soft_masked_embeddings输入拼写错误纠正网络correction network中的Bert模型后,其计算结果输入进最终的输出层与Softmax层之前,
# 拼写错误纠正网络correction network的结果需通过残差连接residual connection与输入模型一开始的input embeddings相加,
# 相加的结果才输入最终的输出层与Softmax层中做最终的正确字符预测。
residual_connection_outputs = bert_output_final_hidden_layer + input_embeddings
'''self.soft_masked_bert_dense_out即为拼写错误纠正网络correction network之后的输出层, 其会将经过残差连接模块residual connection之后 的输出的维度由768投影到纠错词表的索引空间. (此处输出层self.soft_masked_bert_dense_out的输出final_outputs张量即可被视为Soft_Masked_BERT模型的最终输出).'''
final_outputs = self.soft_masked_bert_dense_out(residual_connection_outputs)
# 此处输出层self.soft_masked_bert_dense_out的输出final_outputs张量即可被视为Soft_Masked_BERT模型的最终输出.
return final_outputs
# 测试代码
config = BertConfig.from_pretrained("/kaggle/input/bertchinese/bert_config.json")
soft_masked_bert = Soft_Masked_BERT.from_pretrained("/kaggle/input/bertchinese/pytorch_model.bin", config=config)
input_ids = torch.Tensor([[101,768,867,117,102,0]]).long()
attention_mask = torch.Tensor([[1,1,1,1,1,0]]).long()
token_type_ids = torch.Tensor([[0,0,0,0,0,0]]).long()
position_ids = torch.Tensor([[0,1,2,3,4,5]]).long()
output = soft_masked_bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids)
output, output.shape
References
Zhang, S., Huang, H., Liu, J., & Li, H. (2020). Spelling Error Correction with Soft-Masked BERT. ArXiv, abs/2005.07421.