1. 引言:为什么需要LSTM?

循环神经网络(RNN)因其天然的时序结构,被广泛应用于自然语言处理、时间序列预测等任务。然而,传统RNN在处理长序列时容易遭遇梯度消失梯度爆炸问题,导致模型难以捕捉远距离的语义依赖。例如,在“我出生在法国……我会说法语”中,“法语”依赖于远在前面的“法国”,传统RNN往往难以建立这种长距离关联。

为了解决这一问题,Sepp Hochreiter 和 Jürgen Schmidhuber 于1997年提出了长短期记忆网络(Long Short-Term Memory, LSTM)。LSTM通过精巧的门控机制细胞状态,选择性地记忆或遗忘信息,从而有效缓解了长序列训练中的梯度消失问题。后来(2000年左右),Gers等人又引入了遗忘门,进一步完善了LSTM结构。


2. LSTM的核心思想

LSTM与传统RNN最大的区别在于:它引入了一条细胞状态(Cell State)的“传送带”,信息可以在时间步上几乎无损地流动。同时,LSTM使用三个门控单元(遗忘门、输入门、输出门)来控制信息的遗忘写入读出

  • 细胞状态 Ct:负责长期记忆,贯穿整个序列。

  • 隐状态 ht:负责短期记忆,也是每个时间步的输出。

  • :使用sigmoid函数输出0~1之间的值,表示信息“通过”的比例(0表示完全阻断,1表示完全通过)。


3. LSTM内部结构详解(含公式)

下图示意了单个LSTM单元的内部结构(图中省略了偏置项,但在实际实现中存在)。

3.1 遗忘门(Forget Gate)

遗忘门决定上一时刻的细胞状态 Ct−1 中有多少信息需要被丢弃。它读取当前输入 xt和上一时刻隐状态 ht−1,输出一个0~1的向量 ft​。

  • σ 为sigmoid函数。

  • [ht−1​,xt​] 表示将两个向量拼接。

  • Wf​ 和 bf​ 为可学习参数。

直观理解:如果 ft中的某个分量接近0,则对应的历史信息将被遗忘;接近1则保留。

3.2 输入门(Input Gate)

输入门决定将多少新信息写入细胞状态。它由两部分组成:

  • 门控部分 iti:决定哪些位置要更新。

  • 候选细胞状态 C~t:利用tanh层生成新的候选值向量。

  • tanh 将输出值压缩到-1到1之间,起到调节作用。

3.3 细胞状态更新

旧细胞状态 Ct−1经过遗忘门进行选择性遗忘,再与输入门筛选后的候选状态相加,得到新的细胞状态 Ct。

  • +表示逐元素相乘(Hadamard积)。

意义:这一步完美融合了“忘记过去不重要的”和“记住当前新的重要信息”。

3.4 输出门(Output Gate)

输出门决定当前时刻的隐状态 htht​(同时也是该时刻的输出)。它基于更新后的细胞状态 CtCt​,并经过一个门控筛选。

  • 先用tanh将 Ct 的值缩放至-1~1,再通过输出门 ot​ 决定哪些信息最终输出。

总结:LSTM通过上述四个步骤,实现了对长序列信息的选择性存储和读取。其中 遗忘门 和 输入门 配合完成细胞状态的更新,输出门 控制隐状态的表达。


4. PyTorch中的LSTM实现

PyTorch提供了便捷的 torch.nn.LSTM 模块,我们可以直接调用。

4.1 参数说明

nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, 
        batch_first=False, dropout=0, bidirectional=False)
  • input_size:输入特征维度(例如词向量的长度)。

  • hidden_size:隐状态 htht​ 的维度。

  • num_layers:LSTM堆叠的层数(大于1时为多层LSTM)。

  • batch_first:若为 True,输入形状为 (batch, seq_len, input_size),否则为 (seq_len, batch, input_size)

  • 注意bidirectional 参数在这里应保持 False(本文不涉及Bi-LSTM)。

4.2 输入与输出形状

  • 输入

    • input:形状 (seq_len, batch, input_size)

    • h0(可选):初始隐状态,形状 (num_layers, batch, hidden_size)

    • c0(可选):初始细胞状态,形状 (num_layers, batch, hidden_size)

  • 输出

    • output:所有时间步的隐状态,形状 (seq_len, batch, hidden_size)

    • (hn, cn):最后一个时间步的隐状态和细胞状态,形状均为 (num_layers, batch, hidden_size)

4.3 完整示例

# 定义LSTM的参数含义: (input_size, hidden_size, num_layers)
# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量和细胞初始状态张量的参数含义:
# (num_layers * num_directions, batch_size, hidden_size)

>>> import torch.nn as nn
>>> import torch
>>> rnn = nn.LSTM(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> c0 = torch.randn(2, 3, 6)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
         [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
         [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
       grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.4647, -0.2364,  0.0645, -0.3996, -0.0500, -0.0152],
         [ 0.3852,  0.0704,  0.2103, -0.2524,  0.0243,  0.0477],
         [ 0.2571,  0.0608,  0.2322,  0.1815, -0.0513, -0.0291]],

        [[ 0.0447, -0.0335,  0.1454,  0.0438,  0.0865,  0.0416],
         [ 0.0105,  0.1923,  0.5507, -0.1742,  0.1569, -0.0548],
         [-0.1186,  0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],
       grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.8083, -0.5500,  0.1009, -0.5806, -0.0668, -0.1161],
         [ 0.7438,  0.0957,  0.5509, -0.7725,  0.0824,  0.0626],
         [ 0.3131,  0.0920,  0.8359,  0.9187, -0.4826, -0.0717]],

        [[ 0.1240, -0.0526,  0.3035,  0.1099,  0.5915,  0.0828],
         [ 0.0203,  0.8367,  0.9832, -0.4454,  0.3917, -0.1983],
         [-0.2976,  0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],
       grad_fn=<StackBackward>)

在实际任务中(如情感分析),我们通常取 output[:, -1, :] 作为最后一个时间步的隐状态,再接入全连接层进行分类。


5. LSTM的优缺点

✅ 优势

  1. 长距离依赖建模能力强:相比传统RNN,LSTM通过门控机制有效缓解了梯度消失/爆炸,可以处理长达数百步的序列。

  2. 灵活性高:可以堆叠多层,也可以与其他网络(如CNN、Attention)结合。

  3. 工程成熟:各种深度学习框架均有高效实现,且有很多预训练变体。

❌ 缺点

  1. 计算复杂度高:每个时间步需要计算4个全连接层(遗忘门、输入门、输出门、候选状态),参数量约为传统RNN的4倍,训练和推理较慢。

  2. 难以并行:LSTM本质是递归结构,后一个时间步依赖前一步的输出,无法像Transformer那样进行大规模并行计算。

  3. 并非万能:在超长序列(数千步)上仍有信息衰减,且对随机打乱的序列不敏感。


6. 总结

LSTM是RNN家族中最经典、最成功的变体之一。它通过遗忘门、输入门、输出门细胞状态实现了对长期记忆的精细控制,解决了原始RNN的梯度问题。虽然近年来Transformer等模型在多数NLP任务上取得了更好效果,但LSTM在时间序列预测、语音识别、小规模序列建模等场景中依然具有重要价值。掌握LSTM的内部原理和PyTorch实现,是深入理解序列模型的关键一步。

参考文献

  • Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.

  • Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget: Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.

希望本文能帮助你彻底搞懂LSTM!如果有任何疑问,欢迎在评论区留言讨论。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