从零手写RNN诗歌生成器:逐行代码解释+数学案例彻底搞懂

很多人学RNN的时候,总觉得代码里有很多"魔法操作":为什么要转置维度?为什么损失要乘以batch size?嵌入层到底在做什么数学运算?

这篇博客我会带你逐行拆解上面这段PyTorch古诗生成代码,所有不好理解的地方都配上具体的数学案例,保证你看完不仅能跑通代码,还能彻底明白每一行为什么这么写。

一、项目概览

我们要实现的是一个基于RNN的语言模型(RNNLM),它能学习古诗的用字规律,然后根据给定的开头字自动生成七言绝句。

完整代码运行后,你会得到类似这样的输出:

一川烟草平如剪,两岸杨花雪作团。
燕子不来春又晚,东风吹泪湿阑干。

二、数据预处理:把古诗变成模型能懂的数字

2.1 数据集合

兰叶春葳蕤,桂华秋皎洁。欣欣此生意,自尔为佳节。谁知林栖者,闻风坐相悦。草木有本心,何求美人折?
江南有丹桔,经冬犹绿林。岂伊地气暖,自有岁寒心。可以荐佳客,奈何阻重深。运命唯所遇,循环不可寻。徒言树桃李,此木岂无陰。
暮从碧山下,山月随人归。却顾所来径,苍苍横翠微。相携及田家,童稚开荆扉。绿竹入幽径,青萝拂行衣。欢言得所憩,美酒聊共挥。长歌吟松风,曲尽河星稀。我醉君复乐,陶然共忘机。
花间一壶酒,独酌无相亲。举杯邀明月,对影成三人。月既不解饮,影徒随我身。暂伴月将影,行乐须及春。我歌月徘徊,我舞影零乱。醒时同交欢,醉后各分散。永结无情游,相期邈云汉。
燕草如碧丝,秦桑低绿枝。当君怀归日,是妾断肠时。春风不相识,何事入罗帏?

2.2 核心思路

模型只能处理数字,不能直接处理汉字。所以我们需要做三件事:

  1. 清洗原始数据,去掉标点和空白
  2. 构建字表:给每个出现过的汉字分配一个唯一的ID
  3. 把所有古诗转换成ID序列

2.3 逐行代码+数学案例

def preprocess_poems(file_path):
    char_set = set()
    poems = []
    with open(file_path, 'r', encoding="utf-8") as f:
        for line in f:
            # 去掉所有标点和两侧空白
            line = re.sub(r"[,。?!、:]", "", line).strip()
            # 把诗句拆成单个字,加入字集合去重
            char_set.update(list(line))
            poems.append(list(line))
    
    # 构建字表:ID到字,字到ID
    id2word = list(char_set) + ["<UNK>"]  # <UNK>表示未见过的字
    word2id = { word:id for id, word in enumerate(id2word) }
    
    # 把所有诗句转换成ID序列
    id_seqs = []
    for poem in poems:
        id_seq = [ word2id.get(word) for word in poem ]
        id_seqs.append(id_seq)
    
    return id_seqs, id2word, word2id

数学案例
假设我们只有两句诗:“床前明月光"和"疑是地上霜”

  • 清洗后:["床前明月光", "疑是地上霜"]
  • 字集合:{"床","前","明","月","光","疑","是","地","上","霜"}
  • 字表:
    id2word = ["床","前","明","月","光","疑","是","地","上","霜","<UNK>"]
    word2id = {"床":0, "前":1, "明":2, "月":3, "光":4, "疑":5, "是":6, "地":7, "上":8, "霜":9, "<UNK>":10}
    
  • ID序列:
    id_seqs = [
        [0,1,2,3,4],  # 床前明月光
        [5,6,7,8,9]   # 疑是地上霜
    ]
    

三、构建数据集:用滑动窗口生成训练样本

3.1 核心思路

RNN语言模型的任务是:给定前N个字,预测下一个字。所以我们需要用滑动窗口把长序列切成很多个(x,y)样本对:

  • x:输入序列(前N个字)
  • y:目标序列(后N个字,也就是x每个字的下一个字)

3.2 逐行代码+数学案例

class PoetryDataset(Dataset):
    def __init__(self, id_seqs, seq_len):
        self.seq_len = seq_len
        self.data = []
        for id_seq in id_seqs:
            # 滑动窗口遍历整个序列
            for i in range(0, len(id_seq) - self.seq_len):
                # x是从i开始的seq_len个字
                # y是从i+1开始的seq_len个字
                self.data.append( (id_seq[i:i+seq_len], id_seq[i+1:i+1+seq_len]) )
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = torch.LongTensor(self.data[idx][0])
        y = torch.LongTensor(self.data[idx][1])
        return x, y

