相关项目下载链接

基于 Transformer 的自回归图像生成模型完整的链路是:1、先用 Patch AutoEncoder + BSQ 量化器,把原始图像压缩为离散的 token 序列(每个整数 token 对应原图的一个小图像 patch)2、训练这个自回归 Transformer 模型,学习 token 之间的空间共现规律;3、通过generate方法生成全新的token序列;4、用 BSQ 量化器把 token 序列解码回可保存的 png 图片。
本节内容主要介绍如何通过generate方法生成全新的 token 序列。

定义主模型

主模型对应的代码在autoregressive.py,在上一节中我们并没有定义generation方法的具体实现,本节对其逻辑进行补全。为了兼容补全后的generation方法,还需要对前向传播算法进行维度匹配调整。

补全generation方法

 @torch.no_grad()
    def generate(self, B: int = 1, h: int = 20, w: int = 30, device=None) -> torch.Tensor:
        if device is None:
            device = self.embedding.weight.device

        gen_seq = torch.zeros((B, h, w), dtype=torch.long, device=device)
        total_len = h * w

        for k in range(total_len):
            # 把 1D 索引 k 转回 2D 坐标 (i,j)
            i = k // w  # 行号
            j = k % w   # 列号

            logits, _ = self.forward(gen_seq)

            next_token_logits = logits[:, i, j, :] / 0.9
            
            next_token = torch.multinomial(
                F.softmax(next_token_logits, dim=-1), 
                num_samples=1
            ).squeeze(1)

            gen_seq[:, i, j] = next_token

        return gen_seq

调整前向传播算法

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    	# 对训练和推理进行维度匹配
        if x.dim() == 4:
            x = x.squeeze(1)
        B, h, w = x.shape
        L = h * w

        # 展平成序列
        x_flat = x.reshape(B, L)

        # 嵌入 + 位置编码
        token_emb = self.embedding(x_flat)
        pos_idx = torch.arange(L, device=x.device)
        pos_emb = self.pos_emb(pos_idx)
        x_emb = token_emb + pos_emb

        # 自回归右移(关键)
        x_emb = F.pad(x_emb, (0,0,1,0))[:, :-1]

        # 因果掩码
        mask = self._generate_causal_mask(L, x.device)
        trans_out = self.transformer(x_emb, mask=mask)

        # 输出
        logits = self.fc_out(trans_out)
        logits_2d = logits.reshape(B, h, w, self.n_tokens)

        return logits_2d, {}

模块测评

下面进行图像生成的功能测试:

mkdir test
python -m homework.generation checkpoints/BSQPatchAutoEncoder.pth checkpoints/AutoregressiveModel.pth 8 test 

所得的解码后的PNG图片如下所示:

将代码打包为压缩文件

python bundle.py homework 20260412

进行评分自测:

python -m grader 20260412.zip

最终测试得分如下:
在这里插入图片描述

可选的优化方向

更优的量化器(更小的图像块、更高的码率)

  1. 缩小 patch 尺寸:你当前patch_size=5,可改为 3或2。更小的图像块意味着更细的图像粒度,大幅减少单 patch 的信息损失,生成的图像细节更丰富、块效应更少。
  2. 提升码本码率:你当前codebook_bits=10(仅 1024 个码本),可提升到 12或14。码本容量越大,量化的精度越高,单个 token 能表达的图像信息越丰富,生成的画面连贯性更强。
  3. 辅助优化:提升 Patch AutoEncoder 的重建能力(比如增加卷积层、调整 latent_dim),降低量化器的基础重建 MSE,从根源上提升 token 的质量。

更大的 Transformer 模型参数量

  1. 增加 Transformer 深度:把 Encoder 层数从 2 层提升到 4/6 层,更深的网络能拟合更复杂的 token 序列分布。
  2. 提升隐层维度:把d_latent从 128 提升到 256/512(注意nhead必须能整除d_latent),更高的维度能承载更丰富的图像语义信息。

更优的训练策略

  1. 增加训练轮次:可提升到 10/20/50 轮,配合学习率衰减策略,让模型充分学习 patch 的空间分布和长距离依赖关系。
  2. 优化学习率策略:在 AdamW 优化器中加入「warmup 预热 + 余弦退火衰减」,避免训练初期梯度爆炸,同时让模型在训练后期更精细地拟合分布,大幅提升生成效果。
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