作者:昇腾实战派

背景概述

在处理长序列的大语言模型推理时,内存和计算资源的瓶颈问题尤为突出。为了有效应对这一挑战,本文将深入探讨 vllm-ascend 项目中实现的显存优化方案——Context Parallel (CP),特别是其子策略 Prefill Context Parallel (PCP) 和 Decode Context Parallel (DCP) 的设计思路、实现细节及适用场景。通过本文,读者将能够更好地理解如何在实际业务中应用这些策略,以实现性能的显著提升。

名词表

缩写 全称 含义
TP Tensor Parallelism 张量并行
CP Context Parallelism 上下文并行
AR / RS / AG AllReduce / ReduceScatter / AllGather 集合通信原语
MLA Multi-head Latent Attention 深度学习模型中采用的注意力变体
GQA Grouped Query Attention 注意力机制的一种变体
MoE Mixture of Experts 混合专家模型

1. 什么是 Context Parallel (CP)

Context Parallel (CP) 是一种针对长序列推理的并行优化策略,其核心思想是沿着序列维度将长序列的计算任务分散到多个计算设备上。CP 可以分为两个独立的子策略:

  • PCP (Prefill Context Parallel, 预填充上下文并行)

    • 目标:加速 Prefill 阶段,降低首次输出时间 (TTFT)。
    • 原理:将输入序列切分,不同设备同时计算序列的不同片段。
    • 代价:需要引入新的通信域,总设备数 = tensor_parallel_size × prefill_context_parallel_size
    • 实现:KV cache 沿序列维度分片存储到各设备。
  • DCP (Decode Context Parallel, 解码上下文并行)

    • 目标:消除 KV cache 的冗余副本,节省显存,提升 Decode 吞吐量。
    • 原理:复用 TP 的通信域,不引入额外计算设备,将原本在 TP 组内重复存储的 KV cache 沿序列维度分片。

1.1 PCP 与 DCP 的整体关系

PCP 和 DCP 可以同时启用,也可以单独使用。它们的关系如下:

cp_size = pcp_size × dcp_size  
cp_rank = pcp_rank × dcp_size + dcp_rank

2. CP 实现中的 KV Cache 分片:Block Table

CP 对 KV cache 进行序列维度分片,采用 交错(interleaved)方式 存储 token,交错粒度由参数 cp_kv_cache_interleave_size 控制(默认值为 1,即 token 级交错)。

virtual_block_size = block_size × cp_size

BlockTable.compute_slot_mapping 调用 Triton kernel,传入 TOTAL_CP_WORLD_SIZE = pcp_size * dcp_sizeTOTAL_CP_RANK = pcp_rank * dcp_size + dcp_rank,将 token 的物理存储位置按交错方式分配到各设备:

# vllm_ascend/worker/block_table.py  
total_cp_world_size = self.pcp_world_size * self.dcp_world_size  
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
_compute_slot_mapping_kernel[(num_reqs + 1,)](
      ...
      TOTAL_CP_WORLD_SIZE=total_cp_world_size,
      TOTAL_CP_RANK=total_cp_rank,
      CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
  )

每个设备只存储属于自己 rank 的 token slot,实现 KV cache 沿序列维度的分片。

3. 针对长序列输入优化 PCP

3.1 序列切分:PCPManager.update_tokens_for_pcp

这是 PCP 的核心预处理步骤,实现 DualChunkSwap(首尾风格) 切分:

原始序列长度 L → 填充到 2*pcp_size 的倍数  
→ 分成 2*pcp_size 个 chunk  
→ rank i 拿 chunk[i](head)和 chunk[2*pcp_size-1-i](tail)

从而保证各设备计算负载均衡。

3.2 PCP — Prefill 阶段

Step 1:AllGather KV

每个 rank 只有本地序列片段的 KV,通过 allgather 聚合完整 KV,并用 pcp_allgather_restore_idx 恢复原始顺序:

# vllm_ascend/attention/context_parallel/attention_cp.py
  kv = torch.cat([key, value], dim=-1)
  all_kv = get_pcp_group().all_gather(kv[:num_actual_tokens_pcp_padded].contiguous(), dim=0)
  all_kv = torch.index_select(all_kv, 0, pcp_allgather_restore_idx)
  key, value = all_kv.split([self.head_size, self.head_size], dim=-1)

Step 2:Head/Tail 分组计算 Attention

每个 rank 的 Q 被分成 head 和 tail 两部分,分别与对应的 KV 做 attention(有 mask 和无 mask 两段):

# attention_cp.py
# _forward_prefill_cp_pre: 按预计算的索引切分 Q 和 KV
  q_head = torch.index_select(query, 0, q_head_idx)
  q_tail = torch.index_select(query, 0, q_tail_idx)
  k_head_nomask = torch.index_select(key, 0, kv_with_q_head_nomask_idx)  # 无 causal mask 部分
  k_head_mask   = torch.index_select(key, 0, kv_with_q_head_mask_idx)    # 有 causal mask 部分

