AI 推理性能调优:KV Cache 优化与批处理策略的工程实战

cover

一、推理延迟与吞吐量的双重瓶颈

在 AI 推理服务的生产部署中,性能瓶颈通常出现在两个维度:首 Token 延迟(Time to First Token, TTFT)和吞吐量(Tokens per Second, TPS)。TTFT 决定了用户感知的响应速度,TPS 决定了单卡能承载的并发请求数。二者往往互相制约——提升吞吐量的批处理策略会增加 TTFT,降低 TTFT 的逐请求处理会牺牲吞吐量。

实测数据表明,一个 7B 模型在 A100 GPU 上,逐请求推理的 TPS 约 50 tokens/s,TTFT 约 100ms;开启批处理后 TPS 可提升到 200 tokens/s,但 TTFT 增加到 300~500ms。如何在延迟和吞吐之间找到最优平衡点,是推理服务性能调优的核心命题。

二、KV Cache 与批处理的协同优化机制

KV Cache 和批处理是推理优化的两大支柱,它们在显存管理和计算调度上相互影响。

graph TB
    A[推理请求到达] --> B{批处理调度器}
    B -->|等待凑批| C[请求队列]
    C -->|超时或满批| D[批量 Prefill 阶段]
    D --> E[生成 KV Cache]
    E --> F[批量 Decode 阶段]
    F --> G{所有请求完成?}
    G -->|否| F
    G -->|是| H[返回结果]

    subgraph 显存管理
        I[PagedAttention] --> J[虚拟内存分页]
        J --> K[按需分配 KV Block]
        K --> L[空闲 Block 回收]
    end

    E --> I
    F --> I

KV Cache 的显存占用分析:以 Qwen2.5-7B 为例,28 层 Transformer,GQA 配置下 4 组 KV 头,每头维度 128。每个 Token 的 KV Cache 大小为 2 × 4 × 128 × 28 × 2 bytes = 57,344 bytes ≈ 56 KB。128K 上下文下,单个请求的 KV Cache 约 7 GB。A100 的 80GB 显存中,模型参数占用约 14GB,剩余 66GB 最多同时承载约 9 个 128K 请求。

批处理的调度策略:连续批处理(Continuous Batching)是当前主流方案,它允许新请求在运行中的批次中加入、已完成请求随时退出,避免了静态批处理中"等最慢请求"的浪费。

三、推理服务优化的工程实现

3.1 PagedAttention 显存管理器

from dataclasses import dataclass, field
from typing import Optional
import logging

logger = logging.getLogger(__name__)


@dataclass
class KVBlock:
    """KV Cache 的物理块,固定大小"""
    block_id: int
    size_tokens: int = 16  # 每个块存储 16 个 Token 的 KV
    ref_count: int = 0     # 引用计数,支持 Copy-on-Write


