15步搞定超分辨率:ResShift高效扩散模型的残差转移魔法
论文信息
- 标题:ResShift: Efficient Diffusion Model for Image Super-resolution by Residual Shifting
- 会议:NeurIPS 2023
- 单位:南洋理工大学 S-Lab
- 代码:https://github.com/zsyOAOA/ResShift
- 论文:https://arxiv.org/pdf/2307.12348
一、开篇:扩散超分的"速度噩梦"终于被打破了
你有没有过这样的体验:好不容易找到一个效果不错的扩散超分模型,结果跑一张256×256的图要等十几秒?传统扩散模型做超分辨率(SR)就像让一个画家从一张白纸开始画画,需要一笔一笔慢慢涂,动辄几百上千步才能完成。虽然最后画出来的效果不错,但这个速度实在让人着急。
更气人的是,如果你想加速,用DDIM、DPM-Solver这些快速采样算法把步数压缩到几十步,结果图像立刻变得模糊不清,细节全没了。这就是扩散超分领域长期存在的"速度-质量"死结:要么慢得要死,要么快得难看。
直到ResShift的出现,这个死结终于被解开了。它提出了一种全新的扩散模型设计思路,只需要15步采样,就能达到甚至超过传统扩散模型1000步的效果,而且速度比LDM-100快了整整4倍!
二、ResShift核心思想:从"从零开始画"到"在草稿上添细节"
传统扩散模型的核心思路是:先把高清图像一步步加噪声变成高斯白噪声,然后再训练模型从高斯白噪声一步步去噪恢复出高清图像。这个思路对于图像生成来说很合理,但对于超分辨率来说却很浪费——我们明明已经有了低分辨率图像这个"草稿",为什么还要从零开始呢?
ResShift的核心洞察就是:超分辨率不需要从高斯噪声开始,只需要从低分辨率图像开始,逐步转移高分辨率和低分辨率之间的残差(也就是缺失的细节)就行了。
通俗解释:
- 传统扩散超分:给你一张白纸,让你画出一只猫 → 需要画几百笔
- ResShift超分:给你一张已经画好猫轮廓的草稿,让你添上毛发、眼睛、胡须这些细节 → 只需要画十几笔

