AI 推理性能调优:推理引擎选型与批处理策略的工程实战

cover

一、推理延迟的构成:为什么模型推理总是"慢在等"

大模型推理的延迟由三个阶段构成:Prefill(预填充)、Decode(解码)和调度开销。Prefill 阶段处理输入 Prompt 的所有 Token,计算 KV Cache,延迟与输入长度线性相关。Decode 阶段逐 Token 生成输出,每一步都需要读取完整的 KV Cache,延迟与输出长度和 KV Cache 大小相关。调度开销则来自请求排队、批处理组建和 GPU 内存管理。

在生产环境中,推理延迟的瓶颈往往不在模型计算本身,而在调度策略。当多个请求并发到达时,推理引擎需要决定如何将它们组织成 Batch 送入 GPU。简单的"等够 N 个请求再执行"策略会导致短请求被长请求拖慢;"来一个执行一个"的策略则浪费了 GPU 的并行计算能力。批处理策略的优劣,直接决定了推理服务的吞吐量和尾延迟。

graph LR
    A[请求队列] --> B[调度器]
    B --> C1[Batch 1<br/>3 个请求]
    B --> C2[Batch 2<br/>5 个请求]
    B --> C3[Batch 3<br/>2 个请求]

    C1 --> D[GPU 执行]
    C2 --> D
    C3 --> D

    D --> E1[请求1 输出]
    D --> E2[请求2 输出]
    D --> E3[请求3 输出]

    F[连续批处理<br/>Continuous Batching] --> B
    F -->|请求完成即释放| G[KV Cache 回收]
    G -->|新请求即时加入| B

    style F fill:#e1f5fe
    style G fill:#e8f5e9

二、推理引擎的选型对比:vLLM、TensorRT-LLM 与 TGI

当前主流的开源推理引擎有三个:vLLM、TensorRT-LLM 和 TGI(Text Generation Inference)。选型需要从性能、易用性、模型覆盖度和社区活跃度四个维度评估。

vLLM 的核心优势是 PagedAttention 技术,通过分页管理 KV Cache 显存,将显存利用率从传统的 20-40% 提升到 90% 以上。配合 Continuous Batching,vLLM 在高并发场景下的吞吐量显著优于静态批处理方案。vLLM 的劣势是对新模型的支持有时滞后于 Hugging Face,且在极低延迟场景下不如 TensorRT-LLM。

TensorRT-LLM 是 NVIDIA 推出的推理加速方案,通过算子融合、Kernel 自动调优和 FP8 量化实现极致性能。在 A100/H100 GPU 上,TensorRT-LLM 的推理速度通常比 vLLM 快 30-50%。但代价是工程复杂度极高——模型需要先编译为 TensorRT Engine,编译过程耗时且对环境敏感,模型更新后需要重新编译。

TGI 是 Hugging Face 推出的推理服务,优势在于与 Hugging Face 生态的无缝集成,支持 Flash Attention、量化、水印等功能。TGI 的性能介于 vLLM 和 TensorRT-LLM 之间,但在易用性上最优——从模型下载到推理服务上线,一条命令即可完成。

引擎 吞吐量 延迟 易用性 模型覆盖
vLLM 广
TensorRT-LLM 最高 最低
TGI 中高 最广

三、批处理策略的代码实现

以下实现展示了 Continuous Batching 的核心调度逻辑,以及基于 Token 预算的批处理组建策略。

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from collections import deque
import time

class RequestStatus(Enum):
    QUEUED = "queued"
    PREFILLING = "prefilling"
    DECODING = "decoding"
    COMPLETED = "completed"
    FAILED = "failed"

@dataclass
class InferenceRequest:
    """推理请求"""
    request_id: str
    prompt_tokens: list[int]       # 输入 Token IDs
    max_output_tokens: int         # 最大输出长度
    temperature: float = 0.7
    status: RequestStatus = RequestStatus.QUEUED
    output_tokens: list[int] = field(default_factory=list)
    kv_cache_pages: list[int] = field(default_factory=list)  # 分配的 KV Cache 页
    arrival_time: float = field(default_factory=time.time)
    start_time: Optional[float] = None
    completion_time: Optional[float] = None

@dataclass
class BatchConfig:
    """批处理配置"""
    max_batch_size: int = 32          # 最大批大小
    max_tokens_per_batch: int = 8192  # 每批最大 Token 数(含输入+输出)
    max_waiting_time_ms: int = 50     # 最大等待时间(毫秒)
    kv_cache_page_size: int = 16      # 每页 Token 数

