长序列大语言模型推理中的显存优化方案:Context Parallel (CP) 深度解析
作者:昇腾实战派
背景概述
在处理长序列的大语言模型推理时,内存和计算资源的瓶颈问题尤为突出。为了有效应对这一挑战,本文将深入探讨 vllm-ascend 项目中实现的显存优化方案——Context Parallel (CP),特别是其子策略 Prefill Context Parallel (PCP) 和 Decode Context Parallel (DCP) 的设计思路、实现细节及适用场景。通过本文,读者将能够更好地理解如何在实际业务中应用这些策略,以实现性能的显著提升。
名词表
| 缩写 | 全称 | 含义 |
|---|---|---|
| TP | Tensor Parallelism | 张量并行 |
| CP | Context Parallelism | 上下文并行 |
| AR / RS / AG | AllReduce / ReduceScatter / AllGather | 集合通信原语 |
| MLA | Multi-head Latent Attention | 深度学习模型中采用的注意力变体 |
| GQA | Grouped Query Attention | 注意力机制的一种变体 |
| MoE | Mixture of Experts | 混合专家模型 |
1. 什么是 Context Parallel (CP)
Context Parallel (CP) 是一种针对长序列推理的并行优化策略,其核心思想是沿着序列维度将长序列的计算任务分散到多个计算设备上。CP 可以分为两个独立的子策略:
-
PCP (Prefill Context Parallel, 预填充上下文并行)
- 目标:加速 Prefill 阶段,降低首次输出时间 (TTFT)。
- 原理:将输入序列切分,不同设备同时计算序列的不同片段。
- 代价:需要引入新的通信域,总设备数 =
tensor_parallel_size × prefill_context_parallel_size。 - 实现:KV cache 沿序列维度分片存储到各设备。
-
DCP (Decode Context Parallel, 解码上下文并行)
- 目标:消除 KV cache 的冗余副本,节省显存,提升 Decode 吞吐量。
- 原理:复用 TP 的通信域,不引入额外计算设备,将原本在 TP 组内重复存储的 KV cache 沿序列维度分片。
1.1 PCP 与 DCP 的整体关系
PCP 和 DCP 可以同时启用,也可以单独使用。它们的关系如下:
cp_size = pcp_size × dcp_size
cp_rank = pcp_rank × dcp_size + dcp_rank
2. CP 实现中的 KV Cache 分片:Block Table
CP 对 KV cache 进行序列维度分片,采用 交错(interleaved)方式 存储 token,交错粒度由参数 cp_kv_cache_interleave_size 控制(默认值为 1,即 token 级交错)。
virtual_block_size = block_size × cp_size
BlockTable.compute_slot_mapping 调用 Triton kernel,传入 TOTAL_CP_WORLD_SIZE = pcp_size * dcp_size 和 TOTAL_CP_RANK = pcp_rank * dcp_size + dcp_rank,将 token 的物理存储位置按交错方式分配到各设备:
# vllm_ascend/worker/block_table.py
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
_compute_slot_mapping_kernel[(num_reqs + 1,)](
...
TOTAL_CP_WORLD_SIZE=total_cp_world_size,
TOTAL_CP_RANK=total_cp_rank,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
)
每个设备只存储属于自己 rank 的 token slot,实现 KV cache 沿序列维度的分片。
3. 针对长序列输入优化 PCP
3.1 序列切分:PCPManager.update_tokens_for_pcp
这是 PCP 的核心预处理步骤,实现 DualChunkSwap(首尾风格) 切分:
原始序列长度 L → 填充到 2*pcp_size 的倍数
→ 分成 2*pcp_size 个 chunk
→ rank i 拿 chunk[i](head)和 chunk[2*pcp_size-1-i](tail)
从而保证各设备计算负载均衡。
3.2 PCP — Prefill 阶段
Step 1:AllGather KV
每个 rank 只有本地序列片段的 KV,通过 allgather 聚合完整 KV,并用 pcp_allgather_restore_idx 恢复原始顺序:
# vllm_ascend/attention/context_parallel/attention_cp.py
kv = torch.cat([key, value], dim=-1)
all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx)
key, value = all_kv.split([self.head_size, self.head_size], dim=-1)
Step 2:Head/Tail 分组计算 Attention
每个 rank 的 Q 被分成 head 和 tail 两部分,分别与对应的 KV 做 attention(有 mask 和无 mask 两段):
# attention_cp.py
# _forward_prefill_cp_pre: 按预计算的索引切分 Q 和 KV
q_head = torch.index_select(query, 0, q_head_idx)
q_tail = torch.index_select(query, 0, q_tail_idx)
k_head_nomask = torch.index_select(key, 0, kv_with_q_head_nomask_idx) # 无 causal mask 部分
k_head_mask = torch.index_select(key, 0, kv_with_q_head_mask_idx) # 有 causal mask 部分
对 head 和 tail 分别调用 _attention_with_nomask_and_mask,用 npu_fused_infer_attention_score 分别计算无 mask 和有 mask 的 attention,再用 online softmax 合并(_npu_attn_out_lse_update):
# attention_cp.py
def _attention_with_nomask_and_mask(
self,
q: torch.Tensor,
q_seqlens: list[int],
k_nomask: torch.Tensor,
v_nomask: torch.Tensor,
kv_seqlens_nomask: list[int],
k_mask: torch.Tensor,
v_mask: torch.Tensor,
kv_seqlens_mask: list[int],
mask: torch.Tensor,
attn_metadata,
) -> torch.Tensor:
# nomask Attention
if k_nomask is not None:
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
q,
k_nomask,
v_nomask,
...
)
# mask Attention
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
q,
k_mask,
v_mask,
...
)
# update
output = attn_out_mask
attn_lse = attn_lse_mask
if k_nomask is not None:
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None:
output = _npu_attn_out_lse_update(attn_lse_mask, attn_lse_nomask, attn_out_mask, attn_out_nomask)
attn_lse = None
else:
output, attn_lse = _update_out_and_lse(
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0),
)
return output, attn_lse
Step 3:合并 head/tail 输出
# attention_cp.py
def _forward_prefill_cp_post(self, outputs, lses, attn_metadata):
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
output = torch.index_select(torch.cat(outputs, dim=0), 0, q_full_idx)
attn_lse = None
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
attn_lse = torch.index_select(torch.cat(lses, dim=0), 0, q_full_idx)
return output, attn_lse
3.3 PCP — Decode 阶段
在 DCP 的 all-to-all 通信交换 output 和 LSE 之后,在 PCP 组内额外增加一次 allgather,再进行输出更新:
# common_cp.py
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if dcp_size > 1:
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=dcp_group)
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
if pcp_size > 1:
# AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0)
return attn_out_lse
然后调用 _npu_attention_update 做 online softmax 归约,得到最终输出:
# common_cp.py
def _npu_attention_update(head_size, attn_out_lse: torch.Tensor) -> torch.Tensor:
pcp_size = get_pcp_group().world_size
dcp_size = get_decode_context_model_parallel_world_size()
# [PCP * S, DCP * H, D+1]
B_total, H_total, D_plus_1 = attn_out_lse.shape
S = B_total // pcp_size
H = H_total // dcp_size
D = head_size
assert D_plus_1 == D + 1
# [PCP, S, DCP, H, D+1]
x = attn_out_lse.view(pcp_size, S, dcp_size, H, D_plus_1)
# [PCP, DCP, S, H, D+1]
x = x.permute(0, 2, 1, 3, 4).contiguous()
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1)
# Split out lse
out_flat, lse_flat = torch.split(x, [D, 1], dim=-1) # [N, S, H, D], [N, S, H, 1]
# out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
# unbind to list
out_list = out_flat.unbind(0) # [S*H, D]
lse_list = lse_flat.unbind(0) # [S*H]
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
attn_out = attn_out.view(-1, H, D)
return attn_out
3.4 PCP + Chunkprefill (GQA 后端)
对于 chunked prefill,GQA 后端采用 AllGatherQ 方案:先在 PCP 组内 allgather Q,再恢复顺序,然后走与 decode 相同的逻辑:
# attention_cp.py
def _prefill_query_all_gather(self, attn_metadata, prefill_query):
if self.pcp_size > 1:
prefill_query = get_pcp_group().all_gather(prefill_query, 0)
prefill_query = torch.index_select(
prefill_query, 0, attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk
)
if self.dcp_size > 1:
prefill_query = get_dcp_group().all_gather(prefill_query, 1)
return prefill_query
4. 针对长序列输出优化 DCP
4.1 DCP — Prefill 阶段
当启用 chunk-prefill 时,两个后端采用了完全不同的方案:
| GQA | MLA | |
|---|---|---|
| 方案 | AllGatherQ | AllGatherKV |
| 通信对象 | AllGather Q(head 维度,DCP 组) | AllGather 压缩 KV cache(DCP 组) |
| KV 来源 | npu_paged_cache_load 从本地 cache 加载 |
_reorg_kvcache 重组 allgather 后的 KV |
4.1.1 GQA:AllGatherQ
_prefill_query_all_gather 先在 PCP 组内 allgather Q(若 PCP 启用),再在 DCP 组内沿 head 维度 allgather Q:
# attention_cp.py: _prefill_query_all_gather
if self.pcp_size > 1:
prefill_query = get_pcp_group().all_gather(prefill_query, 0)
prefill_query = torch.index_select(prefill_query, 0, cp_kv_recover_idx_for_chunk)
if self.dcp_size > 1:
prefill_query = get_dcp_group().all_gather(prefill_query, 1) # head 维度
return prefill_query
然后 _compute_prefill_context 用 AllGathered Q 与本地 KV cache(npu_paged_cache_load 加载)做 attention,actual_seq_lengths_kv 取本 rank 存储的 KV 长度:
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, self.dcp_rank]
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, local_chunked_kv_lens_rank, ...)
context attention 结果通过 _gather_global_context_output(all-to-all + allgather)聚合,再与当前 chunk 的 attention 结果做 online softmax 合并。
4.1.2 MLA:AllGatherKV
_reorg_kvcache 先在 DCP 组内 allgather 压缩 KV cache(kv_c_normed + k_pe),再在 PCP 组内 allgather,然后去掉 padding、重组为连续布局:
# mla_cp.py:
_reorg_kvcache cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
if self.dcp_size > 1:
cache_kv_c_k_pe = get_dcp_group().all_gather(cache_kv_c_k_pe, 0) # token 维度
if self.pcp_size > 1:
cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0) # 按请求逐段重组,去掉 padding
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
重组后的完整 KV 直接与当前 chunk 的 Q 做 attention,无需额外的 output 聚合通信。
4.2 DCP — Decode 阶段
解码阶段的逻辑与 GQA 的分块预填充一致:首先沿 Q 头维度执行 all-gather 操作,以确保 DCP 组内的一致性。使用本地 KV 缓存计算结果后,通过 cp_lse_ag_out_rs 函数更新结果。
5. 总结
5.1 PCP 与 DCP 特性对比
| 功能维度 | PCP (Prefill Context Parallel) | DCP (Decode Context Parallel) |
|---|---|---|
| 目标 | 加速 prefill,降低 TTFT | 消除 KV cache 冗余,提升 decode 吞吐 |
| 通信域 | 独立通信域,扩展 world size | 复用 TP 通信域,不增加设备数 |
| 核心参数 | prefill_context_parallel_size |
decode_context_parallel_size |
| KV Cache 分片 | 沿序列维度分片存储 | 沿序列维度分片存储 |
| 主要影响阶段 | Prefill + Decode | Decode + Chunked Prefill |
5.2 代码示例
# 在线示例
vllm serve deepseek-ai/DeepSeek-V2-Lite \
--tensor-parallel-size 2 \
--decode-context-parallel-size 2 \
--prefill-context-parallel-size 2 \
# 若与 KV 池化,PD 分离等 KV 传输功能共用
# 为简化 KV 缓存传输,必须将 cp_kv_cache_interleave_size 设置与 KV 缓存 block_size(默认值:128)相同
# 这指定了上下文并行以块交错方式切分 KV 缓存。
vllm serve deepseek-ai/DeepSeek-V2-Lite \
--tensor-parallel-size 2 \
--decode-context-parallel-size 2 \
--prefill-context-parallel-size 2 \
--cp-kv-cache-interleave-size 128 \
--kv-transfer-config {...} \
5.3 关键约束
| 模型类型 | DCP 约束条件 |
|---|---|
| MLA (DeepSeek-R1) | tensor_parallel_size % decode_context_parallel_size == 0 & tensor_parallel_size >= decode_context_parallel_size 1 |
| GQA (Qwen3-235B) | (tensor_parallel_size // num_key_value_heads) % decode_context_parallel_size == 0 & (tensor_parallel_size // num_key_value_heads) >= decode_context_parallel_size |
5.4 相关代码文件
- slot_mapping 计算:
vllm_ascend/worker/block_table.py - 序列切分与元数据准备:
vllm_ascend/worker/model_runner_v1.py - GQA 后端:
vllm_ascend/attention/attention_cp.py - MLA 后端:
vllm_ascend/attention/mla_cp.py
6. 参考
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)