图1 ResShift整体工作流(出处:原论文图2)
- 左边是前向过程:从高清图像x0x_0x0开始,逐步添加残差和噪声,最终变成低分辨率图像y0y_0y0
- 右边是反向过程:从低分辨率图像y0y_0y0开始,逐步去除残差和噪声,最终恢复出高清图像x0x_0x0
- 整个过程只需要15步,比传统扩散模型快了两个数量级
三、ResShift关键技术详解
3.1 整体架构
ResShift构建了一个连接高清图像和低分辨率图像的马尔可夫链(通俗解释:一个状态只能由前一个状态决定的链条)。这个链条的初始状态是高清图像,最终状态是低分辨率图像。超分辨率任务就是沿着这个链条反向走一遍,从低分辨率图像恢复出高清图像。
为了进一步加速计算,ResShift还可以在VQGAN的隐空间中运行,把图像压缩4倍,计算量直接减少16倍,而且几乎不会损失效果。
3.2 前向过程:从高清到低清的残差转移
前向过程的目标是把高清图像x0x_0x0逐步变成低分辨率图像y0y_0y0。我们首先定义残差e0e_0e0:
e0=y0−x0e_0 = y_0 - x_0e0=y0−x0
通俗解释:残差就是低分辨率图像和高清图像之间的差异,也就是我们需要恢复的所有细节信息。
然后我们引入一个单调递增的转移序列{ηt}t=1T\{\eta_t\}_{t=1}^T{ηt}t=1T,满足η1→0\eta_1 \to 0η1→0和ηT→1\eta_T \to 1ηT→1。前向过程的转移分布定义为:
q(xt∣xt−1,y0)=N(xt;xt−1+αte0,κ2αtI)q(x_t | x_{t-1}, y_0) = \mathcal{N}(x_t; x_{t-1} + \alpha_t e_0, \kappa^2 \alpha_t I)q(xt∣xt−1,y0)=N(xt;xt−1+αte0,κ2αtI)
公式中每个符号的含义:
- q(xt∣xt−1,y0)q(x_t | x_{t-1}, y_0)q(xt∣xt−1,y0):给定前一个状态xt−1x_{t-1}xt−1和低分辨率图像y0y_0y0,当前状态xtx_txt的条件概率分布
- N\mathcal{N}N:正态分布(高斯分布)
- xtx_txt:第t步的中间图像
- xt−1x_{t-1}xt−1:第t-1步的中间图像
- αt\alpha_tαt:第t步的残差转移步长,αt=ηt−ηt−1\alpha_t = \eta_t - \eta_{t-1}αt=ηt−ηt−1(t>1时),α1=η1\alpha_1 = \eta_1α1=η1
- e0e_0e0:高分辨率图像x0x_0x0和低分辨率图像y0y_0y0之间的残差
- κ\kappaκ:控制噪声强度的超参数
- III:单位矩阵
通过数学推导,我们可以得到任意时刻t的边际分布:
q(xt∣x0,y0)=N(xt;x0+ηte0,κ2ηtI)q(x_t | x_0, y_0) = \mathcal{N}(x_t; x_0 + \eta_t e_0, \kappa^2 \eta_t I)q(xt∣x0,y0)=N(xt;x0+ηte0,κ2ηtI)
这个公式非常重要,它告诉我们:
- 当t=1时,η1→0\eta_1 \to 0η1→0,x1x_1x1几乎等于高清图像x0x_0x0
- 当t=T时,ηT→1\eta_T \to 1ηT→1,xTx_TxT几乎等于低分辨率图像y0y_0y0加上一点噪声
这就完美实现了我们的目标:构建一个从高清图像到低分辨率图像的平滑过渡链条。
3.3 反向过程:从低清到高清的残差恢复
反向过程的目标是从低分辨率图像y0y_0y0恢复出高清图像x0x_0x0。根据贝叶斯定理,我们可以推导出反向过程的后验分布:
q(xt−1∣xt,x0,y0)=N(xt−1∣ηt−1ηtxt+αtηtx0,κ2ηt−1ηtαtI)q(x_{t-1}|x_t, x_0, y_0) = \mathcal{N}\left( x_{t-1} \bigg| \frac{\eta_{t-1}}{\eta_t}x_t + \frac{\alpha_t}{\eta_t}x_0, \kappa^2 \frac{\eta_{t-1}}{\eta_t}\alpha_t I \right)q(xt−1∣xt,x0,y0)=N(xt−1
ηtηt−1xt+ηtαtx0,κ2ηtηt−1αtI)
这个分布的均值由两部分组成:
- 前一个状态xtx_txt的加权平均
- 原始高清图像x0x_0x0的加权平均
因此,我们可以把反向过程的均值参数化为:
μθ(xt,y0,t)=ηt−1ηtxt+αtηtfθ(xt,y0,t)\mu_\theta(x_t, y_0, t) = \frac{\eta_{t-1}}{\eta_t} x_t + \frac{\alpha_t}{\eta_t} f_\theta(x_t, y_0, t)μθ(xt,y0,t)=ηtηt−1xt+ηtαtfθ(xt,y0,t)
公式中每个符号的含义:
- μθ\mu_\thetaμθ:反向过程预测的均值
- ηt\eta_tηt:第t步的累积残差转移比例
- fθf_\thetafθ:我们训练的神经网络,用来预测原始高分辨率图像x0x_0x0
- θ\thetaθ:神经网络的可学习参数
这样,我们的训练目标就变得非常简单:让神经网络fθf_\thetafθ尽可能准确地预测出原始高清图像x0x_0x0。最终的损失函数是:
minθ∑t∥fθ(xt,y0,t)−x0∥22\min_\theta \sum_t \left\| f_\theta(x_t, y_0, t) - x_0 \right\|_2^2θmint∑∥fθ(xt,y0,t)−x0∥22
通俗解释:我们训练一个神经网络,给它输入中间状态xtx_txt、低分辨率图像y0y_0y0和当前步数t,让它输出原始高清图像x0x_0x0的预测值。训练的目标就是让这个预测值和真实的高清图像尽可能接近。
3.4 灵活的噪声调度
ResShift设计了一个非常灵活的噪声调度方案,可以精确控制残差转移的速度和噪声强度。对于中间时刻的ηt\sqrt{\eta_t}ηt,我们采用非均匀几何序列:
ηt=η1×b0βt,t=2,⋯ ,T−1\sqrt{\eta_t} = \sqrt{\eta_1} \times b_0^{\beta_t}, \quad t=2, \cdots, T-1ηt=η1×b0βt,t=2,⋯,T−1
βt=(t−1T−1)p×(T−1),b0=exp[12(T−1)logηTη1]\beta_t = \left(\frac{t-1}{T-1}\right)^p \times (T-1), \quad b_0 = exp\left[\frac{1}{2(T-1)} log \frac{\eta_T}{\eta_1}\right]βt=(T−1t−1)p×(T−1),b0=exp[2(T−1)1logη1ηT]
公式中每个符号的含义:
- η1\eta_1η1:第一步的累积残差比例,设置为(0.04/κ)2(0.04/\kappa)^2(0.04/κ)2和0.001中的较小值
- ηT\eta_TηT:最后一步的累积残差比例,设置为0.999
- b0b_0b0:几何序列的底数
- βt\beta_tβt:第t步的指数
- ppp:控制残差转移速度的超参数,是整个噪声调度的核心
超参数ppp的作用非常关键:
- 当ppp很小时(如0.3),残差在早期就快速转移,模型有更多时间生成丰富的细节,视觉效果更好,但保真度略有下降
- 当ppp很大时(如3.0),残差转移比较均匀,模型更注重保真度,但生成的细节会比较少

