TeaCache:让扩散模型少算几步,但尽量不掉画质
文章目录
扩散模型生成图片/视频时,本质是在很多个 denoising step 中反复调用 Transformer/DiT。TeaCache 的核心思想很简单:
如果当前 step 和上一次完整计算的 step 足够相似,就不重新完整跑 Transformer,而是复用上一次缓存的 residual / 输出近似。
它不是 LLM 的 KV Cache,也不是缓存最终图片,而是缓存扩散去噪过程中的中间计算结果。
TeaCache 论文将其称为 Timestep Embedding Aware Cache,即利用 timestep embedding 估计不同 step 之间模型输出变化,从而决定是否缓存和复用。论文报告在 Open-Sora-Plan 上最高获得 4.41x 加速,VBench 质量分数仅下降 0.07%。(arXiv)

1. 为什么 TeaCache 能加速?
扩散模型每一步大概都在做:
latent_t + timestep_t + prompt_condition
↓
DiT / Transformer
↓
预测噪声 / velocity / residual
↓
scheduler 更新 latent
最耗时的通常是中间的 DiT / Transformer blocks,包括 Attention、MLP、Norm、Residual 等。
普通推理是:
step 50:完整计算 Transformer
step 49:完整计算 Transformer
step 48:完整计算 Transformer
...
step 1 :完整计算 Transformer
TeaCache 的做法是:
step 50:完整计算 Transformer,缓存 residual
step 49:判断变化小,复用 step 50 的 residual
step 48:判断变化小,继续复用
step 47:变化累计变大,重新完整计算并更新缓存

vLLM-Omni 官方文档也将 TeaCache 描述为:当连续 timestep 足够相似时缓存 Transformer 计算,从而实现约 1.5x–2.0x 加速,并通过输入相似性动态判断是否复用缓存。(vLLM)
2. TeaCache 判断“能不能跳过”的依据
TeaCache 不会直接比较完整模型输出,因为如果已经完整跑了一次模型,那就没有加速意义了。
它使用一个更便宜的代理量:
timestep embedding 调制后的 noisy input
然后比较当前 step 和上一次完整计算 step 的差异:
rel_l1 = mean(abs(current_modulated_input - previous_modulated_input)) \
/ mean(abs(previous_modulated_input))
再通过模型相关的多项式系数做 rescale,估计真实输出差异。
如果累计差异低于阈值,就复用缓存;如果超过阈值,就重新完整计算。
TeaCache 论文明确指出,它不直接使用耗时的模型输出差异,而是利用与模型输出强相关、但计算成本很低的模型输入差异来判断缓存时机。(arXiv)
3. 普通推理 vs TeaCache 推理

TeaCache 省掉的主要是:
部分 denoising step 里的 Transformer 主体计算
它通常不会省掉:
scheduler 更新
VAE decode
text encoder
prompt 编码
CPU/GPU 数据搬运
所以如果你的端到端瓶颈主要在 VAE、CPU offload 或 IO,那么 TeaCache 的实际加速会低于理论加速。
4. 最关键参数:rel_l1_thresh
TeaCache 最重要的参数是:
rel_l1_thresh
它控制缓存复用的激进程度:
阈值越小:更保守,完整计算更多,质量更稳,速度提升较小
阈值越大:更激进,缓存复用更多,速度更快,质量风险更高
vLLM-Omni 文档中 rel_l1_thresh 默认值是 0.2,建议范围是 0.1–0.8;低值优先质量,高值优先速度。(vLLM)

