大模型算法应用:RAG 中的混合检索策略与 Cross-Encoder 重排序

cover

一、RAG 的"检索瓶颈":为什么简单向量搜索不够用

检索增强生成(RAG)的核心假设是:只要能找到相关文档,大模型就能生成正确答案。但实际生产中,"找到相关文档"本身就是最大的挑战。纯向量检索(Dense Retrieval)基于语义相似度匹配,擅长处理同义词和语义关联,但对精确关键词匹配表现不佳——用户搜索"Rust 所有权",向量检索可能返回"Rust 责任体系"的文档,因为"所有权"和"责任"在语义空间中距离很近。

纯关键词检索(Sparse Retrieval,如 BM25)则恰好相反:精确匹配关键词,但无法理解语义——"内存安全"和"Memory Safety"在 BM25 看来毫无关联。混合检索(Hybrid Retrieval)将两种方法的结果融合,取长补短。但融合后的候选集仍然可能包含大量"看似相关实则无关"的文档,此时需要 Cross-Encoder 重排序进行精排。理解从检索到重排序的完整链路,是构建高质量 RAG 系统的关键。

二、混合检索与重排序的架构原理

2.1 Dense 与 Sparse 检索的互补性

Dense Retrieval 使用双编码器(Bi-Encoder)架构:查询和文档分别通过独立的编码器生成向量,再计算余弦相似度。优势是检索速度快(向量索引支持 ANN 查询),劣势是查询和文档之间没有交互计算,精度有限。

Sparse Retrieval 使用 BM25 算法:基于词频-逆文档频率(TF-IDF)的改进,考虑词频饱和度和文档长度归一化。优势是精确匹配能力强,劣势是无法处理同义词和语义关联。

flowchart TD
    A[用户查询] --> B[Dense Retrieval<br/>双编码器向量搜索]
    A --> C[Sparse Retrieval<br/>BM25 关键词匹配]

    B --> D[Dense 候选集<br/>Top-K1]
    C --> E[Sparse 候选集<br/>Top-K2]

    D --> F[结果融合<br/>Reciprocal Rank Fusion]
    E --> F

    F --> G[融合候选集<br/>Top-N]
    G --> H[Cross-Encoder 重排序<br/>查询-文档交互编码]
    H --> I[最终结果<br/>Top-M]

    style B fill:#e1f5fe
    style C fill:#fff3e0
    style H fill:#ffebee
    style I fill:#e8f5e9

2.2 Reciprocal Rank Fusion(RRF)融合算法

RRF 是最常用的多路检索融合算法,核心公式:

RRF_score(d) = Σ 1 / (k + rank_i(d))

其中 k 是平滑常数(通常取 60),rank_i(d) 是文档 d 在第 i 路检索中的排名。RRF 的优势是:不需要归一化不同检索路的分数,直接基于排名融合,对分数分布差异不敏感。

2.3 Cross-Encoder 重排序的精度优势

Cross-Encoder 将查询和文档拼接后输入同一个 Transformer,进行全注意力交互计算。相比 Bi-Encoder 的独立编码,Cross-Encoder 能捕捉查询和文档之间的细粒度交互特征,精度显著提升。

代价是计算成本:Cross-Encoder 需要对每个候选文档单独推理,无法预计算文档向量。因此只能对少量候选(通常 50-100 条)进行重排序,而非全量文档。

三、生产级代码实现:混合检索与重排序管线

3.1 混合检索器

from dataclasses import dataclass
from typing import List, Tuple
import numpy as np

@dataclass
class RetrievalResult:
    """检索结果条目"""
    doc_id: str
    content: str
    dense_score: float
    sparse_score: float
    rrf_score: float
    cross_encoder_score: float = 0.0

class HybridRetriever:
    """混合检索器:Dense + Sparse + RRF 融合"""

    def __init__(
        self,
        dense_index,       # 向量索引(如 FAISS)
        sparse_index,      # BM25 索引
        rrf_k: int = 60,   # RRF 平滑常数
        dense_top_k: int = 50,
        sparse_top_k: int = 50,
    ):
        self.dense_index = dense_index
        self.sparse_index = sparse_index
        self.rrf_k = rrf_k
        self.dense_top_k = dense_top_k
        self.sparse_top_k = sparse_top_k

    def retrieve(self, query: str, query_embedding: np.ndarray) -> List[RetrievalResult]:
        """执行混合检索并融合结果"""

        # 1. Dense 检索
        dense_results = self.dense_index.search(
            query_embedding, top_k=self.dense_top_k
        )

        # 2. Sparse 检索
        sparse_results = self.sparse_index.search(
            query, top_k=self.sparse_top_k
        )

        # 3. RRF 融合
        return self._rrf_fuse(dense_results, sparse_results)

    def _rrf_fuse(
        self,
        dense_results: List[Tuple[str, float]],
        sparse_results: List[Tuple[str, float]],
    ) -> List[RetrievalResult]:
        """Reciprocal Rank Fusion 融合"""
        rrf_scores = {}

        # Dense 路贡献
        for rank, (doc_id, score) in enumerate(dense_results):
            if doc_id not in rrf_scores:
                rrf_scores[doc_id] = {
                    "dense_score": score,
                    "sparse_score": 0.0,
                    "rrf_score": 0.0,
                }
            rrf_scores[doc_id]["dense_score"] = score
            rrf_scores[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank + 1)

        # Sparse 路贡献
        for rank, (doc_id, score) in enumerate(sparse_results):
            if doc_id not in rrf_scores:
                rrf_scores[doc_id] = {
                    "dense_score": 0.0,
                    "sparse_score": score,
                    "rrf_score": 0.0,
                }
            rrf_scores[doc_id]["sparse_score"] = score
            rrf_scores[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank + 1)

        # 按 RRF 分数排序
        sorted_results = sorted(
            rrf_scores.items(),
            key=lambda x: x[1]["rrf_score"],
            reverse=True,
        )

        return [
            RetrievalResult(
                doc_id=doc_id,
                content="",  # 后续填充
                dense_score=scores["dense_score"],
                sparse_score=scores["sparse_score"],
                rrf_score=scores["rrf_score"],
            )
            for doc_id, scores in sorted_results
        ]

