论文信息

  • 标题:ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting
  • 会议:NeurIPS 2023
  • 单位:南洋理工大学 S-Lab
  • 代码:https://github.com/zsyOAOA/ResShift
  • 论文:https://arxiv.org/pdf/2307.12348

一、开篇:扩散超分的"速度噩梦"终于被打破了

你有没有过这样的体验:好不容易找到一个效果不错的扩散超分模型,结果跑一张256×256的图要等十几秒?传统扩散模型做超分辨率(SR)就像让一个画家从一张白纸开始画画,需要一笔一笔慢慢涂,动辄几百上千步才能完成。虽然最后画出来的效果不错,但这个速度实在让人着急。

更气人的是,如果你想加速,用DDIM、DPM-Solver这些快速采样算法把步数压缩到几十步,结果图像立刻变得模糊不清,细节全没了。这就是扩散超分领域长期存在的"速度-质量"死结:要么慢得要死,要么快得难看

直到ResShift的出现,这个死结终于被解开了。它提出了一种全新的扩散模型设计思路,只需要15步采样,就能达到甚至超过传统扩散模型1000步的效果,而且速度比LDM-100快了整整4倍!

二、ResShift核心思想:从"从零开始画"到"在草稿上添细节"

传统扩散模型的核心思路是:先把高清图像一步步加噪声变成高斯白噪声,然后再训练模型从高斯白噪声一步步去噪恢复出高清图像。这个思路对于图像生成来说很合理,但对于超分辨率来说却很浪费——我们明明已经有了低分辨率图像这个"草稿",为什么还要从零开始呢?

ResShift的核心洞察就是:超分辨率不需要从高斯噪声开始,只需要从低分辨率图像开始,逐步转移高分辨率和低分辨率之间的残差(也就是缺失的细节)就行了

通俗解释:

  • 传统扩散超分:给你一张白纸,让你画出一只猫 → 需要画几百笔
  • ResShift超分:给你一张已经画好猫轮廓的草稿,让你添上毛发、眼睛、胡须这些细节 → 只需要画十几笔
    在这里插入图片描述

图1 ResShift整体工作流(出处:原论文图2)

  • 左边是前向过程:从高清图像x0x_0x0开始,逐步添加残差和噪声,最终变成低分辨率图像y0y_0y0
  • 右边是反向过程:从低分辨率图像y0y_0y0开始,逐步去除残差和噪声,最终恢复出高清图像x0x_0x0
  • 整个过程只需要15步,比传统扩散模型快了两个数量级

三、ResShift关键技术详解

3.1 整体架构

ResShift构建了一个连接高清图像和低分辨率图像的马尔可夫链(通俗解释:一个状态只能由前一个状态决定的链条)。这个链条的初始状态是高清图像,最终状态是低分辨率图像。超分辨率任务就是沿着这个链条反向走一遍,从低分辨率图像恢复出高清图像。

为了进一步加速计算,ResShift还可以在VQGAN的隐空间中运行,把图像压缩4倍,计算量直接减少16倍,而且几乎不会损失效果。

3.2 前向过程:从高清到低清的残差转移

前向过程的目标是把高清图像x0x_0x0逐步变成低分辨率图像y0y_0y0。我们首先定义残差e0e_0e0
e0=y0−x0e_0 = y_0 - x_0e0=y0x0
通俗解释:残差就是低分辨率图像和高清图像之间的差异,也就是我们需要恢复的所有细节信息。

然后我们引入一个单调递增的转移序列{ηt}t=1T\{\eta_t\}_{t=1}^T{ηt}t=1T,满足η1→0\eta_1 \to 0η10ηT→1\eta_T \to 1ηT1。前向过程的转移分布定义为:
q(xt∣xt−1,y0)=N(xt;xt−1+αte0,κ2αtI)q(x_t | x_{t-1}, y_0) = \mathcal{N}(x_t; x_{t-1} + \alpha_t e_0, \kappa^2 \alpha_t I)q(xtxt1,y0)=N(xt;xt1+αte0,κ2αtI)