class PagedKVCacheManager:
    """基于 PagedAttention 的 KV Cache 管理器"""

    def __init__(self, total_blocks: int, block_size: int = 16):
        self._block_size = block_size
        self._free_blocks: list[KVBlock] = [
            KVBlock(block_id=i, size_tokens=block_size)
            for i in range(total_blocks)
        ]
        self._used_blocks: dict[str, list[KVBlock]] = {}  # request_id -> blocks
        self._total_blocks = total_blocks

    def allocate(self, request_id: str, num_tokens: int) -> list[KVBlock]:
        """为请求分配 KV Cache 块"""
        num_blocks_needed = (num_tokens + self._block_size - 1) // self._block_size
        available = len(self._free_blocks)

        if available < num_blocks_needed:
            # 显存不足时,尝试驱逐最早未使用的请求
            freed = self._evict_oldest(num_blocks_needed - available)
            if len(freed) < num_blocks_needed - available:
                raise MemoryError(
                    f"KV Cache 显存不足: 需要 {num_blocks_needed} 块, "
                    f"可用 {available + len(freed)} 块"
                )

        allocated = []
        for _ in range(num_blocks_needed):
            block = self._free_blocks.pop()
            block.ref_count += 1
            allocated.append(block)

        self._used_blocks[request_id] = allocated
        logger.debug(
            f"请求 {request_id} 分配 {num_blocks_needed} 块, "
            f"剩余 {len(self._free_blocks)}/{self._total_blocks}"
        )
        return allocated

    def free(self, request_id: str) -> int:
        """释放请求的 KV Cache 块"""
        blocks = self._used_blocks.pop(request_id, [])
        freed_count = 0
        for block in blocks:
            block.ref_count -= 1
            if block.ref_count <= 0:
                self._free_blocks.append(block)
                freed_count += 1
        return freed_count

    def _evict_oldest(self, num_needed: int) -> list[KVBlock]:
        """驱逐最早的请求以释放块"""
        if not self._used_blocks:
            return []

        # 简化实现:驱逐第一个请求
        oldest_id = next(iter(self._used_blocks))
        logger.warning(f"显存不足,驱逐请求: {oldest_id}")
        return [b for b in self._used_blocks.pop(oldest_id) if b.ref_count <= 1]

    @property
    def utilization(self) -> float:
        """显存利用率"""
        used = self._total_blocks - len(self._free_blocks)
        return used / self._total_blocks if self._total_blocks > 0 else 0.0

3.2 连续批处理调度器

import time
from enum import Enum
from typing import Any


class RequestState(Enum):
    WAITING = "waiting"
    PREFILL = "prefill"
    DECODING = "decoding"
    COMPLETED = "completed"


@dataclass
class InferenceRequest:
    request_id: str
    prompt_tokens: list[int]
    max_output_tokens: int
    state: RequestState = RequestState.WAITING
    generated_tokens: int = 0
    start_time: float = 0.0
    output_tokens: list[int] = field(default_factory=list)


class ContinuousBatchScheduler:
    """连续批处理调度器"""

    def __init__(self, max_batch_size: int = 32,
                 max_waiting_time_ms: int = 50):
        self._max_batch_size = max_batch_size
        self._max_waiting_ms = max_waiting_time_ms
        self._waiting_queue: list[InferenceRequest] = []
        self._running_batch: list[InferenceRequest] = []

    def add_request(self, request: InferenceRequest) -> None:
        """将请求加入等待队列"""
        request.start_time = time.monotonic()
        request.state = RequestState.WAITING
        self._waiting_queue.append(request)

    def schedule(self) -> list[InferenceRequest]:
        """调度下一批请求,返回当前应执行的请求列表"""
        # 移除已完成的请求
        self._running_batch = [
            r for r in self._running_batch
            if r.state != RequestState.COMPLETED
        ]

        # 计算可用槽位
        available_slots = self._max_batch_size - len(self._running_batch)
        if available_slots <= 0:
            return self._running_batch

        # 从等待队列中取请求加入批次
        batch_ready = []
        while available_slots > 0 and self._waiting_queue:
            # 检查是否达到最大等待时间
            if not batch_ready:
                oldest_wait = (time.monotonic() - self._waiting_queue[0].start_time) * 1000
                if oldest_wait < self._max_waiting_ms and len(batch_ready) < 4:
                    # 等待更多请求凑批,但不超过最小凑批数
                    break

            request = self._waiting_queue.pop(0)
            request.state = RequestState.PREFILL
            batch_ready.append(request)
            available_slots -= 1

        self._running_batch.extend(batch_ready)
        return self._running_batch

    def mark_prefill_done(self, request_id: str) -> None:
        """标记请求 Prefill 完成,进入 Decode 阶段"""
        for r in self._running_batch:
            if r.request_id == request_id:
                r.state = RequestState.DECODING
                break

    def mark_completed(self, request_id: str) -> None:
        """标记请求完成"""
        for r in self._running_batch:
            if r.request_id == request_id:
                r.state = RequestState.COMPLETED
                break

    @property
    def stats(self) -> dict:
        return {
            "waiting": len(self._waiting_queue),
            "running": len(self._running_batch),
            "batch_utilization": (
                len(self._running_batch) / self._max_batch_size
                if self._max_batch_size > 0 else 0
            ),
        }