数学案例
还是用刚才的两句诗,假设seq_len=3

  • 对于第一句[0,1,2,3,4](长度5),可以生成5-3=2个样本:
    1. i=0: x=[0,1,2], y=[1,2,3] → 输入"床前明",预测"前明月"
    2. i=1: x=[1,2,3], y=[2,3,4] → 输入"前明月",预测"明月光"
  • 对于第二句[5,6,7,8,9],同样生成2个样本
  • 最终数据集共有4个样本

四、RNN模型结构:从字向量到预测概率

4.1 模型整体架构

我们的模型由三层组成:

  1. 嵌入层:把字ID转换成固定长度的向量
  2. RNN层:处理序列信息,捕捉上下文依赖
  3. 全连接层:把RNN的输出转换成每个字的概率
class PoetryRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, input, hx=None):
        embed = self.embed(input)
        output, hn = self.rnn(embed, hx)
        output = self.linear(output)
        return output, hn

4.2 嵌入层:最容易被误解的层

很多人以为嵌入层是"黑魔法",其实它就是一个简单的查表操作

数学定义
嵌入层是一个形状为(vocab_size, embedding_dim)的矩阵W。对于输入的字ID i,嵌入层的输出就是矩阵W的第i行。

数学案例
假设vocab_size=3embedding_dim=2,嵌入层矩阵W为:

W = [
    [0.1, 0.2],  # ID=0对应的向量
    [0.3, 0.4],  # ID=1对应的向量
    [0.5, 0.6]   # ID=2对应的向量
]

输入ID序列[0,1,2],嵌入层输出就是:

[
    [0.1, 0.2],
    [0.3, 0.4],
    [0.5, 0.6]
]

维度变化
输入:(batch_size, seq_len) → 比如(32,24)表示32个样本,每个样本24个字
输出:(batch_size, seq_len, embedding_dim) → 比如(32,24,256)表示每个字变成了256维的向量

4.3 RNN层:处理序列的核心

RNN的核心思想是:当前时刻的输出不仅取决于当前输入,还取决于上一时刻的隐藏状态

前向传播公式

h_t = tanh(W_ih * x_t + b_ih + W_hh * h_{t-1} + b_hh)

其中:

  • h_t:当前时刻的隐藏状态
  • x_t:当前时刻的输入(嵌入向量)
  • h_{t-1}:上一时刻的隐藏状态
  • W_ih, W_hh, b_ih, b_hh:可学习的参数

维度变化
输入:(batch_size, seq_len, embedding_dim) → (32,24,256)
输出:(batch_size, seq_len, hidden_size) → (32,24,512)
隐藏状态hn:(num_layers, batch_size, hidden_size) → (2,32,512)

4.4 全连接层:输出每个字的概率

全连接层把RNN输出的隐藏状态映射到词汇表大小,这样每个时间步都会输出一个长度为vocab_size的向量,代表每个字的得分(还不是概率,需要经过softmax转换)。

维度变化
输入:(batch_size, seq_len, hidden_size) → (32,24,512)
输出:(batch_size, seq_len, vocab_size) → (32,24,2439)

五、模型训练:最容易踩坑的部分

这部分是整个代码里最难理解的地方,我会把每个"魔法操作"都拆解开,配上数学案例。

5.1 损失函数:为什么要转置维度?

loss_value = loss(output.transpose(1,2), y)

这行代码是90%的初学者都会卡住的地方。要理解它,我们首先要搞清楚CrossEntropyLoss的输入要求。

CrossEntropyLoss的输入输出要求

PyTorch的CrossEntropyLoss要求:

  • 输入(预测值):形状为(N, C, d1, d2, ...),其中C是类别数
  • 目标(真实值):形状为(N, d1, d2, ...),每个元素是0到C-1之间的整数
我们的维度问题
  • 模型输出output:(batch_size, seq_len, vocab_size) → (32,24,2439) (注:vocab_size即是类别数)
  • 真实标签y:(batch_size, seq_len) → (32,24)

如果直接把outputy传给CrossEntropyLoss,会报错!因为:

  • 我们的C(类别数=vocab_size)在第3维
  • 而CrossEntropyLoss要求C在第2维

所以我们需要用transpose(1,2)把第2维和第3维交换:

output.transpose(1,2) → (batch_size, vocab_size, seq_len) → (32,2439,24)

现在维度就完全匹配了:

  • 输入:(32,2439,24) → N=32, C=2439, d1=24
  • 目标:(32,24) → N=32, d1=24

数学案例
假设batch_size=2seq_len=3vocab_size=4

  • output:
    [
      [[0.1,0.2,0.3,0.4], [0.5,0.6,0.7,0.8], [0.9,1.0,1.1,1.2]],  # 第一个样本的3个时间步
      [[1.3,1.4,1.5,1.6], [1.7,1.8,1.9,2.0], [2.1,2.2,2.3,2.4]]   # 第二个样本的3个时间步
    ]
    
  • y:
    [
      [1,2,3],  # 第一个样本的真实标签
      [0,1,2]   # 第二个样本的真实标签
    ]
    
  • transpose之后:
    [
      [[0.1,0.5,0.9], [0.2,0.6,1.0], [0.3,0.7,1.1], [0.4,0.8,1.2]],  # 第一个样本
      [[1.3,1.7,2.1], [1.4,1.8,2.2], [1.5,1.9,2.3], [1.6,2.0,2.4]]   # 第二个样本
    ]
    