公式中每个符号的含义:

  • q(xt∣xt−1,y0)q(x_t | x_{t-1}, y_0)q(xtxt1,y0):给定前一个状态xt−1x_{t-1}xt1和低分辨率图像y0y_0y0,当前状态xtx_txt的条件概率分布
  • N\mathcal{N}N:正态分布(高斯分布)
  • xtx_txt:第t步的中间图像
  • xt−1x_{t-1}xt1:第t-1步的中间图像
  • αt\alpha_tαt:第t步的残差转移步长,αt=ηt−ηt−1\alpha_t = \eta_t - \eta_{t-1}αt=ηtηt1(t>1时),α1=η1\alpha_1 = \eta_1α1=η1
  • e0e_0e0:高分辨率图像x0x_0x0和低分辨率图像y0y_0y0之间的残差
  • κ\kappaκ:控制噪声强度的超参数
  • III:单位矩阵

通过数学推导,我们可以得到任意时刻t的边际分布:
q(xt∣x0,y0)=N(xt;x0+ηte0,κ2ηtI)q(x_t | x_0, y_0) = \mathcal{N}(x_t; x_0 + \eta_t e_0, \kappa^2 \eta_t I)q(xtx0,y0)=N(xt;x0+ηte0,κ2ηtI)

这个公式非常重要,它告诉我们:

  • 当t=1时,η1→0\eta_1 \to 0η10x1x_1x1几乎等于高清图像x0x_0x0
  • 当t=T时,ηT→1\eta_T \to 1ηT1xTx_TxT几乎等于低分辨率图像y0y_0y0加上一点噪声

这就完美实现了我们的目标:构建一个从高清图像到低分辨率图像的平滑过渡链条。

3.3 反向过程:从低清到高清的残差恢复

反向过程的目标是从低分辨率图像y0y_0y0恢复出高清图像x0x_0x0。根据贝叶斯定理,我们可以推导出反向过程的后验分布:
q(xt−1∣xt,x0,y0)=N(xt−1∣ηt−1ηtxt+αtηtx0,κ2ηt−1ηtαtI)q(x_{t-1}|x_t, x_0, y_0) = \mathcal{N}\left( x_{t-1} \bigg| \frac{\eta_{t-1}}{\eta_t}x_t + \frac{\alpha_t}{\eta_t}x_0, \kappa^2 \frac{\eta_{t-1}}{\eta_t}\alpha_t I \right)q(xt1xt,x0,y0)=N(xt1 ηtηt1xt+ηtαtx0,κ2ηtηt1αtI)

这个分布的均值由两部分组成:

  1. 前一个状态xtx_txt的加权平均
  2. 原始高清图像x0x_0x0的加权平均

因此,我们可以把反向过程的均值参数化为:
μθ(xt,y0,t)=ηt−1ηtxt+αtηtfθ(xt,y0,t)\mu_\theta(x_t, y_0, t) = \frac{\eta_{t-1}}{\eta_t} x_t + \frac{\alpha_t}{\eta_t} f_\theta(x_t, y_0, t)μθ(xt,y0,t)=ηtηt1xt+ηtαtfθ(xt,y0,t)

公式中每个符号的含义:

  • μθ\mu_\thetaμθ:反向过程预测的均值
  • ηt\eta_tηt:第t步的累积残差转移比例
  • fθf_\thetafθ:我们训练的神经网络,用来预测原始高分辨率图像x0x_0x0
  • θ\thetaθ:神经网络的可学习参数

这样,我们的训练目标就变得非常简单:让神经网络fθf_\thetafθ尽可能准确地预测出原始高清图像x0x_0x0。最终的损失函数是:
min⁡θ∑t∥fθ(xt,y0,t)−x0∥22\min_\theta \sum_t \left\| f_\theta(x_t, y_0, t) - x_0 \right\|_2^2θmintfθ(xt,y0,t)x022

通俗解释:我们训练一个神经网络,给它输入中间状态xtx_txt、低分辨率图像y0y_0y0和当前步数t,让它输出原始高清图像x0x_0x0的预测值。训练的目标就是让这个预测值和真实的高清图像尽可能接近。

3.4 灵活的噪声调度

ResShift设计了一个非常灵活的噪声调度方案,可以精确控制残差转移的速度和噪声强度。对于中间时刻的ηt\sqrt{\eta_t}ηt ,我们采用非均匀几何序列:
ηt=η1×b0βt,t=2,⋯ ,T−1\sqrt{\eta_t} = \sqrt{\eta_1} \times b_0^{\beta_t}, \quad t=2, \cdots, T-1ηt =η1 ×b0βt,t=2,,T1
βt=(t−1T−1)p×(T−1),b0=exp[12(T−1)logηTη1]\beta_t = \left(\frac{t-1}{T-1}\right)^p \times (T-1), \quad b_0 = exp\left[\frac{1}{2(T-1)} log \frac{\eta_T}{\eta_1}\right]βt=(T1t1)p×(T1),b0=exp[2(T1)1logη1ηT]