class ContinuousBatchScheduler:
    """连续批处理调度器:请求完成即释放,新请求即时加入"""

    def __init__(self, config: BatchConfig, total_kv_pages: int):
        self.config = config
        self.total_kv_pages = total_kv_pages
        self.available_pages = total_kv_pages
        self.waiting_queue: deque[InferenceRequest] = deque()
        self.running_requests: list[InferenceRequest] = []

    def add_request(self, request: InferenceRequest):
        """添加推理请求到等待队列"""
        self.waiting_queue.append(request)

    def schedule(self) -> list[InferenceRequest]:
        """调度一批请求:从等待队列中选择可执行的请求"""
        batch = list(self.running_requests)  # 保留正在解码的请求
        current_tokens = self._count_batch_tokens(batch)

        # 尝试从等待队列中加入新请求
        while self.waiting_queue:
            # 检查批大小限制
            if len(batch) >= self.config.max_batch_size:
                break

            # 检查 Token 预算
            next_req = self.waiting_queue[0]
            required_tokens = len(next_req.prompt_tokens) + next_req.max_output_tokens

            if current_tokens + required_tokens > self.config.max_tokens_per_batch:
                # Token 预算不足,检查是否可以加入更小的请求
                # 避免长请求阻塞短请求
                skipped = self._try_find_smaller_request(current_tokens)
                if skipped:
                    batch.append(skipped)
                    current_tokens += len(skipped.prompt_tokens) + skipped.max_output_tokens
                    continue
                break

            # 检查 KV Cache 页是否足够
            required_pages = self._estimate_kv_pages(next_req)
            if required_pages > self.available_pages:
                break  # 显存不足,等待运行中请求释放

            # 加入批处理
            req = self.waiting_queue.popleft()
            req.status = RequestStatus.PREFILLING
            req.start_time = time.time()
            allocated = self._allocate_kv_pages(required_pages)
            req.kv_cache_pages = allocated
            batch.append(req)
            current_tokens += required_tokens

        self.running_requests = batch
        return batch

    def on_step_complete(self, completed_ids: list[str]) -> list[InferenceRequest]:
        """处理一步解码完成后的状态更新"""
        completed = []

        for req in self.running_requests:
            if req.request_id in completed_ids:
                # 请求完成:释放 KV Cache 页
                self._release_kv_pages(req.kv_cache_pages)
                req.status = RequestStatus.COMPLETED
                req.completion_time = time.time()
                completed.append(req)
            else:
                # 请求仍在解码:检查是否超过最大输出长度
                if len(req.output_tokens) >= req.max_output_tokens:
                    self._release_kv_pages(req.kv_cache_pages)
                    req.status = RequestStatus.COMPLETED
                    req.completion_time = time.time()
                    completed.append(req)
                else:
                    req.status = RequestStatus.DECODING

        # 从运行列表中移除已完成的请求
        self.running_requests = [
            r for r in self.running_requests
            if r.request_id not in {c.request_id for c in completed}
        ]

        return completed

    def _count_batch_tokens(self, batch: list[InferenceRequest]) -> int:
        """统计批处理中的总 Token 数"""
        total = 0
        for req in batch:
            total += len(req.prompt_tokens) + len(req.output_tokens)
        return total

    def _estimate_kv_pages(self, req: InferenceRequest) -> int:
        """估算请求所需的 KV Cache 页数"""
        total_tokens = len(req.prompt_tokens) + req.max_output_tokens
        pages = (total_tokens + self.config.kv_cache_page_size - 1) // self.config.kv_cache_page_size
        return pages

    def _allocate_kv_pages(self, count: int) -> list[int]:
        """分配 KV Cache 页"""
        if count > self.available_pages:
            return []
        self.available_pages -= count
        # 简化实现:返回页 ID 列表
        return list(range(count))

    def _release_kv_pages(self, pages: list[int]):
        """释放 KV Cache 页"""
        self.available_pages += len(pages)

    def _try_find_smaller_request(self, current_tokens: int) -> Optional[InferenceRequest]:
        """在等待队列中寻找 Token 预算允许的较小请求"""
        for i, req in enumerate(self.waiting_queue):
            required = len(req.prompt_tokens) + req.max_output_tokens
            if current_tokens + required <= self.config.max_tokens_per_batch:
                return self.waiting_queue[i]
        return None

四、推理性能优化的边界与权衡

吞吐量与延迟的矛盾。 大 Batch 提升吞吐量但增加单请求延迟(排队等待时间长),小 Batch 降低延迟但浪费 GPU 算力。需要根据业务 SLA 选择平衡点:对延迟敏感的场景(如对话),限制最大批大小和等待时间;对吞吐敏感的场景(如批量处理),增大批大小。

量化精度与推理质量。 INT8 量化通常只损失 1-2% 的精度,INT4 量化可能损失 5-10%。量化选择需要基于业务对精度的容忍度。对于代码生成场景,INT4 量化可能导致变量名拼写错误,不可接受;对于文本摘要场景,INT4 量化的精度损失通常可接受。

KV Cache 的显存瓶颈。 在长上下文场景下,KV Cache 占用的显存可能超过模型权重。例如,Llama-3-70B 在 128K 上下文长度下,单个请求的 KV Cache 就需要约 8GB 显存。PagedAttention 缓解了显存碎片问题,但不减少总显存需求。长上下文场景下,需要配合 KV Cache 压缩(如稀疏注意力)或卸载到 CPU 内存。

引擎锁定的风险。 不同推理引擎的 API 和优化策略差异很大,切换引擎需要重新适配。建议在应用层建立统一的推理接口,将引擎选择作为可替换的实现细节,而非硬编码依赖。

优化维度 收益 代价
Continuous Batching 高吞吐 调度逻辑复杂
量化(INT8/INT4) 降低显存和延迟 精度损失
PagedAttention 高显存利用率 页管理开销
Speculative Decoding 降低解码延迟 额外计算开销

五、总结

AI 推理性能调优的核心在于理解延迟的构成,并在吞吐量与延迟之间找到业务场景的最优平衡点。推理引擎选型需要权衡性能、易用性和模型覆盖度。Continuous Batching 是高并发场景的必备策略,PagedAttention 是显存优化的基础设施。但优化不是免费的——量化损失精度、批处理增加延迟、引擎切换成本高。

落地路线建议:第一,从 vLLM 起步,验证基础性能后再考虑 TensorRT-LLM 的极致优化;第二,建立推理延迟的 P50/P95/P99 基线,任何优化都必须量化对尾延迟的影响;第三,在应用层建立统一的推理接口,避免引擎锁定。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