建议初始设置:
质量优先:0.10 ~ 0.20
均衡配置:0.20 ~ 0.40
速度优先:0.50 ~ 0.80
生产环境不要一上来拉到 0.8。更稳的方式是:
0.2 → 0.3 → 0.4 → 对比质量和耗时
5. 适用场景
TeaCache 更适合:
1. DiT / Transformer-based diffusion 模型
2. 图像生成、视频生成、音频扩散生成
3. denoising steps 较多的推理任务
4. 对 1.5x~2x 加速有价值,同时能容忍极小质量波动的生产服务
5. 单卡加速场景
vLLM-Omni 官方文档也建议 TeaCache 用于需要更快推理、且能容忍极小质量损失的生产场景;不太适合极致画质要求或非常短步数推理,例如小于 20 steps 的情况。(vLLM)
不太适合:
1. 4~8 step 的蒸馏模型
2. 强编辑、强控制、强文字生成场景
3. 对画质零损失要求极高的任务
4. 主要瓶颈不在 Transformer 的 pipeline
6. vLLM-Omni 正式用法
如果你用的是 vLLM-Omni,并且模型后端支持 TeaCache,可以直接这样开:
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
omni = Omni(
model="Qwen/Qwen-Image",
cache_backend="tea_cache",
)
outputs = omni.generate(
"A cat sitting on a windowsill",
OmniDiffusionSamplingParams(num_inference_steps=50),
)
在线服务方式:
vllm serve Qwen/Qwen-Image --omni --port 8091 \
--cache-backend tea_cache \
--cache-config '{"rel_l1_thresh": 0.2}'
这些参数写法来自 vLLM-Omni TeaCache 官方文档。(vLLM)
使用 facebook/DiT-XL-2-256 重点演示“根据相邻 step 输入差异决定是否复用 residual”。
它不是官方 TeaCache 的完整实现,因为官方 TeaCache 会使用 timestep embedding 调制输入、模型专属 coefficients、多项式 rescale 等细节。
安装环境
pip install -U diffusers transformers accelerate safetensors scipy pillow
代码
import gc
import time
import types
import torch
from IPython.display import display
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
from diffusers.models.modeling_outputs import Transformer2DModelOutput
MODEL_ID = "facebook/DiT-XL-2-256"
NUM_STEPS = 25
CLASS_ID = 207 # ImageNet class id.
class SimpleTeaCacheState:
"""Minimal TeaCache-style state for DiT transformer calls."""
def __init__(self, rel_l1_thresh=0.20, num_steps=25):
self.rel_l1_thresh = rel_l1_thresh
self.num_steps = num_steps
self.reset()
def reset(self):
self.step_idx = 0
self.accumulated_rel_l1 = 0.0
self.previous_input = None
self.previous_residual = None
self.previous_sample = None
self.full_compute_steps = 0
self.cached_steps = 0
@torch.no_grad()
def should_compute(self, hidden_states: torch.Tensor) -> bool:
is_first = self.step_idx == 0
is_last = self.step_idx >= self.num_steps - 1
if is_first or is_last:
self.previous_input = hidden_states.detach().float()
self.accumulated_rel_l1 = 0.0
return True
if self.previous_input is None or self.previous_sample is None:
self.previous_input = hidden_states.detach().float()
return True
current = hidden_states.detach().float()
previous = self.previous_input
denom = previous.abs().mean().clamp_min(1e-6)
rel_l1 = (current - previous).abs().mean() / denom
self.accumulated_rel_l1 += float(rel_l1.item())
if self.accumulated_rel_l1 < self.rel_l1_thresh:
return False
self.accumulated_rel_l1 = 0.0
self.previous_input = current
return True
def next_step(self):
self.step_idx += 1
def clear_cuda_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def load_dit_pipeline(device):
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = DiTPipeline.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
token=False,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
return pipe.to(device)
def enable_simple_teacache_for_dit(pipe, rel_l1_thresh=0.20, num_steps=25):
state = SimpleTeaCacheState(
rel_l1_thresh=rel_l1_thresh,
num_steps=num_steps,
)
transformer = pipe.transformer
original_forward = transformer.forward
def cached_forward(self, hidden_states, timestep, class_labels=None, **kwargs):
compute_full = state.should_compute(hidden_states)
if not compute_full and state.previous_sample is not None:
if state.previous_residual is not None and state.previous_residual.shape == hidden_states.shape:
sample = hidden_states + state.previous_residual.to(
device=hidden_states.device,
dtype=hidden_states.dtype,
)
else:
sample = state.previous_sample.to(
device=hidden_states.device,
dtype=hidden_states.dtype,
)
state.cached_steps += 1
state.next_step()
return Transformer2DModelOutput(sample=sample)
out = original_forward(
hidden_states=hidden_states,
timestep=timestep,
class_labels=class_labels,
**kwargs,
)
state.previous_sample = out.sample.detach()
if out.sample.shape == hidden_states.shape:
state.previous_residual = out.sample.detach() - hidden_states.detach()
else:
state.previous_residual = None
print("Warning: output sample shape differs from input hidden_states shape; cannot compute residual for caching.")
state.full_compute_steps += 1
state.next_step()
return out
transformer.forward = types.MethodType(cached_forward, transformer)
return state
def run_generation(pipe, device):
generator = torch.Generator(device=device).manual_seed(42)
return pipe(
class_labels=[CLASS_ID],
num_inference_steps=NUM_STEPS,
generator=generator,
).images[0]
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device = {device}")
pipe = load_dit_pipeline(device)
clear_cuda_cache()
t0 = time.time()
image = run_generation(pipe, device)
baseline_time = time.time() - t0
baseline_image = image
baseline_image.save("dit_baseline.png")
print("Baseline image:")
display(baseline_image)
print(f"[Baseline] time = {baseline_time:.2f}s")
del pipe
clear_cuda_cache()
pipe = load_dit_pipeline(device)
state = enable_simple_teacache_for_dit(
pipe,
rel_l1_thresh=0.20,
num_steps=NUM_STEPS,
)
clear_cuda_cache()
t0 = time.time()
image = run_generation(pipe, device)
cached_time = time.time() - t0
cached_image = image
cached_image.save("dit_simple_teacache.png")
print("SimpleTeaCache image:")
display(cached_image)
print(f"[SimpleTeaCache] time = {cached_time:.2f}s")
print(f"full_compute_steps = {state.full_compute_steps}")
print(f"cached_steps = {state.cached_steps}")
print(f"speedup = {baseline_time / cached_time:.2f}x")
if __name__ == "__main__":
main()
输出:
[Baseline] time = 1.42s
[SimpleTeaCache] time = 0.57s
full_compute_steps = 9
cached_steps = 16
speedup = 2.48x
对比
| threshold | time(s) | speedup | full | cached | 图 |
|---|---|---|---|---|---|
| 0.1 | 0.59 | 2.80x | 10 | 15 | ![]() |
| 0.2 | 0.68 | 2.43x | 9 | 16 | |
| 0.3 | 0.42 | 3.89x | 7 | 18 | |
| 0.4 | 0.40 | 4.12x | 6 | 19 | |
| 0.5 | 0.42 | 3.98x | 6 | 19 | |
| 0.6 | 0.39 | 4.19x | 6 | 19 | |
| 0.7 | 0.38 | 4.38x | 5 | 20 | |
| 0.8 | 0.35 | 4.77x | 5 | 20 | |
| 0.9 | 0.33 | 4.94x | 5 | 20 | |
| 1.0 | 0.36 | 4.63x | 5 | 20 | ![]() |
如何调参?
先用:
rel_l1_thresh = 0.20
然后逐步尝试:
0.10:更稳,速度提升较小
0.20:均衡起点
0.30:更快,但可能有质量波动
0.50:偏激进,容易出画质问题
如果发现生成图像细节变差、主体变形、纹理糊,先把阈值降回:
rel_l1_thresh = 0.10
7. 工程踩坑清单
坑 1:把 TeaCache 当成 LLM KV Cache
LLM KV Cache 缓存的是 token 历史的 key/value。
TeaCache 缓存的是 diffusion denoising step 之间的中间输出 / residual。
二者不是一个东西。
坑 2:rel_l1_thresh 太大
表现:
图像细节糊
视频运动不稳定
人物脸部漂移
文字生成质量下降
编辑任务不稳定
解决:
cache_config = {"rel_l1_thresh": 0.1}
vLLM-Omni 文档在质量下降场景下也建议降低 threshold,使缓存更保守。(vLLM)
坑 3:步数太少时加速不明显
TeaCache 需要足够多的 denoising steps 才有跳过空间。
vLLM-Omni 文档也提到,非常短的推理过程,例如小于 20 steps,缓存开销可能抵消收益;如果加速低于预期,建议使用足够多的 inference steps,例如 35+。(vLLM)
坑 4:coefficients 不能乱迁移
官方 TeaCache 里有模型相关的多项式 coefficients。不同模型的 timestep embedding、Transformer 结构、scheduler 都可能不同。
TeaCache 官方仓库也提示:结构相近的模型可以尝试迁移 coefficients,否则需要参考已有适配或重新适配。(GitHub)
坑 5:服务化时 cache state 必须按请求隔离
不要在多用户服务里把这些状态做成全局变量:
previous_residual
previous_input
accumulated_rel_l1
step_idx
否则请求 A 的 cache 可能污染请求 B。
正确做法:
每个 request 独立 TeaCacheState
请求结束后 reset
CFG cond/uncond 分支分别维护 cache
多线程/异步场景避免共享状态
8. 总结
TeaCache 的本质可以压缩成一句话:
利用 timestep embedding 感知相邻 denoising step 的输出变化;变化小时复用缓存,变化大时完整计算。
它最适合 DiT 类图像/视频/音频扩散模型,尤其是 denoising steps 较多、Transformer 计算占主要瓶颈的场景。生产里建议从 rel_l1_thresh=0.2 开始,逐步调大,同时用固定 prompt、seed、分辨率和 steps 对比不开缓存、保守缓存、激进缓存三组结果。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)