深入理解LSTM:从结构到PyTorch实践
1. 引言:为什么需要LSTM?
循环神经网络(RNN)因其天然的时序结构,被广泛应用于自然语言处理、时间序列预测等任务。然而,传统RNN在处理长序列时容易遭遇梯度消失或梯度爆炸问题,导致模型难以捕捉远距离的语义依赖。例如,在“我出生在法国……我会说法语”中,“法语”依赖于远在前面的“法国”,传统RNN往往难以建立这种长距离关联。
为了解决这一问题,Sepp Hochreiter 和 Jürgen Schmidhuber 于1997年提出了长短期记忆网络(Long Short-Term Memory, LSTM)。LSTM通过精巧的门控机制和细胞状态,选择性地记忆或遗忘信息,从而有效缓解了长序列训练中的梯度消失问题。后来(2000年左右),Gers等人又引入了遗忘门,进一步完善了LSTM结构。
2. LSTM的核心思想
LSTM与传统RNN最大的区别在于:它引入了一条细胞状态(Cell State)的“传送带”,信息可以在时间步上几乎无损地流动。同时,LSTM使用三个门控单元(遗忘门、输入门、输出门)来控制信息的遗忘、写入和读出。
-
细胞状态 Ct:负责长期记忆,贯穿整个序列。
-
隐状态 ht:负责短期记忆,也是每个时间步的输出。
-
门:使用sigmoid函数输出0~1之间的值,表示信息“通过”的比例(0表示完全阻断,1表示完全通过)。
3. LSTM内部结构详解(含公式)
下图示意了单个LSTM单元的内部结构(图中省略了偏置项,但在实际实现中存在)。

3.1 遗忘门(Forget Gate)
遗忘门决定上一时刻的细胞状态 Ct−1 中有多少信息需要被丢弃。它读取当前输入 xt和上一时刻隐状态 ht−1,输出一个0~1的向量 ft。

-
σ 为sigmoid函数。
-
[ht−1,xt] 表示将两个向量拼接。
-
Wf 和 bf 为可学习参数。
直观理解:如果 ft中的某个分量接近0,则对应的历史信息将被遗忘;接近1则保留。
3.2 输入门(Input Gate)
输入门决定将多少新信息写入细胞状态。它由两部分组成:
-
门控部分 iti:决定哪些位置要更新。
-
候选细胞状态 C~t:利用tanh层生成新的候选值向量。

-
tanh 将输出值压缩到-1到1之间,起到调节作用。
3.3 细胞状态更新
旧细胞状态 Ct−1经过遗忘门进行选择性遗忘,再与输入门筛选后的候选状态相加,得到新的细胞状态 Ct。

-
+表示逐元素相乘(Hadamard积)。
意义:这一步完美融合了“忘记过去不重要的”和“记住当前新的重要信息”。
3.4 输出门(Output Gate)
输出门决定当前时刻的隐状态 htht(同时也是该时刻的输出)。它基于更新后的细胞状态 CtCt,并经过一个门控筛选。

-
先用tanh将 Ct 的值缩放至-1~1,再通过输出门 ot 决定哪些信息最终输出。
总结:LSTM通过上述四个步骤,实现了对长序列信息的选择性存储和读取。其中 遗忘门 和 输入门 配合完成细胞状态的更新,输出门 控制隐状态的表达。
4. PyTorch中的LSTM实现
PyTorch提供了便捷的 torch.nn.LSTM 模块,我们可以直接调用。
4.1 参数说明
nn.LSTM(input_size, hidden_size, num_layers=1, bias=True,
batch_first=False, dropout=0, bidirectional=False)
-
input_size:输入特征维度(例如词向量的长度)。 -
hidden_size:隐状态 htht 的维度。 -
num_layers:LSTM堆叠的层数(大于1时为多层LSTM)。 -
batch_first:若为True,输入形状为(batch, seq_len, input_size),否则为(seq_len, batch, input_size)。 -
注意:
bidirectional参数在这里应保持False(本文不涉及Bi-LSTM)。
4.2 输入与输出形状
-
输入:
-
input:形状(seq_len, batch, input_size) -
h0(可选):初始隐状态,形状(num_layers, batch, hidden_size) -
c0(可选):初始细胞状态,形状(num_layers, batch, hidden_size)
-
-
输出:
-
output:所有时间步的隐状态,形状(seq_len, batch, hidden_size) -
(hn, cn):最后一个时间步的隐状态和细胞状态,形状均为(num_layers, batch, hidden_size)
-
4.3 完整示例
# 定义LSTM的参数含义: (input_size, hidden_size, num_layers)
# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量和细胞初始状态张量的参数含义:
# (num_layers * num_directions, batch_size, hidden_size)
>>> import torch.nn as nn
>>> import torch
>>> rnn = nn.LSTM(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> c0 = torch.randn(2, 3, 6)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],
[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],
[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152],
[ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477],
[ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]],
[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],
[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],
[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161],
[ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626],
[ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]],
[[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828],
[ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983],
[-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],
grad_fn=<StackBackward>)
在实际任务中(如情感分析),我们通常取 output[:, -1, :] 作为最后一个时间步的隐状态,再接入全连接层进行分类。
5. LSTM的优缺点
✅ 优势
-
长距离依赖建模能力强:相比传统RNN,LSTM通过门控机制有效缓解了梯度消失/爆炸,可以处理长达数百步的序列。
-
灵活性高:可以堆叠多层,也可以与其他网络(如CNN、Attention)结合。
-
工程成熟:各种深度学习框架均有高效实现,且有很多预训练变体。
❌ 缺点
-
计算复杂度高:每个时间步需要计算4个全连接层(遗忘门、输入门、输出门、候选状态),参数量约为传统RNN的4倍,训练和推理较慢。
-
难以并行:LSTM本质是递归结构,后一个时间步依赖前一步的输出,无法像Transformer那样进行大规模并行计算。
-
并非万能:在超长序列(数千步)上仍有信息衰减,且对随机打乱的序列不敏感。
6. 总结
LSTM是RNN家族中最经典、最成功的变体之一。它通过遗忘门、输入门、输出门和细胞状态实现了对长期记忆的精细控制,解决了原始RNN的梯度问题。虽然近年来Transformer等模型在多数NLP任务上取得了更好效果,但LSTM在时间序列预测、语音识别、小规模序列建模等场景中依然具有重要价值。掌握LSTM的内部原理和PyTorch实现,是深入理解序列模型的关键一步。
参考文献
Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.
Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget: Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
希望本文能帮助你彻底搞懂LSTM!如果有任何疑问,欢迎在评论区留言讨论。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)