公式中每个符号的含义:

  • η1\eta_1η1:第一步的累积残差比例,设置为(0.04/κ)2(0.04/\kappa)^2(0.04/κ)2和0.001中的较小值
  • ηT\eta_TηT:最后一步的累积残差比例,设置为0.999
  • b0b_0b0:几何序列的底数
  • βt\beta_tβt:第t步的指数
  • ppp:控制残差转移速度的超参数,是整个噪声调度的核心

超参数ppp的作用非常关键:

  • ppp很小时(如0.3),残差在早期就快速转移,模型有更多时间生成丰富的细节,视觉效果更好,但保真度略有下降
  • ppp很大时(如3.0),残差转移比较均匀,模型更注重保真度,但生成的细节会比较少
    在这里插入图片描述

图2 不同超参数下的噪声调度(出处:原论文图3(h))

  • 横轴是时间步t,纵轴是ηt\sqrt{\eta_t}ηt
  • 可以看到,p越小,曲线越陡峭,早期残差转移越快;p越大,曲线越平缓,残差转移越均匀

另一个重要的超参数是κ\kappaκ,它控制整个扩散过程的噪声强度:

  • κ\kappaκ太小时(如0.5),噪声太弱,模型生成的细节不够丰富
  • κ\kappaκ太大时(如16.0),噪声太强,模型容易生成不真实的纹理
  • 论文中选择κ=2.0\kappa=2.0κ=2.0作为默认值,在细节丰富度和真实性之间取得了最佳平衡

四、实验结果与分析

4.1 实验设置

  • 训练数据:ImageNet训练集,随机裁剪256×256的高清图像
  • 退化模型:RealESRGAN的退化管道,包括模糊、下采样、噪声和JPEG压缩
  • 优化器:Adam,学习率=5e-5,训练50万次迭代
  • 网络架构:基于UNet,把自注意力层替换成Swin Transformer块,提高对任意分辨率的支持
  • 测试数据集
    • ImageNet-Test:3000张合成退化图像
    • RealSR:100张真实世界图像
    • RealSet65:65张真实世界图像(35张来自文献,30张来自互联网)

4.2 模型超参数分析

我们首先分析了不同超参数对模型性能的影响,结果如下表所示:

表1 不同超参数配置的性能对比(出处:原论文表1)

T p κ PSNR ↑ SSIM ↑ LPIPS ↓ CLIPIQA ↑ MUSIQ ↑
10 0.3 2.0 25.20 0.6828 0.2517 0.5492 50.6617
15 0.3 2.0 25.01 0.6769 0.2312 0.5922 53.6596
30 0.3 2.0 24.52 0.6585 0.2253 0.6273 55.7904
15 0.3 0.5 24.90 0.6709 0.2437 0.5700 50.6101
15 0.3 1.0 24.84 0.6699 0.2354 0.5914 52.9933
15 0.3 2.0 25.01 0.6769 0.2312 0.5922 53.6596
15 0.3 8.0 25.31 0.6858 0.2592 0.5231 49.3182
15 0.1 2.0 25.01 0.6769 0.2312 0.5922 53.6596
15 0.5 2.0 25.05 0.6745 0.2387 0.5816 52.4475
15 1.0 2.0 25.12 0.6780 0.2613 0.5314 48.4964
15 3.0 2.0 25.39 0.5813 0.3432 0.4041 38.5324

结果分析

  • 采样步数T:随着T增加,全参考指标(PSNR、SSIM)下降,无参考指标(CLIPIQA、MUSIQ)上升。这是因为更多的步数让模型有更多时间生成细节,但也会稍微偏离原始图像。T=15是速度和质量的最佳平衡点。
  • 超参数p:随着p增加,全参考指标上升,无参考指标下降。这验证了我们之前的分析:p小细节多,p大保真度高。
  • 超参数κ:κ在1.0-2.0之间时,模型取得了最佳的综合性能。κ太小或太大都会导致图像质量下降。
    在这里插入图片描述

图3 不同超参数配置的视觉效果对比(出处:原论文图4)

  • 可以直观地看到,T=15、p=0.3、κ=2.0的配置取得了最佳的视觉效果,细节丰富且自然真实