图2 不同超参数下的噪声调度(出处:原论文图3(h))
- 横轴是时间步t,纵轴是ηt\sqrt{\eta_t}ηt
- 可以看到,p越小,曲线越陡峭,早期残差转移越快;p越大,曲线越平缓,残差转移越均匀
另一个重要的超参数是κ\kappaκ,它控制整个扩散过程的噪声强度:
- 当κ\kappaκ太小时(如0.5),噪声太弱,模型生成的细节不够丰富
- 当κ\kappaκ太大时(如16.0),噪声太强,模型容易生成不真实的纹理
- 论文中选择κ=2.0\kappa=2.0κ=2.0作为默认值,在细节丰富度和真实性之间取得了最佳平衡
四、实验结果与分析
4.1 实验设置
- 训练数据:ImageNet训练集,随机裁剪256×256的高清图像
- 退化模型:RealESRGAN的退化管道,包括模糊、下采样、噪声和JPEG压缩
- 优化器:Adam,学习率=5e-5,训练50万次迭代
- 网络架构:基于UNet,把自注意力层替换成Swin Transformer块,提高对任意分辨率的支持
- 测试数据集:
- ImageNet-Test:3000张合成退化图像
- RealSR:100张真实世界图像
- RealSet65:65张真实世界图像(35张来自文献,30张来自互联网)
4.2 模型超参数分析
我们首先分析了不同超参数对模型性能的影响,结果如下表所示:
表1 不同超参数配置的性能对比(出处:原论文表1)
| T | p | κ | PSNR ↑ | SSIM ↑ | LPIPS ↓ | CLIPIQA ↑ | MUSIQ ↑ |
|---|---|---|---|---|---|---|---|
| 10 | 0.3 | 2.0 | 25.20 | 0.6828 | 0.2517 | 0.5492 | 50.6617 |
| 15 | 0.3 | 2.0 | 25.01 | 0.6769 | 0.2312 | 0.5922 | 53.6596 |
| 30 | 0.3 | 2.0 | 24.52 | 0.6585 | 0.2253 | 0.6273 | 55.7904 |
| 15 | 0.3 | 0.5 | 24.90 | 0.6709 | 0.2437 | 0.5700 | 50.6101 |
| 15 | 0.3 | 1.0 | 24.84 | 0.6699 | 0.2354 | 0.5914 | 52.9933 |
| 15 | 0.3 | 2.0 | 25.01 | 0.6769 | 0.2312 | 0.5922 | 53.6596 |
| 15 | 0.3 | 8.0 | 25.31 | 0.6858 | 0.2592 | 0.5231 | 49.3182 |
| 15 | 0.1 | 2.0 | 25.01 | 0.6769 | 0.2312 | 0.5922 | 53.6596 |
| 15 | 0.5 | 2.0 | 25.05 | 0.6745 | 0.2387 | 0.5816 | 52.4475 |
| 15 | 1.0 | 2.0 | 25.12 | 0.6780 | 0.2613 | 0.5314 | 48.4964 |
| 15 | 3.0 | 2.0 | 25.39 | 0.5813 | 0.3432 | 0.4041 | 38.5324 |
结果分析:
- 采样步数T:随着T增加,全参考指标(PSNR、SSIM)下降,无参考指标(CLIPIQA、MUSIQ)上升。这是因为更多的步数让模型有更多时间生成细节,但也会稍微偏离原始图像。T=15是速度和质量的最佳平衡点。
- 超参数p:随着p增加,全参考指标上升,无参考指标下降。这验证了我们之前的分析:p小细节多,p大保真度高。
- 超参数κ:κ在1.0-2.0之间时,模型取得了最佳的综合性能。κ太小或太大都会导致图像质量下降。