对 head 和 tail 分别调用 _attention_with_nomask_and_mask,用 npu_fused_infer_attention_score 分别计算无 mask 和有 mask 的 attention,再用 online softmax 合并(_npu_attn_out_lse_update):

# attention_cp.py
def _attention_with_nomask_and_mask(
        self,
        q: torch.Tensor,
        q_seqlens: list[int],
        k_nomask: torch.Tensor,
        v_nomask: torch.Tensor,
        kv_seqlens_nomask: list[int],
        k_mask: torch.Tensor,
        v_mask: torch.Tensor,
        kv_seqlens_mask: list[int],
        mask: torch.Tensor,
        attn_metadata,
    ) -> torch.Tensor:
        # nomask Attention
        if k_nomask is not None:
            attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
                q,
                k_nomask,
                v_nomask,
                ...
            )

        # mask Attention
        attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
            q,
            k_mask,
            v_mask,
            ...
        )
        # update
        output = attn_out_mask
        attn_lse = attn_lse_mask
        if k_nomask is not None:
            if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None:
                output = _npu_attn_out_lse_update(attn_lse_mask, attn_lse_nomask, attn_out_mask, attn_out_nomask)
                attn_lse = None
            else:
                output, attn_lse = _update_out_and_lse(
                    torch.stack([attn_out_nomask, attn_out_mask], dim=0),
                    torch.stack([attn_lse_nomask, attn_lse_mask], dim=0),
                )

        return output, attn_lse

Step 3:合并 head/tail 输出

# attention_cp.py
def _forward_prefill_cp_post(self, outputs, lses, attn_metadata):
       q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
       output = torch.index_select(torch.cat(outputs, dim=0), 0, q_full_idx)
       attn_lse = None
       if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
           attn_lse = torch.index_select(torch.cat(lses, dim=0), 0, q_full_idx)
       return output, attn_lse
3.3 PCP — Decode 阶段

在 DCP 的 all-to-all 通信交换 output 和 LSE 之后,在 PCP 组内额外增加一次 allgather,再进行输出更新:

# common_cp.py
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
    attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
    if dcp_size > 1:
        # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
        attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
        attn_out_lse_all2all = torch.empty_like(attn_out_lse)
        dist.all_to_all_single(attn_out_lse_all2all, attn_out_lse, group=dcp_group)
        attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])

    if pcp_size > 1:
        # AllGather out&lse within CP group
        attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0)

    return attn_out_lse

然后调用 _npu_attention_update 做 online softmax 归约,得到最终输出:

# common_cp.py
def _npu_attention_update(head_size, attn_out_lse: torch.Tensor) -> torch.Tensor:
    pcp_size = get_pcp_group().world_size
    dcp_size = get_decode_context_model_parallel_world_size()
    # [PCP * S, DCP * H, D+1]
    B_total, H_total, D_plus_1 = attn_out_lse.shape
    S = B_total // pcp_size
    H = H_total // dcp_size
    D = head_size
    assert D_plus_1 == D + 1
    # [PCP, S, DCP, H, D+1]
    x = attn_out_lse.view(pcp_size, S, dcp_size, H, D_plus_1)
    # [PCP, DCP, S, H, D+1]
    x = x.permute(0, 2, 1, 3, 4).contiguous()
    # Flatten [N, S, H, D+1], N = pcp_size * dcp_size
    x = x.view(-1, S, H, D_plus_1)
    # Split out lse
    out_flat, lse_flat = torch.split(x, [D, 1], dim=-1)  # [N, S, H, D], [N, S, H, 1]
    #    out: [N, S, H, D] -> [N, S*H, D]
    #    lse: [N, S, H, 1] -> [N, S*H]
    out_flat = out_flat.flatten(1, 2)  # [N, S*H, D]
    lse_flat = lse_flat.flatten(1, -1)  # [N, S*H]
    #  unbind to list
    out_list = out_flat.unbind(0)  # [S*H, D]
    lse_list = lse_flat.unbind(0)  # [S*H]
    attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
    attn_out = attn_out.view(-1, H, D)
    return attn_out
3.4 PCP + Chunkprefill (GQA 后端)

对于 chunked prefill,GQA 后端采用 AllGatherQ 方案:先在 PCP 组内 allgather Q,再恢复顺序,然后走与 decode 相同的逻辑:

# attention_cp.py
def _prefill_query_all_gather(self, attn_metadata, prefill_query):
      if self.pcp_size > 1:
          prefill_query = get_pcp_group().all_gather(prefill_query, 0)
          prefill_query = torch.index_select(
              prefill_query, 0, attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk
          )
      if self.dcp_size > 1:
          prefill_query = get_dcp_group().all_gather(prefill_query, 1)
      return prefill_query