4.3 与其他方法的对比

我们在ImageNet-Test数据集上与当前最先进的方法进行了对比,包括BSRGAN、RealESRGAN、SwinIR和LDM。

表2 效率与性能对比(出处:原论文表2)

方法 PSNR ↑ LPIPS ↓ CLIPIQA ↑ 运行时间(s) ↓ 参数量(M)
BSRGAN 24.42 0.259 0.581 0.012 16.70
RealESRGAN 24.04 0.254 0.523 0.013 16.70
SwinIR 23.99 0.238 0.564 0.046 28.01
LDM-15 24.89 0.269 0.512 0.102 113.60
LDM-30 24.49 0.248 0.572 0.184 113.60
LDM-100 23.90 0.244 0.620 0.413 113.60
ResShift-15 25.01 0.231 0.592 0.105 118.59

结果分析

  • ResShift-15在PSNR和LPIPS上都取得了最好的成绩,说明它在保真度和感知质量上都优于其他方法
  • ResShift-15的运行时间和LDM-15几乎相同,但效果远远好于LDM-15
  • ResShift-15比LDM-100快了整整4倍,而且PSNR更高,LPIPS更低,真正实现了"又快又好"

表3 合成数据定量对比(出处:原论文表3)

方法 PSNR ↑ SSIM ↑ LPIPS ↓ CLIPIQA ↑ MUSIQ ↑
ESRGAN 20.67 0.448 0.485 0.451 43.615
RealSR-JPEG 23.11 0.591 0.326 0.537 46.981
BSRGAN 24.42 0.659 0.259 0.581 54.697
SwinIR 23.99 0.667 0.238 0.564 53.790
RealESRGAN 24.04 0.665 0.254 0.523 52.538
DASR 24.75 0.675 0.250 0.536 48.337
LDM-15 24.89 0.670 0.269 0.512 46.419
ResShift 25.01 0.677 0.231 0.592 53.660

结果分析

  • ResShift在所有5个指标上都取得了最好或第二好的成绩
  • 特别是在PSNR和LPIPS上,ResShift大幅领先于其他扩散方法和GAN方法
  • 这说明ResShift在生成真实细节的同时,能够很好地保持对原始图像的保真度

4.4 真实世界数据对比

在真实世界数据集上,我们主要使用无参考指标CLIPIQA和MUSIQ来评估图像质量,因为真实世界图像没有对应的高清真值。

表4 真实世界数据定量对比(出处:原论文表4)

方法 RealSR RealSet65
CLIPIQA ↑ MUSIQ ↑ CLIPIQA ↑ MUSIQ ↑
ESRGAN 0.2362 29.048 0.3739 42.369
RealSR-JPEG 0.3615 36.076 0.5282 50.539
BSRGAN 0.5439 63.586 0.6163 65.582
SwinIR 0.4654 59.636 0.5782 63.822
RealESRGAN 0.4898 59.678 0.5995 63.220
DASR 0.3629 45.825 0.4965 55.708
LDM-15 0.3836 49.317 0.4274 47.488
ResShift 0.5958 59.873 0.6537 61.330

结果分析

  • ResShift在两个真实世界数据集的CLIPIQA指标上都取得了压倒性的第一
  • CLIPIQA是基于CLIP模型的无参考质量指标,和人类视觉感知高度相关,这说明ResShift生成的图像最符合人类的审美
  • 在MUSIQ指标上,ResShift也取得了有竞争力的成绩
    在这里插入图片描述

图4 真实世界图像恢复效果对比(出处:原论文图6)

  • 第一列是低分辨率输入
  • 可以看到,ResShift生成的图像细节最丰富,纹理最自然,没有明显的人工痕迹
  • LDM-15的结果过于平滑,丢失了很多细节
  • GAN方法的结果虽然有细节,但经常出现不自然的 artifacts

4.5 感知-失真权衡

在图像恢复领域,存在一个著名的"感知-失真权衡":感知质量好的模型,保真度往往较低;保真度高的模型,感知质量往往较差。
在这里插入图片描述

图5 感知-失真权衡曲线(出处:原论文图7)

  • 横轴是失真度(MSE,越小越好),纵轴是感知质量(LPIPS,越小越好)
  • 曲线越靠近左下角,说明模型在感知和失真之间的平衡越好
  • 可以看到,ResShift的曲线始终位于LDM的下方,说明ResShift在平衡感知和失真方面明显优于LDM