图3 不同超参数配置的视觉效果对比(出处:原论文图4)
- 可以直观地看到,T=15、p=0.3、κ=2.0的配置取得了最佳的视觉效果,细节丰富且自然真实
4.3 与其他方法的对比
我们在ImageNet-Test数据集上与当前最先进的方法进行了对比,包括BSRGAN、RealESRGAN、SwinIR和LDM。
表2 效率与性能对比(出处:原论文表2)
| 方法 | PSNR ↑ | LPIPS ↓ | CLIPIQA ↑ | 运行时间(s) ↓ | 参数量(M) |
|---|---|---|---|---|---|
| BSRGAN | 24.42 | 0.259 | 0.581 | 0.012 | 16.70 |
| RealESRGAN | 24.04 | 0.254 | 0.523 | 0.013 | 16.70 |
| SwinIR | 23.99 | 0.238 | 0.564 | 0.046 | 28.01 |
| LDM-15 | 24.89 | 0.269 | 0.512 | 0.102 | 113.60 |
| LDM-30 | 24.49 | 0.248 | 0.572 | 0.184 | 113.60 |
| LDM-100 | 23.90 | 0.244 | 0.620 | 0.413 | 113.60 |
| ResShift-15 | 25.01 | 0.231 | 0.592 | 0.105 | 118.59 |
结果分析:
- ResShift-15在PSNR和LPIPS上都取得了最好的成绩,说明它在保真度和感知质量上都优于其他方法
- ResShift-15的运行时间和LDM-15几乎相同,但效果远远好于LDM-15
- ResShift-15比LDM-100快了整整4倍,而且PSNR更高,LPIPS更低,真正实现了"又快又好"
表3 合成数据定量对比(出处:原论文表3)
| 方法 | PSNR ↑ | SSIM ↑ | LPIPS ↓ | CLIPIQA ↑ | MUSIQ ↑ |
|---|---|---|---|---|---|
| ESRGAN | 20.67 | 0.448 | 0.485 | 0.451 | 43.615 |
| RealSR-JPEG | 23.11 | 0.591 | 0.326 | 0.537 | 46.981 |
| BSRGAN | 24.42 | 0.659 | 0.259 | 0.581 | 54.697 |
| SwinIR | 23.99 | 0.667 | 0.238 | 0.564 | 53.790 |
| RealESRGAN | 24.04 | 0.665 | 0.254 | 0.523 | 52.538 |
| DASR | 24.75 | 0.675 | 0.250 | 0.536 | 48.337 |
| LDM-15 | 24.89 | 0.670 | 0.269 | 0.512 | 46.419 |
| ResShift | 25.01 | 0.677 | 0.231 | 0.592 | 53.660 |
结果分析:
- ResShift在所有5个指标上都取得了最好或第二好的成绩
- 特别是在PSNR和LPIPS上,ResShift大幅领先于其他扩散方法和GAN方法
- 这说明ResShift在生成真实细节的同时,能够很好地保持对原始图像的保真度
4.4 真实世界数据对比
在真实世界数据集上,我们主要使用无参考指标CLIPIQA和MUSIQ来评估图像质量,因为真实世界图像没有对应的高清真值。
表4 真实世界数据定量对比(出处:原论文表4)
| 方法 | RealSR | RealSet65 | ||
|---|---|---|---|---|
| CLIPIQA ↑ | MUSIQ ↑ | CLIPIQA ↑ | MUSIQ ↑ | |
| ESRGAN | 0.2362 | 29.048 | 0.3739 | 42.369 |
| RealSR-JPEG | 0.3615 | 36.076 | 0.5282 | 50.539 |
| BSRGAN | 0.5439 | 63.586 | 0.6163 | 65.582 |
| SwinIR | 0.4654 | 59.636 | 0.5782 | 63.822 |
| RealESRGAN | 0.4898 | 59.678 | 0.5995 | 63.220 |
| DASR | 0.3629 | 45.825 | 0.4965 | 55.708 |
| LDM-15 | 0.3836 | 49.317 | 0.4274 | 47.488 |
| ResShift | 0.5958 | 59.873 | 0.6537 | 61.330 |
结果分析:
- ResShift在两个真实世界数据集的CLIPIQA指标上都取得了压倒性的第一
- CLIPIQA是基于CLIP模型的无参考质量指标,和人类视觉感知高度相关,这说明ResShift生成的图像最符合人类的审美
- 在MUSIQ指标上,ResShift也取得了有竞争力的成绩