3.2 Cross-Encoder 重排序器

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

class CrossEncoderReranker:
    """Cross-Encoder 重排序器:对候选文档精排"""

    def __init__(
        self,
        model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
        max_length: int = 512,
        batch_size: int = 32,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.max_length = max_length
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.model.eval()

    def rerank(
        self,
        query: str,
        candidates: List[RetrievalResult],
        top_k: int = 10,
    ) -> List[RetrievalResult]:
        """对候选文档进行 Cross-Encoder 重排序"""

        if not candidates:
            return []

        # 构造查询-文档对
        pairs = [(query, cand.content) for cand in candidates]

        # 批量推理
        all_scores = []
        for i in range(0, len(pairs), self.batch_size):
            batch_pairs = pairs[i : i + self.batch_size]

            features = self.tokenizer(
                batch_pairs,
                padding=True,
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt",
            ).to(self.device)

            with torch.no_grad():
                scores = self.model(**features).logits.squeeze(-1)
                all_scores.extend(scores.cpu().tolist())

        # 更新 Cross-Encoder 分数并重排序
        for cand, score in zip(candidates, all_scores):
            cand.cross_encoder_score = score

        reranked = sorted(
            candidates,
            key=lambda x: x.cross_encoder_score,
            reverse=True,
        )

        return reranked[:top_k]

3.3 完整 RAG 管线

class RAGPipeline:
    """完整的 RAG 管线:检索 → 融合 → 重排序 → 生成"""

    def __init__(
        self,
        retriever: HybridRetriever,
        reranker: CrossEncoderReranker,
        llm_client,           # 大模型调用客户端
        embed_model,          # 向量编码模型
        rerank_top_k: int = 5,
    ):
        self.retriever = retriever
        self.reranker = reranker
        self.llm_client = llm_client
        self.embed_model = embed_model
        self.rerank_top_k = rerank_top_k

    def query(self, question: str) -> str:
        """执行完整 RAG 查询"""

        # 1. 编码查询
        query_embedding = self.embed_model.encode(question)

        # 2. 混合检索
        candidates = self.retriever.retrieve(question, query_embedding)

        # 填充文档内容(从存储中获取)
        candidates = self._fill_content(candidates)

        # 3. Cross-Encoder 重排序
        reranked = self.reranker.rerank(
            question, candidates, top_k=self.rerank_top_k
        )

        # 4. 构造 Prompt 并生成
        context = self._build_context(reranked)
        prompt = self._build_prompt(question, context)

        return self.llm_client.generate(prompt)

    def _fill_content(self, candidates: List[RetrievalResult]) -> List[RetrievalResult]:
        """从文档存储中填充候选文档的完整内容"""
        # 实际实现会从数据库或文件系统加载
        return candidates

    def _build_context(self, results: List[RetrievalResult]) -> str:
        """将重排序结果构造为上下文文本"""
        context_parts = []
        for i, result in enumerate(results, 1):
            context_parts.append(
                f"[文档{i}] (相关度: {result.cross_encoder_score:.3f})\n{result.content}"
            )
        return "\n\n".join(context_parts)

    def _build_prompt(self, question: str, context: str) -> str:
        """构造 RAG Prompt"""
        return (
            f"请根据以下参考文档回答问题。如果文档中没有相关信息,请说明。\n\n"
            f"参考文档:\n{context}\n\n"
            f"问题:{question}\n\n"
            f"回答:"
        )

四、混合检索与重排序的架构权衡

4.1 延迟与精度的权衡

混合检索 + Cross-Encoder 重排序的总延迟 = Dense 检索延迟 + Sparse 检索延迟 + Cross-Encoder 推理延迟。在典型配置下(50 条 Dense + 50 条 Sparse → 100 条重排序 → Top-5),总延迟约为 200-500ms。如果要求 < 100ms 的响应时间,需要减少重排序候选数或使用更轻量的 Cross-Encoder 模型。

4.2 索引维护成本

Dense 索引(FAISS)和 Sparse 索引(Elasticsearch)需要独立维护和更新。文档新增时,两个索引必须同步更新,否则会出现检索结果不一致。在文档频繁更新的场景中,索引同步是运维的持续负担。

4.3 Cross-Encoder 的领域适配

预训练的 Cross-Encoder 模型(如 ms-marco 系列)在通用领域表现良好,但在垂直领域(医疗、法律、金融)精度显著下降。领域适配需要在领域数据上微调 Cross-Encoder,但标注"查询-文档"相关性数据的成本极高。

五、总结

RAG 系统的检索质量决定了生成质量的上限。三个关键设计决策:第一,使用 Dense + Sparse 混合检索,通过 RRF 融合算法兼顾语义匹配和精确匹配;第二,在融合结果上应用 Cross-Encoder 重排序,利用交互编码提升精排质量;第三,控制重排序候选数量在 50-100 条,在精度和延迟之间取得平衡。检索不是"搜一下"的简单操作,而是需要系统性设计的核心模块。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