扩散模型生成图片/视频时,本质是在很多个 denoising step 中反复调用 Transformer/DiT。TeaCache 的核心思想很简单:

如果当前 step 和上一次完整计算的 step 足够相似,就不重新完整跑 Transformer,而是复用上一次缓存的 residual / 输出近似。

它不是 LLM 的 KV Cache,也不是缓存最终图片,而是缓存扩散去噪过程中的中间计算结果。
TeaCache 论文将其称为 Timestep Embedding Aware Cache,即利用 timestep embedding 估计不同 step 之间模型输出变化,从而决定是否缓存和复用。论文报告在 Open-Sora-Plan 上最高获得 4.41x 加速,VBench 质量分数仅下降 0.07%。(arXiv)

在这里插入图片描述


1. 为什么 TeaCache 能加速?

扩散模型每一步大概都在做:

latent_t + timestep_t + prompt_condition
        ↓
DiT / Transformer
        ↓
预测噪声 / velocity / residual
        ↓
scheduler 更新 latent

最耗时的通常是中间的 DiT / Transformer blocks,包括 Attention、MLP、Norm、Residual 等。

普通推理是:

step 50:完整计算 Transformer
step 49:完整计算 Transformer
step 48:完整计算 Transformer
...
step 1 :完整计算 Transformer

TeaCache 的做法是:

step 50:完整计算 Transformer,缓存 residual
step 49:判断变化小,复用 step 50 的 residual
step 48:判断变化小,继续复用
step 47:变化累计变大,重新完整计算并更新缓存

在这里插入图片描述

vLLM-Omni 官方文档也将 TeaCache 描述为:当连续 timestep 足够相似时缓存 Transformer 计算,从而实现约 1.5x–2.0x 加速,并通过输入相似性动态判断是否复用缓存。(vLLM)


2. TeaCache 判断“能不能跳过”的依据

TeaCache 不会直接比较完整模型输出,因为如果已经完整跑了一次模型,那就没有加速意义了。

它使用一个更便宜的代理量:

timestep embedding 调制后的 noisy input

然后比较当前 step 和上一次完整计算 step 的差异:

rel_l1 = mean(abs(current_modulated_input - previous_modulated_input)) \
         / mean(abs(previous_modulated_input))

再通过模型相关的多项式系数做 rescale,估计真实输出差异。
如果累计差异低于阈值,就复用缓存;如果超过阈值,就重新完整计算。
TeaCache 论文明确指出,它不直接使用耗时的模型输出差异,而是利用与模型输出强相关、但计算成本很低的模型输入差异来判断缓存时机。(arXiv)


3. 普通推理 vs TeaCache 推理

在这里插入图片描述

TeaCache 省掉的主要是:

部分 denoising step 里的 Transformer 主体计算

它通常不会省掉

scheduler 更新
VAE decode
text encoder
prompt 编码
CPU/GPU 数据搬运

所以如果你的端到端瓶颈主要在 VAE、CPU offload 或 IO,那么 TeaCache 的实际加速会低于理论加速。


4. 最关键参数:rel_l1_thresh

TeaCache 最重要的参数是:

rel_l1_thresh

它控制缓存复用的激进程度:

阈值越小:更保守,完整计算更多,质量更稳,速度提升较小
阈值越大:更激进,缓存复用更多,速度更快,质量风险更高

vLLM-Omni 文档中 rel_l1_thresh 默认值是 0.2,建议范围是 0.1–0.8;低值优先质量,高值优先速度。(vLLM)

在这里插入图片描述

建议初始设置:

质量优先:0.10 ~ 0.20
均衡配置:0.20 ~ 0.40
速度优先:0.50 ~ 0.80

生产环境不要一上来拉到 0.8。更稳的方式是:

0.2 → 0.3 → 0.4 → 对比质量和耗时

5. 适用场景

TeaCache 更适合:

1. DiT / Transformer-based diffusion 模型
2. 图像生成、视频生成、音频扩散生成
3. denoising steps 较多的推理任务
4. 对 1.5x~2x 加速有价值,同时能容忍极小质量波动的生产服务
5. 单卡加速场景

vLLM-Omni 官方文档也建议 TeaCache 用于需要更快推理、且能容忍极小质量损失的生产场景;不太适合极致画质要求或非常短步数推理,例如小于 20 steps 的情况。(vLLM)

不太适合:

1. 4~8 step 的蒸馏模型
2. 强编辑、强控制、强文字生成场景
3. 对画质零损失要求极高的任务
4. 主要瓶颈不在 Transformer 的 pipeline

6. vLLM-Omni 正式用法

