zl程序教程

您现在的位置是:首页 >  其它

当前栏目

详解机器翻译任务中的BLEU

详解 任务 机器翻译
2023-09-27 14:19:49 时间

一、 n n n 元语法(N-Gram)

n n n 元语法(n-gram)是指文本中连续出现的 n n n词元。当 n n n 分别为 1 , 2 , 3 1,2,3 1,2,3 时,n-gram 又叫作 unigram(一元语法)、bigram(二元语法)和 trigram(三元语法)。

n n n 元语法模型是基于 n − 1 n-1 n1 阶马尔可夫链的一种概率语言模型(即只考虑前 n − 1 n-1 n1 个词出现的情况下,后一个词出现的概率):

unigram: P ( w 1 , w 2 , ⋯   , w T ) = ∏ i = 1 T P ( w i ) bigram: P ( w 1 , w 2 , ⋯   , w T ) = P ( w 1 ) ∏ i = 1 T − 1 P ( w i + 1 ∣ w i ) trigram: P ( w 1 , w 2 , ⋯   , w T ) = P ( w 1 ) P ( w 2 ∣ w 1 ) ∏ i = 1 T − 2 P ( w i + 2 ∣ w i , w i + 1 ) \begin{aligned} \text{unigram:}\quad&P(w_1,w_2,\cdots,w_T)=\prod_{i=1}^T P(w_i) \\ \text{bigram:}\quad&P(w_1,w_2,\cdots,w_T)=P(w_1)\prod_{i=1}^{T-1} P(w_{i+1}|w_i) \\ \text{trigram:}\quad&P(w_1,w_2,\cdots,w_T)=P(w_1)P(w_2|w_1)\prod_{i=1}^{T-2} P(w_{i+2}|w_{i},w_{i+1}) \\ \end{aligned} unigram:bigram:trigram:P(w1,w2,,wT)=i=1TP(wi)P(w1,w2,,wT)=P(w1)i=1T1P(wi+1wi)P(w1,w2,,wT)=P(w1)P(w2w1)i=1T2P(wi+2wi,wi+1)

二、BLEU(Bilingual Evaluation Understudy)

2.1 BLEU 定义

BLEU(发音与单词 blue 相同) 最早是用于评估机器翻译的结果, 但现在它已经被广泛用于评估许多应用的输出序列的质量。对于预测序列 pred 中的任意 n n n 元语法, BLEU 的评估都是这个 n n n 元语法是否出现在标签序列 label 中。

BLEU 定义如下:

BLEU = exp ⁡ ( min ⁡ ( 0 , 1 − len(label) len(pred) ) ) ∏ n = 1 k p n 1 / 2 n \text{BLEU}=\exp\left(\min\left(0,1-\frac{\text{len(label)}}{\text{len(pred)}}\right)\right)\prod_{n=1}^kp_n^{1/2^n} BLEU=exp(min(0,1len(pred)len(label)))n=1kpn1/2n

其中 len(*) \text{len(*)} len(*) 代表序列 ∗ * 中的词元个数, k k k 用于匹配最长的 n n n 元语法(常取 4 4 4), p n p_n pn 表示 n n n 元语法的精确度。

具体而言,给定 label A , B , C , D , E , F A,B,C,D,E,F A,B,C,D,E,Fpred A , B , B , C , D A,B,B,C,D A,B,B,C,D,取 k = 3 k=3 k=3

首先看 p 1 p_1 p1 如何计算。我们先将 pred 中的每个 unigram 都统计出来: ( A ) , ( B ) , ( B ) , ( C ) , ( D ) (A),(B),(B),(C),(D) (A),(B),(B),(C),(D),再将 label 中的每个 unigram 都统计出来: ( A ) , ( B ) , ( C ) , ( D ) , ( E ) , ( F ) (A),(B),(C),(D),(E),(F) (A),(B),(C),(D),(E),(F),然后看它们之间有多少匹配的(不可以重复匹配,即必须保持一一对应的关系)。可以看出一共有 4 4 4 个匹配的,而 pred 中一共有 5 5 5 个 unigram,于是 p 1 = 4 / 5 p_1=4/5 p1=4/5

