zl程序教程

您现在的位置是:首页 >  后端

当前栏目

NLP-生成模型-2015:Seq2Seq+Copy【 Pointer网络的Copy机制是对传统Attention机制的简化:输出针对输出词汇表的一个概率分布 --> 输出针对输入文本序列的概率分布】

序列网络输入输出 一个 -- 模型 生成
2023-09-27 14:20:38 时间

《原始论文:Pointer Networks》

后续应用了Pointer Networks的三篇文章:

  • 《Get To The Point: Summarization with Pointer-Generator Networks》;
  • 《Incorporating Copying Mechanism in Sequence-to-Sequence Learning》;
  • 《Multi-Source Pointer Network for Product Title Summarization》;

一、从Seq2Seq说起

Sequence2Sequence(简称seq2seq)模型是RNN的一个重要的应用场景,顾名思义,它实现了把一个序列转换成另外一个序列的功能,并且不要求输入序列和输出序列等长。比较典型的如机器翻译,一个英语句子“Who are you”和它对应的中文句子“你是谁”是两个不同的序列,seq2seq模型要做的就是把这样的序列对应起来。

由于类似语言这样的序列都存在时序关系,而RNN天生便适合处理具有时序关系的序列,因此seq2seq模型往往使用RNN来构建,如LSTM和GRU。具体结构见Sequence to Sequence Learning with Neural Networks 这篇文章提供的模型结构图:

在这里插入图片描述

在这幅图中,模型把序列“ABC”转换成了序列“WXYZ”。分析其结构,我们可以把seq2seq模型分为encoder和decoder两个部分。encoder部分接收“ABC”作为输入,然后将这个序列转换成为一个中间向量C,向量C可以认为是对输入序列的一种理解和表示形式。然后decoder部分把中间向量C作为自己的输入,通过解码操作得到输出序列“WXYZ”。

后来,Attention Mechanism[6]的加入使得seq2seq模型的性能大幅提升,从而大放异彩。那么Attention Mechanism做了些什么事呢?一言以蔽之,Attention Mechanism的作用就是将encoder的隐状态按照一定权重加和之后拼接(或者直接加和)到decoder的隐状态上,以此作为额外信息,起到所谓“软对齐”的作用,并且提高了整个模型的预测准确度。简单举个例子,在机器翻译中一直存在对齐的问题,也就是说源语言的某个单词应该和目标语言的哪个单词对应,如“Who are you”对应“你是谁”,如果我们简单地按照顺序进行匹配的话会发现单词的语义并不对应,显然“who”不能被翻译为“你”。而Attention Mechanism非常好地解决了这个问题。如前所述,Attention Mechanism会给输入序列的每一个元素分配一个权重,如在预测“你”这个字的时候输入序列中的“you”这个词的权重最大,这样模型就知道“你”是和“you”对应的,从而实现了软对齐。

二、Pointer Networks

背景讲完,我们就可以正式进入Pointer Networks这部分了。

为什么在讨论Pointer Networks之前要先说seq2seq以及Attention Mechanism呢,因为Pointer Networks正是通过对Attention Mechanism的简化而得到的。

作者开篇就提到,传统的seq2seq模型是无法解决输出序列的词汇表会随着输入序列长度的改变而改变的问题的,如寻找凸包等。因为对于这类问题,输出往往是输入集合的子集。基于这种特点,作者考虑能不能找到一种结构类似编程语言中的指针,每个指针对应输入序列的一个元素,从而我们可以直接操作输入序列而不需要特意设定输出词汇表。作者给出的答案是 指针网络(Pointer Networks)。我们来看作者给出的一个例子:

在这里插入图片描述
这个图的例子是给定 P 1 P_1 P1 p 4 p_4 p4 四个二维点的坐标,要求找到一个凸包。

显然答案是 P 1 P_1 P1-> P 4 P_4 P4-> P 2 P_2 P2-> P 1 P_1 P1

  • 图a是传统seq2seq模型的做法,就是把四个点的坐标作为输入序列输入进去,然后提供一个词汇表:[start, 1, 2, 3, 4, end],最后依据词汇表预测出序列[start, 1, 4, 2, 1, end],缺点作者也提到过了,对于图a的传统seq2seq模型来说,它的输出词汇表已经限定,当输入序列的长度变化的时候(如变为10个点)它根本无法预测大于4的数字。
  • 图b是作者提出的Pointer Networks,它预测的时候每一步都找当前输入序列中权重最大的那个元素,而由于输出序列完全来自输入序列,它可以适应输入序列的长度变化。

那么Pointer Networks具体是怎样实现的呢?

我们首先来看传统 Attention 机制的公式:
在这里插入图片描述

  • e j e_j ej 是Encoder的隐状态,
  • d i d_i di 是Decoder的隐状态,
  • v v v W 1 W_1 W1 W 2 W_2 W2都是可学习的参数,
  • 在得到 u j i u^i_j uji 之后对其执行softmax操作即得到 a j i a^i_j aji a j i a^i_j aji 就是Decoder端的第 j j j 个 token 分配给Encoder端序列中的第 j j j 个token 的权重,
  • 依据Encoder端的第 j j j 个token 的 权重 a j i a^i_j aji 求加权和得到带有权重的Encoder的输出Context Vector,即: d i ′ d^{'}_i di

然后把得到的 d i ′ d^{'}_i di 拼接(或者加和)到 decoder的隐状态 d i d_i di 上,最后让Decoder部分根据拼接后新的隐状态进行解码和预测。

根据传统的Attention机制,作者想到,所谓的Attention权重系数 α \textbf{α} α 正是针对输入序列的权重,完全可以把它拿出来作为指向输入序列的指针,在每次预测一个元素的时候找到输入序列中权重最大的那个元素不就好了嘛!于是作者就按照这个思路对传统注意力机制进行了修改和简化,公式变成了这个样子:
在这里插入图片描述

第一个公式和之前没有区别,然后第二个公式则是说Pointer Networks直接将softmax之后得到的 α \textbf{α} α 当成了输出,让 α \textbf{α} α 承担指向输入序列特定元素的指针角色。

所以总结一下:

  • 传统的带有注意力机制的seq2seq模型的运行过程是这样的,先使用encoder部分对输入序列进行编码,然后对编码后的向量做attention,最后使用decoder部分对attention后的向量进行解码从而得到预测结果。
  • 但是作为Pointer Networks,得到预测结果的方式便是输出一个概率分布 α \textbf{α} α,也即所谓的指针。

换句话说,

  • 传统带有Attention机制的seq2seq模型输出的是针对输出词汇表的一个概率分布
  • Pointer Networks输出的则是针对输入文本序列的概率分布

其实我们可以发现,因为输出元素来自输入元素的特点,Pointer Networks特别适合用来直接复制输入序列中的某些元素给输出序列。而事实证明,后来的许多文章也确实是以这种方式使用Pointer Networks的。




参考资料:
Pointer Networks简介及其应用