4.6 失败案例

ResShift也不是万能的,它在处理一些严重退化的图像时仍然会遇到困难。
在这里插入图片描述

图6 典型失败案例(出处:原论文图9)

  • 这是一张严重退化的漫画图像,所有方法都无法很好地恢复
  • 这是因为大多数超分模型都是在自然图像上训练的,对漫画这种特殊风格的图像泛化能力较差
  • 未来的工作可以通过在更多样化的数据集上训练来解决这个问题

五、核心代码实现

以下是ResShift的简化推理代码,基于官方实现:

import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torchvision import transforms
from resshift import ResShiftModel

# 设备设置
device = "cuda" if torch.cuda.is_available() else "cpu"

# 1. 加载预训练模型
model = ResShiftModel.from_pretrained(
    "zsyoa/resshift-realbsr4x-v1",
    torch_dtype=torch.float16
).to(device)

# 2. 图像预处理
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    # 调整图像大小为64×64(4倍超分到256×256)
    transform = transforms.Compose([
        transforms.Resize((64, 64), interpolation=Image.Resampling.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    return transform(image).unsqueeze(0).to(device, dtype=torch.float16)

# 3. 后处理
def postprocess_image(tensor):
    tensor = tensor.squeeze(0).cpu().float()
    tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
    tensor = tensor.permute(1, 2, 0).numpy()
    tensor = (tensor * 255).astype(np.uint8)
    return Image.fromarray(tensor)

# 4. ResShift采样函数
def resshift_sampling(model, y0, num_steps=15, p=0.3, kappa=2.0):
    # 计算噪声调度
    eta1 = min((0.04 / kappa) ** 2, 0.001)
    etaT = 0.999
    b0 = np.exp((1 / (2 * (num_steps - 1))) * np.log(etaT / eta1))
    
    etas = [eta1]
    for t in range(2, num_steps + 1):
        beta_t = ((t - 1) / (num_steps - 1)) ** p * (num_steps - 1)
        eta_t = (np.sqrt(eta1) * (b0 ** beta_t)) ** 2
        etas.append(eta_t)
    
    alphas = [etas[0]]
    for t in range(1, num_steps):
        alphas.append(etas[t] - etas[t-1])
    
    # 初始化x_T为y0加上噪声
    x_t = y0 + torch.randn_like(y0) * kappa * np.sqrt(etaT)
    
    # 反向采样
    for t in reversed(range(1, num_steps + 1)):
        eta_t = etas[t-1]
        eta_t_prev = etas[t-2] if t > 1 else 0
        alpha_t = alphas[t-1]
        
        # 预测x0
        x0_pred = model(x_t, y0, t)
        
        # 计算均值
        mu = (eta_t_prev / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
        
        # 计算方差
        sigma = kappa * np.sqrt((eta_t_prev / eta_t) * alpha_t)
        
        # 添加噪声(最后一步不加)
        if t > 1:
            x_t = mu + torch.randn_like(mu) * sigma
        else:
            x_t = mu
    
    return x_t

# 5. 运行示例
if __name__ == "__main__":
    # 加载低分辨率图像
    lq_image = preprocess_image("low_quality_image.jpg")
    
    # 进行4倍超分辨率
    hq_image = resshift_sampling(
        model=model,
        y0=lq_image,
        num_steps=15,
        p=0.3,
        kappa=2.0
    )
    
    # 保存结果
    hq_image = postprocess_image(hq_image)
    hq_image.save("high_quality_image.jpg")
    print("超分辨率完成,已保存为high_quality_image.jpg")

六、结论与展望

ResShift通过重新设计扩散模型的马尔可夫链,从根本上解决了扩散超分的速度问题。它的主要贡献包括:

  1. 提出了一种全新的残差转移扩散模型,只需要15步采样就能达到SOTA效果
  2. 设计了灵活的噪声调度方案,可以精确控制保真度和感知质量之间的平衡
  3. 在合成和真实世界数据集上都取得了优异的性能,同时保持了很高的推理效率
  4. 提供了完整的理论推导和实验验证,证明了所提方法的有效性

未来,ResShift可以在以下方向进一步改进:

  • 进一步减少采样步数,探索10步以内的超分辨率
  • 扩展到视频超分辨率任务
  • 设计更真实的退化模型,提高对真实世界退化的泛化能力
  • 探索轻量级网络架构,进一步降低模型参数量和计算量
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