Seq2Seq-AI轮回基本原理复习2
9.7 Seq2Seq:从RNN到机器翻译完整实现
理论 → 数学公式 → 代码实现 → 张量维度变化 → 函数调用流程 → 为什么这样设计
1 为什么需要Seq2Seq?
在上一节Encoder-Decoder架构中,我们已经知道:
机器翻译本质上属于:
X=(x1,x2,⋯ ,xT)X=(x_1,x_2,\cdots,x_T)X=(x1,x2,⋯,xT)到Y=(y1,y2,⋯ ,yN)Y=(y_1,y_2,\cdots,y_N)Y=(y1,y2,⋯,yN)的映射问题。
例如:
English:
They are watching .
French:
Ils regardent .
这里:
- 输入长度 T 不固定
- 输出长度 N 不固定
- T ≠ N
因此:
传统神经网络:y=f(x)y=f(x)y=f(x)
无法直接处理。
2 Seq2Seq核心思想
Seq2Seq:
Sequence To Sequence
即:
输入序列
↓
Encoder
↓
Context
↓
Decoder
↓
输出序列
数学表示:P(Y∣X)P(Y|X)P(Y∣X)
展开:$ P(Y|X)=\prod_{t=1}^{N}P(y_t|y_{<t},X)]$
其中:y<t=(y1,y2,⋯ ,yt−1)y_{<t}=(y_1,y_2,\cdots,y_{t-1})y<t=(y1,y2,⋯,yt−1)表示已经生成的词。
3 Encoder设计
3.1 为什么需要Encoder
Decoder无法直接看到整个输入句子。
因此需要:
输入句子
↓
压缩
↓
语义向量
即:C=f(X)C=f(X)C=f(X)
3.2 Encoder结构
class Seq2SeqEncoder(d2l.Encoder):
def __init__(
self,
vocab_size,
embed_size,
num_hiddens,
num_layers,
dropout=0
):
super().__init__()
# Embedding层
self.embedding = nn.Embedding(
vocab_size,
embed_size
)
# GRU层
self.rnn = nn.GRU(
embed_size,
num_hiddens,
num_layers,
dropout=dropout
)
3.3 为什么先Embedding?
输入:
[5,8,12,20]
实际上是:
They
are
watching
.
数字没有语义。
Embedding负责:token→vectortoken\rightarrow vectortoken→vector
例如:5→[0.2,0.7,0.4]5\rightarrow[0.2,0.7,0.4]5→[0.2,0.7,0.4]
3.4 Encoder前向传播
def forward(self, X):
X = self.embedding(X)
X = X.permute(1,0,2)
output,state = self.rnn(X)
return output,state
3.5 为什么permute?
Embedding输出:(batch_size,num_steps,embed_size)即:
(64,10,32)
而GRU要求:(num_steps,batch_size,embed_size)即:
(10,64,32)
因此:
permute(1,0,2)
3.6 Encoder张量流动
输入
(64,10)
↓ Embedding
(64,10,32)
↓ Permute
(10,64,32)
↓ GRU
output:
(10,64,32)
state:
(2,64,32)
4 Decoder设计
4.1 为什么需要Decoder
Encoder只负责:
理解句子
Decoder负责:
生成句子
4.2 Decoder结构
class Seq2SeqDecoder(d2l.Decoder):
def __init__(
self,
vocab_size,
embed_size,
num_hiddens,
num_layers,
dropout=0
):
super().__init__()
self.embedding = nn.Embedding(
vocab_size,
embed_size
)
self.rnn = nn.GRU(
embed_size + num_hiddens,
num_hiddens,
num_layers,
dropout=dropout
)
self.dense = nn.Linear(
num_hiddens,
vocab_size
)
4.3 为什么GRU输入变大?
注意:
embed_size + num_hiddens
因为:Decoder每个时刻输入:[xt,C][x_t,C][xt,C]
拼接:32+32=6432+32=6432+32=64
4.4 Decoder初始化
def init_state(self, enc_outputs):
return enc_outputs[1]
为什么?
Encoder返回:
(output,state)
其中:
state
包含整个句子的语义。
因此:
Encoder状态
↓
Decoder初始状态
4.5 Decoder前向传播
def forward(self,X,state):
X=self.embedding(X)
X=X.permute(1,0,2)
context=state[-1].repeat(
X.shape[0],
1,
1
)
X_and_context=torch.cat(
(X,context),
2
)
output,state=self.rnn(
X_and_context,
state
)
output=self.dense(output)
return output,state
4.6 为什么拼接Context?
如果不拼接:
Decoder只能看到当前词
容易遗忘源句子。
拼接后:[xt,C][x_t,C][xt,C]
Decoder始终知道:
原句在表达什么
5 Loss设计
5.1 为什么需要Mask
句子长度不同:
I go .
长度:3
I am studying .
长度:4
需要补齐:
I go . <pad>
这些:
<pad>
没有意义。
必须忽略。
5.2 Sequence Mask
def sequence_mask(
X,
valid_len,
value=0
):
例如:
valid_len=[3]
生成:
[1,1,1,0,0]
得到:
loss * mask
6 Masked Softmax CE
class MaskedSoftmaxCELoss(
nn.CrossEntropyLoss
):
核心:
weights=sequence_mask(
weights,
valid_len
)
数学表示:
普通交叉熵:L=−logp(y)L=-\log p(y)L=−logp(y)
加入Mask:
L=mask×(−logp(y))L=mask\times(-\log p(y))L=mask×(−logp(y))
7 Teacher Forcing
为什么提出?
训练时:若使用预测值:
错误
↓
继续错误
↓
彻底崩溃
Teacher Forcing:
dec_input=torch.cat(
[bos,Y[:,:-1]],
1
)
即:真实标签:
Ils regardent .
作为Decoder输入。
8 训练流程
Y_hat,_=net(
X,
dec_input,
X_valid_len
)
l=loss(
Y_hat,
Y,
Y_valid_len
)
l.sum().backward()
optimizer.step()
训练公式:
$ L=-\sum_tlogP(y_t)$
梯度更新:
θ=θ−η∇θL\theta= \theta-\eta \nabla_\theta Lθ=θ−η∇θL
9 推理流程
训练:
真实标签输入
推理:
预测结果输入
代码:
for _ in range(num_steps):
Y,state=
decoder(
dec_X,
state
)
dec_X=
Y.argmax(dim=2)
流程:
<bos>
↓
Ils
↓
regardent
↓
.
↓
<eos>
停止。
10 整个函数调用关系
11 整个网络计算图
12 小结
Seq2Seq模型本质是在学习:P(Y∣X)P(Y|X)P(Y∣X)
其核心组成:
| 模块 | 作用 |
|---|---|
| Embedding | 词向量表示 |
| Encoder GRU | 理解输入句子 |
| Context | 保存语义 |
| Decoder GRU | 生成输出句子 |
| Teacher Forcing | 加速训练 |
| Mask Loss | 忽略PAD |
| BLEU | 评价翻译质量 |
整个模型完成了:
$
输入句子
\rightarrow
语义表示
\rightarrow
目标句子
$
这也是后续 Attention、Bahdanau Attention、Transformer 的直接起点。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)