4. 针对长序列输出优化 DCP

4.1 DCP — Prefill 阶段

当启用 chunk-prefill 时,两个后端采用了完全不同的方案:

GQA MLA
方案 AllGatherQ AllGatherKV
通信对象 AllGather Q(head 维度,DCP 组) AllGather 压缩 KV cache(DCP 组)
KV 来源 npu_paged_cache_load 从本地 cache 加载 _reorg_kvcache 重组 allgather 后的 KV
4.1.1 GQA:AllGatherQ

_prefill_query_all_gather 先在 PCP 组内 allgather Q(若 PCP 启用),再在 DCP 组内沿 head 维度 allgather Q:

# attention_cp.py: _prefill_query_all_gather
  if self.pcp_size > 1:
      prefill_query = get_pcp_group().all_gather(prefill_query, 0)
      prefill_query = torch.index_select(prefill_query, 0, cp_kv_recover_idx_for_chunk)
  if self.dcp_size > 1:
      prefill_query = get_dcp_group().all_gather(prefill_query, 1)  # head 维度
  return prefill_query

然后 _compute_prefill_context 用 AllGathered Q 与本地 KV cache(npu_paged_cache_load 加载)做 attention,actual_seq_lengths_kv 取本 rank 存储的 KV 长度:

local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank, self.dcp_rank]
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache, local_chunked_kv_lens_rank, ...)

context attention 结果通过 _gather_global_context_output(all-to-all + allgather)聚合,再与当前 chunk 的 attention 结果做 online softmax 合并。

4.1.2 MLA:AllGatherKV

_reorg_kvcache 先在 DCP 组内 allgather 压缩 KV cache(kv_c_normed + k_pe),再在 PCP 组内 allgather,然后去掉 padding、重组为连续布局:

# mla_cp.py:
 _reorg_kvcache  cache_kv_c_k_pe = torch.cat([kv_c_normed, k_pe], dim=-1)
  if self.dcp_size > 1:
      cache_kv_c_k_pe = get_dcp_group().all_gather(cache_kv_c_k_pe, 0)  # token 维度
  if self.pcp_size > 1:
      cache_kv_c_k_pe = get_pcp_group().all_gather(cache_kv_c_k_pe, 0)  # 按请求逐段重组,去掉 padding
  reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)

重组后的完整 KV 直接与当前 chunk 的 Q 做 attention,无需额外的 output 聚合通信。

4.2 DCP — Decode 阶段

解码阶段的逻辑与 GQA 的分块预填充一致:首先沿 Q 头维度执行 all-gather 操作,以确保 DCP 组内的一致性。使用本地 KV 缓存计算结果后,通过 cp_lse_ag_out_rs 函数更新结果。

5. 总结

5.1 PCP 与 DCP 特性对比
功能维度 PCP (Prefill Context Parallel) DCP (Decode Context Parallel)
目标 加速 prefill,降低 TTFT 消除 KV cache 冗余,提升 decode 吞吐
通信域 独立通信域,扩展 world size 复用 TP 通信域,不增加设备数
核心参数 prefill_context_parallel_size decode_context_parallel_size
KV Cache 分片 沿序列维度分片存储 沿序列维度分片存储
主要影响阶段 Prefill + Decode Decode + Chunked Prefill
5.2 代码示例
# 在线示例
vllm serve deepseek-ai/DeepSeek-V2-Lite \
    --tensor-parallel-size 2 \
    --decode-context-parallel-size 2 \
    --prefill-context-parallel-size 2 \

# 若与 KV 池化,PD 分离等 KV 传输功能共用
# 为简化 KV 缓存传输,必须将 cp_kv_cache_interleave_size 设置与 KV 缓存 block_size(默认值:128)相同
# 这指定了上下文并行以块交错方式切分 KV 缓存。
vllm serve deepseek-ai/DeepSeek-V2-Lite \
    --tensor-parallel-size 2 \
    --decode-context-parallel-size 2 \
    --prefill-context-parallel-size 2 \
    --cp-kv-cache-interleave-size 128 \
    --kv-transfer-config {...} \
5.3 关键约束
模型类型 DCP 约束条件
MLA (DeepSeek-R1) tensor_parallel_size % decode_context_parallel_size == 0 & tensor_parallel_size >= decode_context_parallel_size 1
GQA (Qwen3-235B) (tensor_parallel_size // num_key_value_heads) % decode_context_parallel_size == 0 & (tensor_parallel_size // num_key_value_heads) >= decode_context_parallel_size
5.4 相关代码文件
  • slot_mapping 计算:vllm_ascend/worker/block_table.py
  • 序列切分与元数据准备:vllm_ascend/worker/model_runner_v1.py
  • GQA 后端:vllm_ascend/attention/attention_cp.py
  • MLA 后端:vllm_ascend/attention/mla_cp.py

6. 参考

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