zl程序教程

您现在的位置是:首页 >  其他

当前栏目

新加坡国立大学 | 建立一个具有鲁棒性的QA模型(抗分布变化 & 含源码)

amp源码 一个 模型 建立 变化 分布 具有
2023-06-13 09:15:31 时间

引言

 情人节,你遇到的一切都是最好得礼物。今天给大家分享的这篇文章是新加坡国立大学发表的一篇文章,该文介绍了COLDQA,它是针对文本损坏、语言更改和域更改的分布变化的鲁棒QA的统一评估基准,进而从“测试集与训练集数据分布变化会影响模型效果”引入Test-time Adaptation(TTA),通过对TTA的分析,提出了一种新的TTA方法:Online Imitation Learning(OIL)方法;通过大量实验,发现TTA与RT方法相当,在RT之后应用TTA可以显着提高模型在COLDQA的上性能。

背景介绍

 如何构建一个可靠的、对分布变化具有鲁棒性的NLP系统是很重要的,因为现实世界是动态变化的,当测试数据集的分布不同于训练数据集的分布时,NLP模型系统很容易出现问题。针对模型的鲁棒性评估,先前的许多工作发现:当测试数据集分布发生变化时,模型结果会受到很大影响。例如,问答(QA)模型在处理扩展问答时很脆弱;面向任务的对话模型无法理解有损坏的输入;在有噪声的文本输入时神经机器翻译性能下降。

 为了建立一个对分布变化具有鲁棒性的模型,以前的大多数工作都集中在鲁棒性调优(RT)方法上,这些方法可以改善模型部署前的泛化能力,例如对抗性训练。但是,我们能否在模型部署完成之后继续提升模型效果吗?针对QA模型部署后鲁棒性这一问题,本文研究了Test-time Adaptation (TTA) ,TTA通过使用测试时数据不断更新模型来增强模型的泛化能力。

 如上图1所示,在这项工作中,本文专注于实时的测试时自适应(Test-Time Adaptation,TTA),其中模型对数据流进行动态预测和更新。对于每个测试数据实例,模型首先返回它的预测,然后用测试数据更新自己。与NLP中研究的无监督域适应不同,TTA适用于域泛化,因为它对目标分布不做假设,并且在测试时可以使模型适应任意分布。

TTA

「TTA定义」:基于源分布S训练得到的模型源模型为

\pi_0

,TTA利用测试数据使模型适应测试分布T,来增强模型部署后的性能。在线上适配的设置中,测试时数据以流的形式传入,如上面循环图所示。在时间t时,对于测试数据

x_t ∼ \tau

,模型

π_t

首先预测其标签

y_t

返回给终端用户。其次,

π_t

采用TTA方法进行自适应,并将自适应模型推进到t+1时刻。随着更多测试数据的到来,这个过程可以不间断地进行,并且在整个过程中无法获得测试数据的gold labels。

TTA的两个阶段

「TTA具有Tent和PL两个阶段」,其中Tent通过熵最小化对模型进行调整,模型利用测试时数据预测输出,并计算熵损失进行优化;PL是一种伪标记方法,预测测试时数据上的伪标记,并计算交叉熵损失。理论上,Tent和PL从源模型开始,在时间t,模型

π_t

利用测试数据

x_t

来自我更新,优化损失函数如下:

其中,

p_t

是模型

π_t

x_t

输出类的预测概率,

yt = arg max_ip_t [i]

H(·)

H(,)

分别为熵和交叉熵损失。在数据

x_t

上,对模型进行优化,只需要一个梯度步就可以得到

\pi^{'}_{t}:\pi^{'}_{t}\gets \pi_t

,模型