现在CrossEntropyLoss会计算每个时间步的损失,然后求平均。

5.2 为什么损失要乘以x.shape[0]?

train_loss += loss_value.item() * x.shape[0]

这又是一个很多人搞不懂的地方。答案很简单:因为CrossEntropyLoss默认返回的是每个样本的平均损失

详细解释
  • loss_value.item():返回的是当前batch中每个样本的平均损失
  • x.shape[0]:当前batch的大小(batch_size)
  • 相乘之后得到的是当前batch的总损失

数学案例
假设我们有3个batch:

  1. batch1:size=32,平均损失=0.5 → 总损失=0.5×32=16
  2. batch2:size=32,平均损失=0.4 → 总损失=0.4×32=12.8
  3. batch3:size=36,平均损失=0.3 → 总损失=0.3×36=10.8

整个epoch的总损失=16+12.8+10.8=39.6
整个epoch的平均损失=39.6/(32+32+36)=39.6/100=0.396

如果不乘以batch_size,直接加平均损失:
总损失=0.5+0.4+0.3=1.2
平均损失=1.2/3=0.4 → 这是错误的!因为每个batch的大小不一样。

5.3 完整训练流程

def train(model, dataset, lr, epoch_num, batch_size, device):
    model.to(device)
    model.train()
    loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epoch_num):
        train_loss = 0
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            
            # 前向传播
            output, _ = model(x)
            # 计算损失
            loss_value = loss(output.transpose(1,2), y)
            # 反向传播
            loss_value.backward()
            # 更新参数
            optimizer.step()
            # 梯度清零(必须!否则梯度会累加)
            optimizer.zero_grad()

            train_loss += loss_value.item() * x.shape[0]

        # 计算本轮平均损失
        this_loss = train_loss / len(dataset)
        print(f"Epoch:{epoch + 1:0>2} train loss: {this_loss:.4f}")

六、诗歌生成:让模型自己写诗

训练好模型之后,我们就可以用它来生成新诗了。生成的核心逻辑是:用前一个字预测下一个字,然后把预测出来的字作为下一次的输入,循环往复

def generate_poem(model, id2word, word2id, start_token, line_num=4, line_length=7):
    model.eval()
    poem = []
    current_rest_len = line_length
    
    # 把开头字转换成ID
    start_id = word2id.get(start_token, word2id["<UNK>"])
    if start_id != word2id["<UNK>"]:
        poem.append(start_token)
        current_rest_len -= 1
    
    input = torch.LongTensor([[start_id]]).to(device)
    
    with torch.no_grad():  # 生成时不需要计算梯度
        for i in range(line_num):
            for interpunction in [",","。\n"]:
                while current_rest_len > 0:
                    # 前向传播,得到下一个字的得分
                    output, _ = model(input)
                    # 转换成概率分布
                    prob = torch.softmax(output[0,0], dim=-1)
                    # 基于概率分布采样下一个字(比argmax更有多样性)
                    next_id = torch.multinomial(prob, num_samples=1)
                    # 把ID转换成字,加入结果
                    poem.append(id2word[next_id.item()])
                    # 更新输入为刚生成的字
                    input = next_id.unsqueeze(0)
                    current_rest_len -= 1
                
                poem.append(interpunction)
                current_rest_len = line_length

    return "".join(poem)

关键技巧:为什么用multinomial而不是argmax

  • argmax:总是选择概率最大的字 → 生成的诗歌很单调,重复很多字
  • multinomial:基于概率分布随机采样 → 生成的诗歌更有多样性和创造性

七、常见问题与优化方向

  1. 生成的诗歌不通顺

    • 增加训练轮数(建议50-100轮)
    • 用LSTM或GRU代替普通RNN(解决梯度消失问题)
    • 增加训练数据量
  2. 生成重复的字

    • 加入温度系数(temperature)控制采样的随机性
    • 使用beam search代替随机采样
  3. 训练损失下降很慢

    • 调整学习率(建议1e-3到1e-4之间)
    • 增加嵌入层维度和隐藏层大小

八、总结

通过这个项目,我们彻底搞懂了RNN语言模型的完整流程:

  1. 数据预处理:把文本转换成ID序列
  2. 数据集构建:用滑动窗口生成(x,y)样本对
  3. 模型结构:嵌入层→RNN层→全连接层
  4. 训练过程:重点理解CrossEntropyLoss的维度要求和损失计算
  5. 生成过程:逐字采样生成新文本
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