zl程序教程

您现在的位置是:首页 >  硬件

当前栏目

Bert机器问答模型QA(阅读理解)

机器 模型 理解 阅读 问答 bert QA
2023-09-27 14:27:52 时间

Github参考代码:https://github.com/edmondchensj/ChineseQA-with-BERT

https://zhuanlan.zhihu.com/p/333682032

数据集来源于DuReader Dataset,即百度经验上的问答,在上述链接中提供下载方式。

感谢作者提供的代码。

1、数据集处理

(1)首先数据集格式需要转为 SQuAD数据集格式,SQuAD数据集介绍参考。https://blog.csdn.net/m0_45478865/article/details/106568237

(2)然后将每条数据转换为样本example字典,包含qas_id、question_text、doc_tokens(答案的切词结果)。模型使用过程中进一步转换为bert所需要的特征。

2、Bert机器问答模型训练

(1)根据样本example构建模型输入,包括 input_ids、input_mask、segment_ids、 start_positions、 end_positions。其中input_ids为输入文本的切词编号,这个输入文本是一个完整的bert输入,格式为“[CLS]问题文本[SEP]答案文本[SEP]填充文本”,设置总长度上限为384,那么input_ids的维度为(B,L),L=384。input_mask(B,L)是input_ids的掩模,由于区分输入文本在长度不足384时的实际长度。segment_ids(B,L)用于区分输入文本的每个字所属的文本,这个属于问题文本为0,属于答案文本为1,属于填充文本也为0。start_positons(B)、end_positions(B)为答案文本在输入文本中的起止索引。

(2)将上述数据输入到BertForQuestionAnswering模型,这个模型是由pytorch bert默认定义的由一个基础的bert模型和一个全连接层组成。这里bert预训练模型选择的是bert-base-chinsese模型,该模型隐藏维度为768。输入数据经bert处理之后的维度为(B,L,768)。全连接层的为Linear(384,2),那么输出进一步转换为(B,L, 2)。这两个维度可以理解为答案在阅读文本中的起止位置。进一步表示为起始位置模型结果为start_logits(B,L),结束位置模型结果为end_logits。

(3)损失函数计算:模型损失包含起始位置损失和结束位置损失。起始位置损失是start_positions和start_logits的交叉熵损失;结束位置损失是end_positions和end_logits的交叉损失。总的损失为二者的平均值。

3、Bert机器问答模型推理

(1)数据输入类似训练(1),但是输入不再需要start_positions和end_positions。

(2)类似训练步骤(2),模型不再输出损失结果,而是直接输出start_logits和end_logits。通常分类时会取最大值所对应的类别为最终结果。这里作者选取了N个最大值作为备选结果,N=20,即分别从start_logits和end_logits选出20个最大值,并记录他们的位置索引。

(3)对所有的start_logits和end_logits进行遍历,每遍历一次,根据其位置索引得到模型输出的答案文本,同时用start_logit与end_logit的和作为预测结果的概率。

(4)对(3)中的结果按照概率从大到小进行排列,并设置阈值,即可得到模型最终预测的答案文本。

4、部分变量理解

orig_to_tok_index:原始答案的个数。

tok_to_orig_index:token每个字符属于第几个答案。

all_doc_tokens:所有文字的token。

token_to_ori_map:字典,标注字符属于第几个答案。

token_is_max_content:表示字符当前是否处于住家的截取片段。