zl程序教程

您现在的位置是:首页 >  工具

当前栏目

tensor2tensor-transformer源码学习

源码学习 transformer
2023-09-14 09:11:22 时间

1.多头注意力

多头注意力,我们可以看到源码中是进行了切割,从return的shape可以看出来。

2.transformer编码部分

可以看到它的输入就是经过emb和位置编码求和之后的输入。下面是正式使用到的编码函数:

上面的编码函数中,主要调用还是多头注意力这个函数:

调用的语句:

//注意,这里每次调用的时候第二个参数,也就是memory都是None,也就是query=momery。

 可以看到下面的query_antecedent就是经过预处理之后的输入,memory一开始是为None的。

 进入上面的函数后,因为一开始的时候memory是None,那么就调用计算qkv的函数:

 首先是对Q的计算:

在compute_attention_component函数中,看起来这个过程也非常地简单,就是之前输入的变换*一个var(服从正态分布的随机取样的矩阵),Q=pre_process(input)*var

计算KV也是调用同样的函数,但是所用的ante不同,kv需要的是memory,但是此时因为memory是None,

 compute一开始将query赋值给了memory:

 然后把qkv切成了8个部分进行之后的

下面进行attention操作:

具体的公式操作的部分标注出来:

 上面计算完attention之后,又有了一个o:

但是我不太明白这个o是干嘛用的,也许它只是用来做一个变换。

在transformer_layers.py文件中,在调用了common_attention.multihead_attention:

 可以看到返回y之后,然后进行了后处理得到x,之后又进行了全连接层,之后又后处理,然后有一个for循环,共有几层,encoder应该是6层,那么就是6次循环了。这样就获取到了encoder的输出:

 之后就返回到了这里?encoder输出结果。

3.解码部分

transformer类的body函数就主要是进行调用encode和decode的,可以看到它的输入features的要求:

 

 

应该包含inputs和targets和id。 下面decoder,也就是targets(这个是针对机器翻译吗?)首先要对它预处理

预处理中,首先是对输出的输入进行右移的操作:

 

 

 也就是在每一个seq上面都加上一行0,表示右移一位,而且第二维上还进行了[:-1]表示遗弃了最后一个word,以保证固定的seq_len长度?

 可以看到在transformer_self_attention_layer函数中,包括了self与encode的attention: 

 

 

 首先是self的部分,和之前的encode调用方式是一样的:

 

但实际上内部应该是进行的不一样的,因为这个应该是有个mask的吧?需要把后面的单词挡住,看到下面的地方我明白了:

common_attention.py文件中,multihead_attention函数内,有不同的attention方法:

 对于encoder的部分:

 

 

如果说第二个参数memory不是None的话,这样的话在计算KV的时候就可以使用encoder的输出了。

 

 

 此时Q是由解码的输入确定的,KV是由encoder的输出确定的。