3.3 性能基准测试框架

import asyncio
import statistics
import time
from typing import Callable


class InferenceBenchmark:
    """推理性能基准测试工具"""

    def __init__(self, inference_fn: Callable):
        self._inference_fn = inference_fn
        self._results: list[dict] = []

    async def run(self, prompts: list[str], concurrency: int = 1) -> dict:
        """运行基准测试"""
        semaphore = asyncio.Semaphore(concurrency)

        async def benchmark_single(prompt: str) -> dict:
            async with semaphore:
                start = time.monotonic()
                result = await self._inference_fn(prompt)
                end = time.monotonic()
                return {
                    "latency_ms": (end - start) * 1000,
                    "output_tokens": result.get("tokens", 0),
                    "ttft_ms": result.get("ttft_ms", 0),
                }

        tasks = [benchmark_single(p) for p in prompts]
        self._results = await asyncio.gather(*tasks)

        latencies = [r["latency_ms"] for r in self._results]
        ttfts = [r["ttft_ms"] for r in self._results]
        total_tokens = sum(r["output_tokens"] for r in self._results)
        total_time = max(latencies) / 1000  # 总耗时取最长请求

        return {
            "concurrency": concurrency,
            "total_requests": len(prompts),
            "total_tokens": total_tokens,
            "total_time_s": round(total_time, 2),
            "throughput_tps": round(total_tokens / total_time, 1),
            "latency_p50_ms": round(statistics.median(latencies), 1),
            "latency_p99_ms": round(sorted(latencies)[int(len(latencies) * 0.99)], 1),
            "ttft_p50_ms": round(statistics.median(ttfts), 1),
            "ttft_p99_ms": round(sorted(ttfts)[int(len(ttfts) * 0.99)], 1),
        }

四、推理优化的工程权衡

批处理大小与延迟的矛盾:批处理越大,GPU 利用率越高,吞吐量越大,但 TTFT 也越高。对于实时对话场景,TTFT 超过 500ms 用户就能感知到延迟。建议根据场景设置不同的批处理策略:实时场景最大等待 50ms、最小凑批 4 个请求;离线场景最大等待 200ms、最小凑批 16 个请求。

KV Cache 显存与并发数的博弈:更大的 KV Cache 意味着更长的上下文支持,但也意味着更少的并发请求。在 A100 80GB 上,128K 上下文最多 9 个并发,4K 上下文可达 200+ 并发。建议根据实际业务分布动态调整——大部分请求的上下文远小于最大值,可以为短上下文请求分配更小的 Cache。

Prefill 与 Decode 的资源竞争:Prefill 阶段是计算密集型(需要处理整个 Prompt),Decode 阶段是显存带宽密集型(逐 Token 生成)。二者在同一 GPU 上运行时会互相干扰。高端方案是使用分离式推理(Disaggregated Serving),Prefill 和 Decode 分别在不同 GPU 上执行,通过高速网络传输 KV Cache。

量化对推理性能的影响:INT8 量化可以将推理吞吐量提升 30%~50%,但精度损失在复杂推理任务上可能达到 2%~5%。建议对延迟敏感但精度要求适中的场景(如对话、摘要)使用量化,对精度要求严格的场景(如代码生成、数学推理)保持 FP16。

五、总结

AI 推理性能调优的核心是在延迟和吞吐之间找到业务场景的最优平衡点。KV Cache 优化通过 PagedAttention 实现显存的高效管理,连续批处理通过动态调度提升 GPU 利用率,两者协同决定了推理服务的整体性能。在工程落地时,需要根据场景特征配置不同的调度策略:实时场景优先保证 TTFT,离线场景优先最大化吞吐。性能调优不是一次性工作,而是持续的过程——需要建立基准测试框架,在每次配置变更后量化评估效果。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