transform是2017年提出来的,当时横扫NLP领域的多个任务,Vaswani et al. Attention Is All You Need. In NIPS,2017.

transform模型是Seq2Seq模型

transform不是RNN

transform是基于attention机制和全连接层的

这里通过最初的基于RNN的Seq2Seq模型,到基于RNN+attention的Seq2Seq模型,在到把RNN去掉全部基于attention的Seq2Seq模型,然后transform模型的思路进行讲解,这里前面几个都是前面讲过的,这里只是简单介绍即可

基于RNN的Seq2Seq模型

先介绍RNN是怎么更新状态的

基于RNN的Seq2Seq模型

这里简单的说一下,为什么在编码器中只需要使用最后一个状态的输出,为什么不使用中间的状态,因为最后一个状态包含了输入的所有信息,中间的状态,包含的信息不全

接着:

 

基于RNN+attention的Seq2Seq模型

可以发现基于attention的Seq2Seq的模型增加了context状态,通过context状态在计算输出状态h如下:

因此基于attention的RNN解决了长序列依赖问题,因为每个状态都包含了输入的整个信息。

计算context方法

上面采用的计算context方法并不是很常用,常用的方法是如下:

即引入三个矩阵和三个状态,其实本质上和上面的目的是一样的,只是处理方式不同,增加更多的参数,使模型表达能力更强。

他们的区别是什么?其实大家可以这样想,在获取解码器的输出状态时,使用RNN对seq2seq进行建模时,会发现,输入和上一个状态的S进行concat,然后乘上一个矩阵A,那么我们是否可以拆开呢?单独对输入创建一个参数矩阵和单独对状态创建一个矩阵,当然可以啊,两个矩阵的目的都是一样的,增加参数矩阵的表达能力,为了引入注意力机制,就需要计算相关性 ,那么也可以引入一个矩阵即可,因此,本质上transform就是通过计算相关性,然后把输出状态和每个输入的状态都关联起来,至于关联度其实就是相关性了,如何计算相关性,方法很多,而目前比较受欢迎的方法就是transform的这个方法,即设置参数矩阵Q、K、V:

Q矩阵是query就是当前的状态需要和编码器的状态计算相识度,因此需要解码器当前状态S乘上一个参数矩阵Q,表示该矩阵需要询问编码器所有的状态

K矩阵是key的意思,因为解码输出状态Q需要和编码器输出状态计算相关性,那么编码器的输出就是key了,因为每个编码器的输出都需要和Q相乘,因此称为key

V值其实很简单,在Q和K计算输出后是以矩阵,而相关性系数是一个值,因此需要另一个矩阵把Q和K的矩阵值在乘上一个矩阵V使其结果为一个相关性向量,这样就可以获取完整的α了,下面介绍基于RNN和attention的seq2seq模型

从上可以发现key和value的输出向量都是编码器的输出h状态,还是所有的状态,而q的输入是解码器当前的输出S状态,其中S0的状态为h_m,C0初始化为0,把S0和C0进行concat在通过激活函数就可以得到S1了,关键是如何通过attention计算法C1,其实很简单,此时S1和编码器的h_i都是知道的,那么分别创建W_Q矩阵和W_K矩阵和W_V矩阵,q矩阵是解码器当前输出状态S的参数矩阵,K是编码器的输出状态h的参数矩阵,V是为了计算相识度α的矩阵,该输入也是编码器的输出状态。

通过K和q的矩阵相乘,在通过softmax获取相识度α权重,然后计算当前输出状态的上下文C了

基于Attention的Seq2Seq模型

上面的是全部基于attention实现的seq2seq模型,从上可以发现,K和V的输入都是基于编码器的输入x_i,q的输入是解码器的上一个输出x^',其他的很基于rnn的attention类似,这里不过多解释,下面介绍如何组建成transform模型

Attention Layer

把上面模块化形成attention层

Self-attention层

Self-attention层和attention类似,只是这里的输入全为x_i,同样没有RNN网络,只有attention,其中Q、K、V的输入都是x,通过类似的方式及时α值,具体如下:

计算α值:

计算C值:

Self-attention层:

简化:

GitHub 加速计划 / vi / vision
15.85 K
6.89 K
下载
pytorch/vision: 一个基于 PyTorch 的计算机视觉库,提供了各种计算机视觉算法和工具,适合用于实现计算机视觉应用程序。
最近提交(Master分支:2 个月前 )
868a3b42 5 天前
e9a32135 14 天前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