\pi^{'}_{t}

将会被前移到t+1时刻即:

\pi_{t+1} \gets \pi^{'}_{t}

在线模拟学习OIL

 仅仅通过模型自适应,Tent和PL很容易失去预测正确标签的能力,因为他们预测的标签没有被验证的,使用这样的噪声信号学习可能会降低模型的性能,并且模型效果一旦开始恶化,就可能无法恢复。为了克服这一问题,受到模仿学习的启发,本文提出了在线模仿学习(OIL)。OIL旨在通过数据流中的专家模型

π_e

有监督的来训练模型π。专家模型可以帮助模型在整个自适应过程中变得更加健壮,因为专家模型是固定的,训练模型会克隆专家模型的行为。

 理论上,在每一个时间t,专家模型

π_e

x_t ∼ \tau

上做出预测

\hat{y}_t \sim \pi _e

。然后模型

π_t

通过优化代理目标

l_t

来学习克隆这样一个动作:

其中,

y_t

表示学习模型在t时刻的预测,L用来表示计算专家模型和训练模型输出结果之间的距离。理论上,在时刻T,线上损失函数序列为:

\{l_t\}^T_{t=1}

,对应每个时刻的学习模型为:

\prod =\{\pi _t\}^T_{t=1}

,那么regret可以被定义为:

基于OIL的TTA实例化

 在时刻0,专家模型

\pi_e

和学习模型

\pi

都基于模型

\pi_0

进行初始化。在时刻t,需要优化的损失函数

l_t(\pi_t)

为:

其中,

p_t

表示基于学习模型

\pi_t

对于

x_t

输出类的预测概率,

\hat{y}_t = arg max_i\hat{p}_t [i]

,这里的

hat{p}_t

表示基于专家模型

\pi_e

的输出概率。与Tent和PL类似,该模型的优化还是采用一个梯度步骤就可以得到

\pi^{'}_{t}:\pi^{'}_{t}\gets \pi_t

,模型

\pi^{'}_{t}

将会被前移到t+1时刻即:

\pi_{t+1} \gets \pi^{'}_{t}

 对于专家模型,这里依然会采用学习模型的参数对其进行更新。在时刻t,采用以下方式对专家模型进行更新:

其中,

\theta

表示模型参数,

\alpha

是个超参数用来控制专家模型的更新,

\alpha

会设定一个比较高的值,例如0.99或者1,这样在自适应过程中,专家模型可以尽可能的与源模型

\pi_0

保持一致。

 除此之外,由于专家模型是由源模型初始化而来以及测试分布的变化,专家模型的预测值可能存在干扰,为此本文采用过滤的方式来减少对学习模型的干扰。那么损失函数如下:

其中,交叉熵

H(p_t,\hat{y}_t)

用来判定是否是干扰,

\gamma

是一个阈值超参数。

利用因果推理增强OIL

 由于专家模型是由源模型初始化的,当它预测测试数据上的标签时,它的行为会受到它从源分布中学到的知识的影响,这就是我们在这项工作中所说的模型偏差。由于测试分布与源分布不同,专家模型向学习模型提供克隆指示时,这种模型偏差会对学习模型的产生负面影响。为此本文,进一步使用因果推断来减少模型偏差造成的影响。

「因果图」:这里假设学习模型

\pi

的输出会受到输入直接或者间接的影响,如上图(a)所示。那么因果图展示了输入X,输出Y以及专家模型潜在的偏差M。

X\to Y

表示直接影响,

X \to M \to Y

表示间接影响,其中M是X和Y之间的中间值。这里的M是由输入X决定的,而X它既可以来自分布内数据,也可以是来自分布外数据。

「因果影响」:这里做因果推论的目标是保持直接影响但控制或者移除间接影响。如上图(b)所示,可以计算

X \to Y

的所有直接影响(TDE)如下:

其中

d_0

表示因果干预,即去除X的干扰因素。然而由于在假设中并没有X的干扰因素,所以这里忽直接略。

「模型训练」:基于上面TDE计算公式,首先需要学习公式左边这一项,它包含了从X到Y的直接影响以及X到M再到Y的间接影响。这里利用学习模型

\pi

来学习直接影响,对于间接影响,模型偏差M对于不同的分布则表现出不同的行为。由于学习模型

\pi

和专家模型

\pi_e

分别对应测试分布和源分布,我们使用输出中的差异来表示模型偏差。考虑模型偏差,损失函数则为:

其中,

p_t

\hat{p}_t

分别是学习模型

\pi

和专家模型

\pi_e

输出分类概率。其中

p_t

获取直接影响,

p_t-\hat{p}_t

获取间接影响。

「推理」:当进行推理的时候,本文取的y值能够让TDE具有最大的值,对于输入

x_t

其经过学习模型

\pi_t

得到

y_t

如下所示:

其中

\beta

来控制间接影响。这里当计算TDE值得时候,假设输入

x_0

为null时,模型输出为0,因此这里在输入为空得时候,模型得预测也为空。通过实验发现,这里设置

\beta

为1能够完全消除模型偏置得影响。

将TTA应用于QA任务

 对于提取性问答,模型需要预测开始和结束的位置。上述TTA方法分别对这两个位置采用相同的损失,即

l_t(\pi_t)

,最终的损失取两者的平均值。在上面算法1中给出了OIL的伪代码,其中Tent和PL遵循相同的过程,但更新的损失不同。每个时间t的数据

x_t

是一批实例。我们保留了一个大小为K的内存库,用于存储t- K到t时间的数据,从而更充分地利用测试时间数据进行模型自适应。在每一次时间t,从内存库中排队

x_t

和出队列

x_{t=K}

。然后使用内存库中的每一批数据优化在线损失,在此过程中,专家模型也进行相应得更新。

实验结果

 1、基于ClodQA的测试结果如下图所示。其中模型在RT之后应用TTA。

 2、基于MRQA测试集的测试结果如下图所示。

推荐阅读

[1] 「自然语言处理(NLP)」 你必须要知道的 “ 十二个国际顶级会议 ” !

[2] 快看!Transformer中的自注意力机制(Self-attention)竟有这么多变体

[3]GPT-3有Bug!基于Transformer的大型语言模型「鲁棒性」的定量分析

[4]Transformer变体!用于时间序列预测的指数平滑Transformer(含源码)

论文&&源码

Paper:https://arxiv.org/pdf/2302.04618v1.pdf

Code:https://github.com/oceanypt/coldqa-tta