如果你用的是 vLLM-Omni,并且模型后端支持 TeaCache,可以直接这样开:

from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

omni = Omni(
    model="Qwen/Qwen-Image",
    cache_backend="tea_cache",
)

outputs = omni.generate(
    "A cat sitting on a windowsill",
    OmniDiffusionSamplingParams(num_inference_steps=50),
)

在线服务方式:

vllm serve Qwen/Qwen-Image --omni --port 8091 \
  --cache-backend tea_cache \
  --cache-config '{"rel_l1_thresh": 0.2}'

这些参数写法来自 vLLM-Omni TeaCache 官方文档。(vLLM)

使用 facebook/DiT-XL-2-256 重点演示“根据相邻 step 输入差异决定是否复用 residual”。
它不是官方 TeaCache 的完整实现,因为官方 TeaCache 会使用 timestep embedding 调制输入、模型专属 coefficients、多项式 rescale 等细节。

安装环境

pip install -U diffusers transformers accelerate safetensors scipy pillow

代码

import gc
import time
import types

import torch
from IPython.display import display
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
from diffusers.models.modeling_outputs import Transformer2DModelOutput


MODEL_ID = "facebook/DiT-XL-2-256"
NUM_STEPS = 25
CLASS_ID = 207  # ImageNet class id.


class SimpleTeaCacheState:
    """Minimal TeaCache-style state for DiT transformer calls."""

    def __init__(self, rel_l1_thresh=0.20, num_steps=25):
        self.rel_l1_thresh = rel_l1_thresh
        self.num_steps = num_steps
        self.reset()

    def reset(self):
        self.step_idx = 0
        self.accumulated_rel_l1 = 0.0
        self.previous_input = None
        self.previous_residual = None
        self.previous_sample = None
        self.full_compute_steps = 0
        self.cached_steps = 0

    @torch.no_grad()
    def should_compute(self, hidden_states: torch.Tensor) -> bool:
        is_first = self.step_idx == 0
        is_last = self.step_idx >= self.num_steps - 1

        if is_first or is_last:
            self.previous_input = hidden_states.detach().float()
            self.accumulated_rel_l1 = 0.0
            return True

        if self.previous_input is None or self.previous_sample is None:
            self.previous_input = hidden_states.detach().float()
            return True

        current = hidden_states.detach().float()
        previous = self.previous_input

        denom = previous.abs().mean().clamp_min(1e-6)
        rel_l1 = (current - previous).abs().mean() / denom
        self.accumulated_rel_l1 += float(rel_l1.item())

        if self.accumulated_rel_l1 < self.rel_l1_thresh:
            return False

        self.accumulated_rel_l1 = 0.0
        self.previous_input = current
        return True

    def next_step(self):
        self.step_idx += 1


def clear_cuda_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()


def load_dit_pipeline(device):
    dtype = torch.float16 if device == "cuda" else torch.float32
    pipe = DiTPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        token=False,
    )
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    return pipe.to(device)


def enable_simple_teacache_for_dit(pipe, rel_l1_thresh=0.20, num_steps=25):
    state = SimpleTeaCacheState(
        rel_l1_thresh=rel_l1_thresh,
        num_steps=num_steps,
    )

    transformer = pipe.transformer
    original_forward = transformer.forward

    def cached_forward(self, hidden_states, timestep, class_labels=None, **kwargs):
        compute_full = state.should_compute(hidden_states)

        if not compute_full and state.previous_sample is not None:
            if state.previous_residual is not None and state.previous_residual.shape == hidden_states.shape:
                sample = hidden_states + state.previous_residual.to(
                    device=hidden_states.device,
                    dtype=hidden_states.dtype,
                )
            else:
                sample = state.previous_sample.to(
                    device=hidden_states.device,
                    dtype=hidden_states.dtype,
                )
            state.cached_steps += 1
            state.next_step()
            return Transformer2DModelOutput(sample=sample)

        out = original_forward(
            hidden_states=hidden_states,
            timestep=timestep,
            class_labels=class_labels,
            **kwargs,
        )

        state.previous_sample = out.sample.detach()
        if out.sample.shape == hidden_states.shape:
            state.previous_residual = out.sample.detach() - hidden_states.detach()
        else:
            state.previous_residual = None
            print("Warning: output sample shape differs from input hidden_states shape; cannot compute residual for caching.")
        state.full_compute_steps += 1
        state.next_step()
        return out

    transformer.forward = types.MethodType(cached_forward, transformer)
    return state