再来看 p 2 p_2 p2 如何计算。我们先将 pred 中的每个 bigram 都统计出来: ( A , B ) , ( B , B ) , ( B , C ) , ( C , D ) (A,B),(B,B),(B,C),(C,D) (A,B),(B,B),(B,C),(C,D),再将 label 中的每个 bigram 都统计出来: ( A , B ) , ( B , C ) , ( C , D ) , ( D , E ) , ( E , F ) (A,B),(B,C),(C,D),(D,E),(E,F) (A,B),(B,C),(C,D),(D,E),(E,F),然后看它们之间有多少匹配的。可以看出一共有 3 3 3 个匹配的,而 pred 中一共有 4 4 4 个 bigram,于是 p 2 = 3 / 4 p_2=3/4 p2=3/4

最后看 p 3 p_3 p3 如何计算。我们先将 pred 中的每个 trigram 都统计出来: ( A , B , B ) , ( B , B , C ) , ( B , C , D ) (A,B,B),(B,B,C),(B,C,D) (A,B,B),(B,B,C),(B,C,D),再将 label 中的每个 trigram 都统计出来: ( A , B , C ) , ( B , C , D ) , ( C , D , E ) , ( D , E , F ) (A,B,C),(B,C,D),(C,D,E),(D,E,F) (A,B,C),(B,C,D),(C,D,E),(D,E,F),然后看它们之间有多少匹配的。可以看出只有 1 1 1 个匹配,而 pred 中一共有 3 3 3 个 trigram,于是 p 3 = 1 / 3 p_3=1/3 p3=1/3

因此此例的 BLEU 分数为

BLEU = exp ⁡ ( min ⁡ ( 0 , 1 − 6 / 5 ) ) ⋅ p 1 1 / 2 ⋅ p 2 1 / 4 ⋅ p 3 1 / 8 = e − 0.2 ⋅ ( 4 5 ) 1 / 2 ⋅ ( 3 4 ) 1 / 4 ⋅ ( 1 3 ) 1 / 8 ≈ 0.5940 \begin{aligned} \text{BLEU}&=\exp(\min(0,1-6/5))\cdot p_1^{1/2}\cdot p_2^{1/4}\cdot p_3^{1/8} \\ &=e^{-0.2}\cdot \left(\frac45\right)^{1/2}\cdot \left(\frac34\right)^{1/4}\cdot\left(\frac13\right)^{1/8} \\ &\approx0.5940 \end{aligned} BLEU=exp(min(0,16/5))p11/2p21/4p31/8=e0.2(54)1/2(43)1/4(31)1/80.5940

2.2 BLEU 的探讨

根据 BLEU 的定义,当预测序列与标签序列完全相同时,BLEU 的值为 1 1 1。另一方面,由于 e x > 0 e^x>0 ex>0 p n ≥ 0 p_n\geq0 pn0,因此有

BLEU ∈ [ 0 , 1 ] \text{BLEU}\in[0,1] BLEU[0,1]

BLEU 的值越接近 1 1 1,则代表预测效果越好;BLEU 的值越接近 0 0 0,则代表预测效果越差。

此外,由于 n n n 元语法越长匹配难度越大, 所以 BLEU 为更长的 n n n 元语法的精确度分配更大的权重(固定 a ∈ ( 0 , 1 ) a\in(0,1) a(0,1),则 a 1 / 2 n a^{1/2^n} a1/2n 会随着 n n n 的增加而增加)。而且,由于预测序列越短获得的 p n p_n pn 值越高,所以系数 exp ⁡ ( ⋅ ) \exp(\cdot) exp() 这一项用于惩罚较短的预测序列。

2.3 BLEU 的简单实现

import math
from collections import Counter


def bleu(label, pred, k=4):
	assert len(pred) >= k
    # 我们假设输入的label和pred都已经进行了分词
    score = math.exp(min(0, 1 - len(label) / len(pred)))
    for n in range(1, k + 1):
        # 使用哈希表用来存放label中所有的n-gram
        hashtable = Counter([' '.join(label[i:i + n]) for i in range(len(label) - n + 1)])
        # 匹配成功的个数
        num_matches = 0
        for i in range(len(pred) - n + 1):
            ngram = ' '.join(pred[i:i + n])
            if ngram in hashtable and hashtable[ngram] > 0:
                num_matches += 1
                hashtable[ngram] -= 1
        score *= math.pow(num_matches / (len(pred) - n + 1), math.pow(0.5, n))
    return score

例如:

label = 'A B C D E F'
pred = 'A B B C D'
for i in range(4):
    print(bleu(label.split(), pred.split(), k=i + 1))
# 0.7322950476607851
# 0.6814773296495302
# 0.5940339360503315
# 0.0

References

[1] d2l. Sequence to Sequence Learning