AI 推理性能调优:推理引擎选型与批处理策略的工程实战
AI 推理性能调优:推理引擎选型与批处理策略的工程实战

一、推理延迟的构成:为什么模型推理总是"慢在等"
大模型推理的延迟由三个阶段构成:Prefill(预填充)、Decode(解码)和调度开销。Prefill 阶段处理输入 Prompt 的所有 Token,计算 KV Cache,延迟与输入长度线性相关。Decode 阶段逐 Token 生成输出,每一步都需要读取完整的 KV Cache,延迟与输出长度和 KV Cache 大小相关。调度开销则来自请求排队、批处理组建和 GPU 内存管理。
在生产环境中,推理延迟的瓶颈往往不在模型计算本身,而在调度策略。当多个请求并发到达时,推理引擎需要决定如何将它们组织成 Batch 送入 GPU。简单的"等够 N 个请求再执行"策略会导致短请求被长请求拖慢;"来一个执行一个"的策略则浪费了 GPU 的并行计算能力。批处理策略的优劣,直接决定了推理服务的吞吐量和尾延迟。
graph LR
A[请求队列] --> B[调度器]
B --> C1[Batch 1<br/>3 个请求]
B --> C2[Batch 2<br/>5 个请求]
B --> C3[Batch 3<br/>2 个请求]
C1 --> D[GPU 执行]
C2 --> D
C3 --> D
D --> E1[请求1 输出]
D --> E2[请求2 输出]
D --> E3[请求3 输出]
F[连续批处理<br/>Continuous Batching] --> B
F -->|请求完成即释放| G[KV Cache 回收]
G -->|新请求即时加入| B
style F fill:#e1f5fe
style G fill:#e8f5e9
二、推理引擎的选型对比:vLLM、TensorRT-LLM 与 TGI
当前主流的开源推理引擎有三个:vLLM、TensorRT-LLM 和 TGI(Text Generation Inference)。选型需要从性能、易用性、模型覆盖度和社区活跃度四个维度评估。
vLLM 的核心优势是 PagedAttention 技术,通过分页管理 KV Cache 显存,将显存利用率从传统的 20-40% 提升到 90% 以上。配合 Continuous Batching,vLLM 在高并发场景下的吞吐量显著优于静态批处理方案。vLLM 的劣势是对新模型的支持有时滞后于 Hugging Face,且在极低延迟场景下不如 TensorRT-LLM。
TensorRT-LLM 是 NVIDIA 推出的推理加速方案,通过算子融合、Kernel 自动调优和 FP8 量化实现极致性能。在 A100/H100 GPU 上,TensorRT-LLM 的推理速度通常比 vLLM 快 30-50%。但代价是工程复杂度极高——模型需要先编译为 TensorRT Engine,编译过程耗时且对环境敏感,模型更新后需要重新编译。
TGI 是 Hugging Face 推出的推理服务,优势在于与 Hugging Face 生态的无缝集成,支持 Flash Attention、量化、水印等功能。TGI 的性能介于 vLLM 和 TensorRT-LLM 之间,但在易用性上最优——从模型下载到推理服务上线,一条命令即可完成。
| 引擎 | 吞吐量 | 延迟 | 易用性 | 模型覆盖 |
|---|---|---|---|---|
| vLLM | 高 | 中 | 中 | 广 |
| TensorRT-LLM | 最高 | 最低 | 低 | 中 |
| TGI | 中高 | 中 | 高 | 最广 |
三、批处理策略的代码实现
以下实现展示了 Continuous Batching 的核心调度逻辑,以及基于 Token 预算的批处理组建策略。
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from collections import deque
import time
class RequestStatus(Enum):
QUEUED = "queued"
PREFILLING = "prefilling"
DECODING = "decoding"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class InferenceRequest:
"""推理请求"""
request_id: str
prompt_tokens: list[int] # 输入 Token IDs
max_output_tokens: int # 最大输出长度
temperature: float = 0.7
status: RequestStatus = RequestStatus.QUEUED
output_tokens: list[int] = field(default_factory=list)
kv_cache_pages: list[int] = field(default_factory=list) # 分配的 KV Cache 页
arrival_time: float = field(default_factory=time.time)
start_time: Optional[float] = None
completion_time: Optional[float] = None
@dataclass
class BatchConfig:
"""批处理配置"""
max_batch_size: int = 32 # 最大批大小
max_tokens_per_batch: int = 8192 # 每批最大 Token 数(含输入+输出)
max_waiting_time_ms: int = 50 # 最大等待时间(毫秒)
kv_cache_page_size: int = 16 # 每页 Token 数
class ContinuousBatchScheduler:
"""连续批处理调度器:请求完成即释放,新请求即时加入"""
def __init__(self, config: BatchConfig, total_kv_pages: int):
self.config = config
self.total_kv_pages = total_kv_pages
self.available_pages = total_kv_pages
self.waiting_queue: deque[InferenceRequest] = deque()
self.running_requests: list[InferenceRequest] = []
def add_request(self, request: InferenceRequest):
"""添加推理请求到等待队列"""
self.waiting_queue.append(request)
def schedule(self) -> list[InferenceRequest]:
"""调度一批请求:从等待队列中选择可执行的请求"""
batch = list(self.running_requests) # 保留正在解码的请求
current_tokens = self._count_batch_tokens(batch)
# 尝试从等待队列中加入新请求
while self.waiting_queue:
# 检查批大小限制
if len(batch) >= self.config.max_batch_size:
break
# 检查 Token 预算
next_req = self.waiting_queue[0]
required_tokens = len(next_req.prompt_tokens) + next_req.max_output_tokens
if current_tokens + required_tokens > self.config.max_tokens_per_batch:
# Token 预算不足,检查是否可以加入更小的请求
# 避免长请求阻塞短请求
skipped = self._try_find_smaller_request(current_tokens)
if skipped:
batch.append(skipped)
current_tokens += len(skipped.prompt_tokens) + skipped.max_output_tokens
continue
break
# 检查 KV Cache 页是否足够
required_pages = self._estimate_kv_pages(next_req)
if required_pages > self.available_pages:
break # 显存不足,等待运行中请求释放
# 加入批处理
req = self.waiting_queue.popleft()
req.status = RequestStatus.PREFILLING
req.start_time = time.time()
allocated = self._allocate_kv_pages(required_pages)
req.kv_cache_pages = allocated
batch.append(req)
current_tokens += required_tokens
self.running_requests = batch
return batch
def on_step_complete(self, completed_ids: list[str]) -> list[InferenceRequest]:
"""处理一步解码完成后的状态更新"""
completed = []
for req in self.running_requests:
if req.request_id in completed_ids:
# 请求完成:释放 KV Cache 页
self._release_kv_pages(req.kv_cache_pages)
req.status = RequestStatus.COMPLETED
req.completion_time = time.time()
completed.append(req)
else:
# 请求仍在解码:检查是否超过最大输出长度
if len(req.output_tokens) >= req.max_output_tokens:
self._release_kv_pages(req.kv_cache_pages)
req.status = RequestStatus.COMPLETED
req.completion_time = time.time()
completed.append(req)
else:
req.status = RequestStatus.DECODING
# 从运行列表中移除已完成的请求
self.running_requests = [
r for r in self.running_requests
if r.request_id not in {c.request_id for c in completed}
]
return completed
def _count_batch_tokens(self, batch: list[InferenceRequest]) -> int:
"""统计批处理中的总 Token 数"""
total = 0
for req in batch:
total += len(req.prompt_tokens) + len(req.output_tokens)
return total
def _estimate_kv_pages(self, req: InferenceRequest) -> int:
"""估算请求所需的 KV Cache 页数"""
total_tokens = len(req.prompt_tokens) + req.max_output_tokens
pages = (total_tokens + self.config.kv_cache_page_size - 1) // self.config.kv_cache_page_size
return pages
def _allocate_kv_pages(self, count: int) -> list[int]:
"""分配 KV Cache 页"""
if count > self.available_pages:
return []
self.available_pages -= count
# 简化实现:返回页 ID 列表
return list(range(count))
def _release_kv_pages(self, pages: list[int]):
"""释放 KV Cache 页"""
self.available_pages += len(pages)
def _try_find_smaller_request(self, current_tokens: int) -> Optional[InferenceRequest]:
"""在等待队列中寻找 Token 预算允许的较小请求"""
for i, req in enumerate(self.waiting_queue):
required = len(req.prompt_tokens) + req.max_output_tokens
if current_tokens + required <= self.config.max_tokens_per_batch:
return self.waiting_queue[i]
return None
四、推理性能优化的边界与权衡
吞吐量与延迟的矛盾。 大 Batch 提升吞吐量但增加单请求延迟(排队等待时间长),小 Batch 降低延迟但浪费 GPU 算力。需要根据业务 SLA 选择平衡点:对延迟敏感的场景(如对话),限制最大批大小和等待时间;对吞吐敏感的场景(如批量处理),增大批大小。
量化精度与推理质量。 INT8 量化通常只损失 1-2% 的精度,INT4 量化可能损失 5-10%。量化选择需要基于业务对精度的容忍度。对于代码生成场景,INT4 量化可能导致变量名拼写错误,不可接受;对于文本摘要场景,INT4 量化的精度损失通常可接受。
KV Cache 的显存瓶颈。 在长上下文场景下,KV Cache 占用的显存可能超过模型权重。例如,Llama-3-70B 在 128K 上下文长度下,单个请求的 KV Cache 就需要约 8GB 显存。PagedAttention 缓解了显存碎片问题,但不减少总显存需求。长上下文场景下,需要配合 KV Cache 压缩(如稀疏注意力)或卸载到 CPU 内存。
引擎锁定的风险。 不同推理引擎的 API 和优化策略差异很大,切换引擎需要重新适配。建议在应用层建立统一的推理接口,将引擎选择作为可替换的实现细节,而非硬编码依赖。
| 优化维度 | 收益 | 代价 |
|---|---|---|
| Continuous Batching | 高吞吐 | 调度逻辑复杂 |
| 量化(INT8/INT4) | 降低显存和延迟 | 精度损失 |
| PagedAttention | 高显存利用率 | 页管理开销 |
| Speculative Decoding | 降低解码延迟 | 额外计算开销 |
五、总结
AI 推理性能调优的核心在于理解延迟的构成,并在吞吐量与延迟之间找到业务场景的最优平衡点。推理引擎选型需要权衡性能、易用性和模型覆盖度。Continuous Batching 是高并发场景的必备策略,PagedAttention 是显存优化的基础设施。但优化不是免费的——量化损失精度、批处理增加延迟、引擎切换成本高。
落地路线建议:第一,从 vLLM 起步,验证基础性能后再考虑 TensorRT-LLM 的极致优化;第二,建立推理延迟的 P50/P95/P99 基线,任何优化都必须量化对尾延迟的影响;第三,在应用层建立统一的推理接口,避免引擎锁定。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)