《动手学深度学习》-66使用注意力机制seq2seq
一、Seq2Seq 模型概述
Seq2Seq 模型的核心思想是将一个输入序列(例如一段话)映射到一个输出序列(例如翻译后的文本)。其结构包括两个主要组件:
-
编码器 (Encoder):接受输入序列,并将其编码为一个固定长度的上下文向量(即隐藏状态)。
-
解码器 (Decoder):根据编码器的上下文向量生成输出序列。
在传统的 Seq2Seq 模型中,编码器将输入序列压缩成一个固定长度的上下文向量,然后解码器从这个向量生成输出序列。这种方式可能存在信息丢失的问题,尤其在处理较长的输入序列时。
二. 注意力机制(Attention Mechanism)
为了弥补传统 Seq2Seq 模型在处理长序列时的信息丢失问题,引入了注意力机制。注意力机制可以让解码器在生成每个输出时,动态地关注输入序列中的不同部分,从而有效捕捉输入的关键信息。
注意力机制的工作原理:
-
对于每个解码步骤,解码器不仅依赖于固定的上下文向量,而是通过计算输入序列各个位置的注意力权重,来决定当前输出时应关注的输入位置。
-
这些权重是根据输入序列和当前解码状态的相关性计算得出的。通过这种方式,模型能够在生成输出时,动态地选择性地关注输入序列中的关键部分。
Seq2Seq + Attention 结构
通常,结合了注意力机制的 Seq2Seq 模型包含以下步骤:
编码器:对输入序列进行编码,生成一个包含所有输入信息的隐藏状态序列(不仅是单一的上下文向量)。
注意力层:计算每个输入位置对当前解码位置的贡献(注意力权重),通过加权和的方式将这些隐藏状态结合成一个动态上下文向量。
解码器:使用编码器的上下文向量和当前的解码状态生成下一个输出。
模型结构
假设我们使用的是一个基于 LSTM 或 GRU 的 Seq2Seq 模型,注意力机制通常在解码器中实现,如下:
编码器(Encoder):
-
输入序列被逐步传入编码器,每个词经过词嵌入层,然后通过一个 RNN(如 LSTM 或 GRU)进行处理,产生每个时间步的隐藏状态。
注意力机制(Attention Mechanism):
-
对于每个时间步的解码器输入,计算与编码器所有隐藏状态的注意力权重,得到一个加权的上下文向量。
解码器(Decoder):
-
解码器根据当前的输入、先前的输出(通常通过 teacher forcing 技术)和加权的上下文向量,生成输出序列。
1. Q(Query) - 查询
-
Q 代表“查询(Query)”,是当前解码器状态(通常来自上一时间步的隐层状态)与输入序列进行匹配时的参考点。
-
它是当前解码器需要寻找相关信息的“请求”部分。
-
例如,在机器翻译中,如果解码器正在翻译一个单词,它会用该单词对应的隐层状态作为查询。
2. K(Key) - 键
-
K 代表“键(Key)”,它来自编码器的输出。
-
它是输入序列中各个时间步的表示。在注意力计算中,每个编码器的输出(通常是一个隐层状态)被认为是一个“键”。
-
键用于与查询进行匹配,以衡量输入序列每个部分与当前解码步骤的相关性。
-
通过计算查询和每个键之间的相似度,确定哪些部分的输入序列对当前输出是最重要的。
3. V(Value) - 值
-
V 代表“值(Value)”,它与键(K)一一对应,通常是与每个键关联的特征或信息。
-
它是实际的输出信息,当我们知道了当前查询与哪些键最相关后,最终我们要从这些值(V)中提取相关的信息。
-
在注意力机制中,计算完查询与键的相似度之后,通过加权的方式从
V中提取信息,这样我们得到一个上下文向量(context vector)。
三、Bahdanau 注意力代码
1.
class AttentionDecoder(test_60en_decorder.Decoder):#
def __init__(self,**kwargs):
super(AttentionDecoder,self).__init__(**kwargs)
@property
def attentionweights(self):
raise NotImplementedError
class Seq2seqAttentionDecoder(AttentionDecoder):
def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
super(Seq2seqAttentionDecoder,self).__init__(**kwargs)
self.attention=test_65attentionscore.AdditiveAttention(num_hiddens,num_hiddens,num_hiddens,dropout)#Q,K,注意力分数的维度,加性中计算查询、键和值之间的关系,而这些通常都是相同大小的隐层向量,通常与 RNN 或 GRU 的隐藏状态维度一致。
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self,enc_outputs,enc_valid_lens,*args):
#outputs形状(batch,steps,hiddens)
#hidden_state形状(layers,batch,hiddens)
outputs,hidden_state=enc_outputs
return (outputs.permute(1,0,2),hidden_state,enc_valid_lens)
def forward(self,X,state):
enc_outputs,hidden_state,enc_valid_lens = state
X=self.embedding(X).permute(1,0,2)
outputs,self.attention_weight=[],[]
for x in X:
query=torch.unsqueeze(hidden_state[-1],dim=1)#(batch_size, 1, num_hiddens)
context=self.attention(query,enc_outputs,enc_outputs,enc_valid_lens)
X=torch.cat((context,torch.unsqueeze(x,dim=1)),dim=1)#将上下文和当前输入连接,(batch_size, 1, embed_size + num_hiddens)
out,hidden_state=self.rnn(X.permute(1,0,2),hidden_state)
outputs.append(out)
self._attention_weight.append(self.attention.weight)
outputs=self.dense(torch.cat(outputs,dim=0))
return outputs.permute(1,0,2),[enc_outputs,hidden_state,enc_valid_lens]#返回 outputs 和解码器的状态
@property
def attentionweights(self):
return self._attention_weight
encoder=test_62seq2seq.Seq2seqEncoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
encoder.eval()
decoder=Seq2seqAttentionDecoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
decoder.eval()
X=torch.zeros((4,7),dtype=torch.long)
state=decoder.init_state(encoder(X),None)
ouputs,state=decoder(X,state)
print(ouputs.shape,len(state),state[0].shape)
embed_size,num_hiddens,num_layers,dropout=32,32,2,0.1
batch_size,num_steps=64,10
lr,num_epochs,device=0.005,250,d2l.try_gpu()
train_iter,scr_bocab,tgt_vocab=d2l.load_data_nmt(batch_size,num_steps)
encoder=test_62seq2seq.Seq2seqEncoder(len(scr_bocab),embed_size,num_hiddens,num_layers,dropout)
decoder=Seq2seqAttentionDecoder(len(tgt_vocab),embed_size,num_hiddens,num_layers,dropout)
net=test_60en_decorder.EncoderDecoder(encoder,decoder)
test_62seq2seq.train_seq2seq(net,train_iter,lr,num_epochs,tgt_vocab,device)
engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation, dec_attention_weight_seq = d2l.predict_seq2seq(
net, eng, scr_bocab, tgt_vocab, num_steps, device, True)
print(f'{eng} => {translation}, ',
f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)