图4 真实世界图像恢复效果对比(出处:原论文图6)
- 第一列是低分辨率输入
- 可以看到,ResShift生成的图像细节最丰富,纹理最自然,没有明显的人工痕迹
- LDM-15的结果过于平滑,丢失了很多细节
- GAN方法的结果虽然有细节,但经常出现不自然的 artifacts
4.5 感知-失真权衡
在图像恢复领域,存在一个著名的"感知-失真权衡":感知质量好的模型,保真度往往较低;保真度高的模型,感知质量往往较差。
图5 感知-失真权衡曲线(出处:原论文图7)
- 横轴是失真度(MSE,越小越好),纵轴是感知质量(LPIPS,越小越好)
- 曲线越靠近左下角,说明模型在感知和失真之间的平衡越好
- 可以看到,ResShift的曲线始终位于LDM的下方,说明ResShift在平衡感知和失真方面明显优于LDM
4.6 失败案例
ResShift也不是万能的,它在处理一些严重退化的图像时仍然会遇到困难。
图6 典型失败案例(出处:原论文图9)
- 这是一张严重退化的漫画图像,所有方法都无法很好地恢复
- 这是因为大多数超分模型都是在自然图像上训练的,对漫画这种特殊风格的图像泛化能力较差
- 未来的工作可以通过在更多样化的数据集上训练来解决这个问题
五、核心代码实现
以下是ResShift的简化推理代码,基于官方实现:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torchvision import transforms
from resshift import ResShiftModel
# 设备设置
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. 加载预训练模型
model = ResShiftModel.from_pretrained(
"zsyoa/resshift-realbsr4x-v1",
torch_dtype=torch.float16
).to(device)
# 2. 图像预处理
def preprocess_image(image_path):
image = Image.open(image_path).convert("RGB")
# 调整图像大小为64×64(4倍超分到256×256)
transform = transforms.Compose([
transforms.Resize((64, 64), interpolation=Image.Resampling.LANCZOS),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
return transform(image).unsqueeze(0).to(device, dtype=torch.float16)
# 3. 后处理
def postprocess_image(tensor):
tensor = tensor.squeeze(0).cpu().float()
tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
tensor = tensor.permute(1, 2, 0).numpy()
tensor = (tensor * 255).astype(np.uint8)
return Image.fromarray(tensor)
# 4. ResShift采样函数
def resshift_sampling(model, y0, num_steps=15, p=0.3, kappa=2.0):
# 计算噪声调度
eta1 = min((0.04 / kappa) ** 2, 0.001)
etaT = 0.999
b0 = np.exp((1 / (2 * (num_steps - 1))) * np.log(etaT / eta1))
etas = [eta1]
for t in range(2, num_steps + 1):
beta_t = ((t - 1) / (num_steps - 1)) ** p * (num_steps - 1)
eta_t = (np.sqrt(eta1) * (b0 ** beta_t)) ** 2
etas.append(eta_t)
alphas = [etas[0]]
for t in range(1, num_steps):
alphas.append(etas[t] - etas[t-1])
# 初始化x_T为y0加上噪声
x_t = y0 + torch.randn_like(y0) * kappa * np.sqrt(etaT)
# 反向采样
for t in reversed(range(1, num_steps + 1)):
eta_t = etas[t-1]
eta_t_prev = etas[t-2] if t > 1 else 0
alpha_t = alphas[t-1]
# 预测x0
x0_pred = model(x_t, y0, t)
# 计算均值
mu = (eta_t_prev / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
# 计算方差
sigma = kappa * np.sqrt((eta_t_prev / eta_t) * alpha_t)
# 添加噪声(最后一步不加)
if t > 1:
x_t = mu + torch.randn_like(mu) * sigma
else:
x_t = mu
return x_t
# 5. 运行示例
if __name__ == "__main__":
# 加载低分辨率图像
lq_image = preprocess_image("low_quality_image.jpg")
# 进行4倍超分辨率
hq_image = resshift_sampling(
model=model,
y0=lq_image,
num_steps=15,
p=0.3,
kappa=2.0
)
# 保存结果
hq_image = postprocess_image(hq_image)
hq_image.save("high_quality_image.jpg")
print("超分辨率完成,已保存为high_quality_image.jpg")
六、结论与展望
ResShift通过重新设计扩散模型的马尔可夫链,从根本上解决了扩散超分的速度问题。它的主要贡献包括:
- 提出了一种全新的残差转移扩散模型,只需要15步采样就能达到SOTA效果
- 设计了灵活的噪声调度方案,可以精确控制保真度和感知质量之间的平衡
- 在合成和真实世界数据集上都取得了优异的性能,同时保持了很高的推理效率
- 提供了完整的理论推导和实验验证,证明了所提方法的有效性
未来,ResShift可以在以下方向进一步改进:
- 进一步减少采样步数,探索10步以内的超分辨率
- 扩展到视频超分辨率任务
- 设计更真实的退化模型,提高对真实世界退化的泛化能力
- 探索轻量级网络架构,进一步降低模型参数量和计算量
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)