基于去噪Transformer的无监督句子编码
EMNLP2021 Findings上有一篇名为TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning的论文,利用Transformer结构无监督训练句子编码,网络架构如下所示
具体来说,输入的文本添加了一些确定的噪声,例如删除、交换、添加、Mask一些词等方法。Encoder需要将含有噪声的句子编码为一个固定大小的向量,然后利用Decoder将原本的不带噪声的句子还原。说是这么说,但是其中有非常多细节,首先是训练目标
egin{aligned} J_{ ext{SDAE}}( heta) &= mathbb{E}_{xsim D}[log P_{ heta}(xmid ilde{x})]\ &=mathbb{E}_{xsim D}[sum_{t=1}^l log P_{ heta}(x_tmid ilde{x})]\ &=mathbb{E}_{xsim D}[sum_{t=1}^l log frac{exp(h_t^T e_t)}{sum_{i=1}^N exp(h_t^T e_i)}] end{aligned}
其中,D是训练集;x = x_1x_2cdots x_l是长度为l的输入句子; ilde{x}是x添加噪声之后的句子;e_t是词x_t的word embedding;N为Vocabulary size;h_t是Decoder第t步输出的hidden state
不同于原始的Transformer,作者提出的方法,Decoder只利用Encoder输出的固定大小的向量进行解码,具体来说,Encoder-Decoder之间的cross-attention形式化地表示如下:
egin{aligned} &H^{(k)}= ext{Attention}(H^{(k-1)}, [s^T], [s^T])\ & ext{Attention}(Q,K,V) = ext{Softmax}(frac{QK^T}{sqrt{d}})V end{aligned}
其中,H^{(k)}in mathbb{R}^{t imes d}是Decoder第k层t个解码步骤内的hidden state;d是句向量的维度(Encoder输出向量的维度);[s^T]in mathbb{R}^{1 imes d}是Encoder输出的句子(行)向量。从上面的公式我们可以看出,不论哪一层的cross-attention,K和V永远都是s^T,作者这样设计的目的是为了人为给模型添加一个瓶颈,如果Encoder编码的句向量s^T不够准确,Decoder就很难解码成功,换句话说,这样设计是为了使得Encoder编码的更加准确。训练结束后如果需要提取句向量只需要用Encoder即可
作者通过在STS数据集上调参,发现最好的组合方法如下:
- 采用删除单词这种添加噪声的方法,并且比例设置为60%
- 使用[CLS]位置的输出作为句向量
Results
从TSDAE的结果来看,基本上是拳打SimCSE,脚踢BERT-flow
个人总结
如果我是reviewer,我特别想问的一个问题是:"你们这种方法,与BART有什么区别?"
论文源码在UKPLab/sentence-transformers/,其实sentence-transformers已经把TSDAE封装成pip包,完整的训练流程可以参考Sentence-Transformer的使用及fine-tune教程,在此基础上只需要修改dataset和loss就可以轻松的训练TSDAE
# 创建可即时添加噪声的特殊去噪数据集
train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences)
# DataLoader 批量处理数据
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# 使用去噪自动编码器损失
train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)
# 模型训练
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=1,
weight_decay=0,
scheduler='constantlr',
optimizer_params={'lr': 3e-5},
show_progress_bar=True
)
相关文章
- 金融服务领域的大数据:即时分析
- 影响大数据、机器学习和人工智能未来发展的8个因素
- 从0开始构建一个属于你自己的PHP框架
- 如何将Hadoop集成到工作流程中?这6个优秀实践必看
- SEO公司使用大数据优化其模型的5种方法
- 关于Web Workers你需要了解的七件事
- 深入理解HTTPS原理、过程与实践
- 增强分析:数据和分析的未来
- PHP协程实现过程详解
- AI专家:大数据知识图谱——实战经验总结
- 关于PHP的错误机制总结
- 利用数据分析量化协同过滤算法的两大常见难题
- 怎么做大数据工作流调度系统?大厂架构师一语点破!
- 2019大数据处理必备的十大工具,从Linux到架构师必修
- OpenCV中的KMeans算法介绍与应用
- 教大家如果搭建一套phpstorm+wamp+xdebug调试PHP的环境
- CentOS下三种PHP拓展安装方法
- Go语言HTTP Server源码分析
- Go语言HTTP Server源码分析
- 2017年4月编程语言排行榜:Hack首次进入前五十