Vision Transformer(ViT)简介理解
参考:https://gitee.com/mindspore/vision/blob/master/examples/classification/vit/vit.ipynb
模型特点
ViT模型是应用于图像分类领域。因此,其模型结构相较于传统的Transformer有以下几个特点:
数据集的原图像被划分为多个patch后,将二维patch(不考虑channel)转换为一维向量,再加上类别向量与位置向量作为模型输入。
模型主体的Block基于Transformer的Encoder部分,但是调整了normaliztion的位置,其中,最主要的结构依然是Multi-head Attention结构。
模型在Blocks堆叠后接全连接层接受类别向量输出用于分类。通常情况下,我们将最后的全连接层称为Head,Transformer Encoder部分为backbone。
ViT模型的输入
传统的Transformer结构主要用于处理自然语言领域的词向量(Word Embedding or Word Vector),词向量与传统图像数据的主要区别在于,词向量通常是1维向量进行堆叠,而图片则是二维矩阵的堆叠,多头注意力机制在处理1维词向量的堆叠时会提取词向量之间的联系也就是上下文语义,这使得Transformer在自然语言处理领域非常好用,而2维图片矩阵如何与1维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。
在ViT模型中:
通过将输入图像在每个channel上划分为16*16个patch,这一步是通过卷积操作来完成的,当然也可以人工进行划分,但卷积操作也可以达到目的同时还可以进行一次而外的数据处理;例如一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14。
再将每一个patch的矩阵拉伸成为一个1维向量,从而获得了近似词向量堆叠的效果。上一步得道的14 x 14的patch就转换为长度为196的向量。
由论文中的模型结构可以得知,输入图像在划分为patch之后,会经过pos_embedding 和 class_embedding两个过程。
1、class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。
2、增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。
3、pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中。
由于pos_embedding也是可以学习的参数,所以它的加入类似于全链接网络和卷积的bias。这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。
从论文中可以得到,pos_embedding总共有4中方案。但是经过作者的论证,只有加上pos_embedding和不加pos_embedding有明显影响,至于pos_embedding是1维还是2维对分类结果影响不大,所以,在我们的代码中,也是采用了1维的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的维度会比patch拉伸后的维度加1。
总的而言,ViT模型还是利用了Transformer模型在处理上下文语义时的优势,将图像转换为一种“变种词向量”然后进行处理,而这样转换的意义在于,多个patch之间本身具有空间联系,这类似于一种“空间语义”,从而获得了比较好的处理效果。
相关文章
- [Web] 深入理解现代浏览器
- 深入理解CSS3 Animation 帧动画
- 深入理解重建索引
- 【转载】理解 Linux 的处理器负载均值
- 理解 Python 中s可变参数的 *args 和 **kwargs
- LeetCode-779. 第K个语法符号【递归,绝对好理解】
- Kubernets 通过例子理解 k8s 架构
- Atitit 人工智能体系树培训列表应用较为广泛的技术.docx Atitit 人工智能体系培训列表 目录 1. 1.NLP自然语言处理文本处理2 1.1. 语言理解 分词2 1.2. 抽取
- titit. 深入理解 内聚( Cohesion)原理and attilax大总结
- DL之DNN优化技术:神经网络算法简介之GD/SGD算法(BP的梯度下降算法)的简介、理解、代码实现、SGD缺点及改进(Momentum/NAG/Ada系列/RMSProp)之详细攻略
- XAI/ML:机器学习模型可解释性之explainability和interpretability区别的简介、区别解读、案例理解之详细攻略
- BigData之Storm:Apache Storm的简介、深入理解、下载、案例应用之详细攻略
- High&NewTech:元宇宙(metaverse)的简介(多角度理解与探讨)、发展历史、现状与未来
- DL之CNN:卷积神经网络算法简介之原理简介(步幅/填充/特征图)、七大层级结构(动态图详解卷积/池化+方块法理解卷积运算)、CNN各层作用及其可视化等之详细攻略
- DL之DNN优化技术:神经网络算法简介之GD/SGD算法(BP的梯度下降算法)的简介、理解、代码实现、SGD缺点及改进(Momentum/NAG/Ada系列/RMSProp)之详细攻略
- MDX 中 drilldownmember函数的理解和用法
- XAI/ML:机器学习模型可解释性之explainability和interpretability区别的简介、区别解读、案例理解之详细攻略
- 谈谈对ThreadLocal的理解?(基于jdk1.8)
- 深入理解css3中的flex-grow、flex-shrink、flex-basis