def run_generation(pipe, device):
    generator = torch.Generator(device=device).manual_seed(42)
    return pipe(
        class_labels=[CLASS_ID],
        num_inference_steps=NUM_STEPS,
        generator=generator,
    ).images[0]


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"device = {device}")

    pipe = load_dit_pipeline(device)

    clear_cuda_cache()
    t0 = time.time()
    image = run_generation(pipe, device)
    baseline_time = time.time() - t0
    baseline_image = image
    baseline_image.save("dit_baseline.png")
    print("Baseline image:")
    display(baseline_image)
    print(f"[Baseline] time = {baseline_time:.2f}s")

    del pipe
    clear_cuda_cache()

    pipe = load_dit_pipeline(device)
    state = enable_simple_teacache_for_dit(
        pipe,
        rel_l1_thresh=0.20,
        num_steps=NUM_STEPS,
    )

    clear_cuda_cache()
    t0 = time.time()
    image = run_generation(pipe, device)
    cached_time = time.time() - t0
    cached_image = image
    cached_image.save("dit_simple_teacache.png")
    print("SimpleTeaCache image:")
    display(cached_image)

    print(f"[SimpleTeaCache] time = {cached_time:.2f}s")
    print(f"full_compute_steps = {state.full_compute_steps}")
    print(f"cached_steps       = {state.cached_steps}")
    print(f"speedup            = {baseline_time / cached_time:.2f}x")


if __name__ == "__main__":
    main()

输出:

[Baseline] time = 1.42s
[SimpleTeaCache] time = 0.57s
full_compute_steps = 9
cached_steps       = 16
speedup            = 2.48x

对比

threshold time(s) speedup full cached
0.1 0.59 2.80x 10 15 在这里插入图片描述
0.2 0.68 2.43x 9 16
0.3 0.42 3.89x 7 18
0.4 0.40 4.12x 6 19
0.5 0.42 3.98x 6 19
0.6 0.39 4.19x 6 19
0.7 0.38 4.38x 5 20
0.8 0.35 4.77x 5 20
0.9 0.33 4.94x 5 20
1.0 0.36 4.63x 5 20 在这里插入图片描述

如何调参?

先用:

rel_l1_thresh = 0.20

然后逐步尝试:

0.10:更稳,速度提升较小
0.20:均衡起点
0.30:更快,但可能有质量波动
0.50:偏激进,容易出画质问题

如果发现生成图像细节变差、主体变形、纹理糊,先把阈值降回:

rel_l1_thresh = 0.10

7. 工程踩坑清单

坑 1:把 TeaCache 当成 LLM KV Cache

LLM KV Cache 缓存的是 token 历史的 key/value。

TeaCache 缓存的是 diffusion denoising step 之间的中间输出 / residual。

二者不是一个东西。


坑 2:rel_l1_thresh 太大

表现:

图像细节糊
视频运动不稳定
人物脸部漂移
文字生成质量下降
编辑任务不稳定

解决:

cache_config = {"rel_l1_thresh": 0.1}

vLLM-Omni 文档在质量下降场景下也建议降低 threshold,使缓存更保守。(vLLM)


坑 3:步数太少时加速不明显

TeaCache 需要足够多的 denoising steps 才有跳过空间。
vLLM-Omni 文档也提到,非常短的推理过程,例如小于 20 steps,缓存开销可能抵消收益;如果加速低于预期,建议使用足够多的 inference steps,例如 35+。(vLLM)


坑 4:coefficients 不能乱迁移

官方 TeaCache 里有模型相关的多项式 coefficients。不同模型的 timestep embedding、Transformer 结构、scheduler 都可能不同。

TeaCache 官方仓库也提示:结构相近的模型可以尝试迁移 coefficients,否则需要参考已有适配或重新适配。(GitHub)


坑 5:服务化时 cache state 必须按请求隔离

不要在多用户服务里把这些状态做成全局变量:

previous_residual
previous_input
accumulated_rel_l1
step_idx

否则请求 A 的 cache 可能污染请求 B。

正确做法:

每个 request 独立 TeaCacheState
请求结束后 reset
CFG cond/uncond 分支分别维护 cache
多线程/异步场景避免共享状态

8. 总结

TeaCache 的本质可以压缩成一句话:

利用 timestep embedding 感知相邻 denoising step 的输出变化;变化小时复用缓存,变化大时完整计算。

它最适合 DiT 类图像/视频/音频扩散模型,尤其是 denoising steps 较多、Transformer 计算占主要瓶颈的场景。生产里建议从 rel_l1_thresh=0.2 开始,逐步调大,同时用固定 prompt、seed、分辨率和 steps 对比不开缓存、保守缓存、激进缓存三组结果。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