大模型算法应用:RAG 中的混合检索策略与 Cross-Encoder 重排序
大模型算法应用:RAG 中的混合检索策略与 Cross-Encoder 重排序

一、RAG 的"检索瓶颈":为什么简单向量搜索不够用
检索增强生成(RAG)的核心假设是:只要能找到相关文档,大模型就能生成正确答案。但实际生产中,"找到相关文档"本身就是最大的挑战。纯向量检索(Dense Retrieval)基于语义相似度匹配,擅长处理同义词和语义关联,但对精确关键词匹配表现不佳——用户搜索"Rust 所有权",向量检索可能返回"Rust 责任体系"的文档,因为"所有权"和"责任"在语义空间中距离很近。
纯关键词检索(Sparse Retrieval,如 BM25)则恰好相反:精确匹配关键词,但无法理解语义——"内存安全"和"Memory Safety"在 BM25 看来毫无关联。混合检索(Hybrid Retrieval)将两种方法的结果融合,取长补短。但融合后的候选集仍然可能包含大量"看似相关实则无关"的文档,此时需要 Cross-Encoder 重排序进行精排。理解从检索到重排序的完整链路,是构建高质量 RAG 系统的关键。
二、混合检索与重排序的架构原理
2.1 Dense 与 Sparse 检索的互补性
Dense Retrieval 使用双编码器(Bi-Encoder)架构:查询和文档分别通过独立的编码器生成向量,再计算余弦相似度。优势是检索速度快(向量索引支持 ANN 查询),劣势是查询和文档之间没有交互计算,精度有限。
Sparse Retrieval 使用 BM25 算法:基于词频-逆文档频率(TF-IDF)的改进,考虑词频饱和度和文档长度归一化。优势是精确匹配能力强,劣势是无法处理同义词和语义关联。
flowchart TD
A[用户查询] --> B[Dense Retrieval<br/>双编码器向量搜索]
A --> C[Sparse Retrieval<br/>BM25 关键词匹配]
B --> D[Dense 候选集<br/>Top-K1]
C --> E[Sparse 候选集<br/>Top-K2]
D --> F[结果融合<br/>Reciprocal Rank Fusion]
E --> F
F --> G[融合候选集<br/>Top-N]
G --> H[Cross-Encoder 重排序<br/>查询-文档交互编码]
H --> I[最终结果<br/>Top-M]
style B fill:#e1f5fe
style C fill:#fff3e0
style H fill:#ffebee
style I fill:#e8f5e9
2.2 Reciprocal Rank Fusion(RRF)融合算法
RRF 是最常用的多路检索融合算法,核心公式:
RRF_score(d) = Σ 1 / (k + rank_i(d))
其中 k 是平滑常数(通常取 60),rank_i(d) 是文档 d 在第 i 路检索中的排名。RRF 的优势是:不需要归一化不同检索路的分数,直接基于排名融合,对分数分布差异不敏感。
2.3 Cross-Encoder 重排序的精度优势
Cross-Encoder 将查询和文档拼接后输入同一个 Transformer,进行全注意力交互计算。相比 Bi-Encoder 的独立编码,Cross-Encoder 能捕捉查询和文档之间的细粒度交互特征,精度显著提升。
代价是计算成本:Cross-Encoder 需要对每个候选文档单独推理,无法预计算文档向量。因此只能对少量候选(通常 50-100 条)进行重排序,而非全量文档。
三、生产级代码实现:混合检索与重排序管线
3.1 混合检索器
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
@dataclass
class RetrievalResult:
"""检索结果条目"""
doc_id: str
content: str
dense_score: float
sparse_score: float
rrf_score: float
cross_encoder_score: float = 0.0
class HybridRetriever:
"""混合检索器:Dense + Sparse + RRF 融合"""
def __init__(
self,
dense_index, # 向量索引(如 FAISS)
sparse_index, # BM25 索引
rrf_k: int = 60, # RRF 平滑常数
dense_top_k: int = 50,
sparse_top_k: int = 50,
):
self.dense_index = dense_index
self.sparse_index = sparse_index
self.rrf_k = rrf_k
self.dense_top_k = dense_top_k
self.sparse_top_k = sparse_top_k
def retrieve(self, query: str, query_embedding: np.ndarray) -> List[RetrievalResult]:
"""执行混合检索并融合结果"""
# 1. Dense 检索
dense_results = self.dense_index.search(
query_embedding, top_k=self.dense_top_k
)
# 2. Sparse 检索
sparse_results = self.sparse_index.search(
query, top_k=self.sparse_top_k
)
# 3. RRF 融合
return self._rrf_fuse(dense_results, sparse_results)
def _rrf_fuse(
self,
dense_results: List[Tuple[str, float]],
sparse_results: List[Tuple[str, float]],
) -> List[RetrievalResult]:
"""Reciprocal Rank Fusion 融合"""
rrf_scores = {}
# Dense 路贡献
for rank, (doc_id, score) in enumerate(dense_results):
if doc_id not in rrf_scores:
rrf_scores[doc_id] = {
"dense_score": score,
"sparse_score": 0.0,
"rrf_score": 0.0,
}
rrf_scores[doc_id]["dense_score"] = score
rrf_scores[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank + 1)
# Sparse 路贡献
for rank, (doc_id, score) in enumerate(sparse_results):
if doc_id not in rrf_scores:
rrf_scores[doc_id] = {
"dense_score": 0.0,
"sparse_score": score,
"rrf_score": 0.0,
}
rrf_scores[doc_id]["sparse_score"] = score
rrf_scores[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank + 1)
# 按 RRF 分数排序
sorted_results = sorted(
rrf_scores.items(),
key=lambda x: x[1]["rrf_score"],
reverse=True,
)
return [
RetrievalResult(
doc_id=doc_id,
content="", # 后续填充
dense_score=scores["dense_score"],
sparse_score=scores["sparse_score"],
rrf_score=scores["rrf_score"],
)
for doc_id, scores in sorted_results
]
3.2 Cross-Encoder 重排序器
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
class CrossEncoderReranker:
"""Cross-Encoder 重排序器:对候选文档精排"""
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
max_length: int = 512,
batch_size: int = 32,
):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.max_length = max_length
self.batch_size = batch_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.model.eval()
def rerank(
self,
query: str,
candidates: List[RetrievalResult],
top_k: int = 10,
) -> List[RetrievalResult]:
"""对候选文档进行 Cross-Encoder 重排序"""
if not candidates:
return []
# 构造查询-文档对
pairs = [(query, cand.content) for cand in candidates]
# 批量推理
all_scores = []
for i in range(0, len(pairs), self.batch_size):
batch_pairs = pairs[i : i + self.batch_size]
features = self.tokenizer(
batch_pairs,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
scores = self.model(**features).logits.squeeze(-1)
all_scores.extend(scores.cpu().tolist())
# 更新 Cross-Encoder 分数并重排序
for cand, score in zip(candidates, all_scores):
cand.cross_encoder_score = score
reranked = sorted(
candidates,
key=lambda x: x.cross_encoder_score,
reverse=True,
)
return reranked[:top_k]
3.3 完整 RAG 管线
class RAGPipeline:
"""完整的 RAG 管线:检索 → 融合 → 重排序 → 生成"""
def __init__(
self,
retriever: HybridRetriever,
reranker: CrossEncoderReranker,
llm_client, # 大模型调用客户端
embed_model, # 向量编码模型
rerank_top_k: int = 5,
):
self.retriever = retriever
self.reranker = reranker
self.llm_client = llm_client
self.embed_model = embed_model
self.rerank_top_k = rerank_top_k
def query(self, question: str) -> str:
"""执行完整 RAG 查询"""
# 1. 编码查询
query_embedding = self.embed_model.encode(question)
# 2. 混合检索
candidates = self.retriever.retrieve(question, query_embedding)
# 填充文档内容(从存储中获取)
candidates = self._fill_content(candidates)
# 3. Cross-Encoder 重排序
reranked = self.reranker.rerank(
question, candidates, top_k=self.rerank_top_k
)
# 4. 构造 Prompt 并生成
context = self._build_context(reranked)
prompt = self._build_prompt(question, context)
return self.llm_client.generate(prompt)
def _fill_content(self, candidates: List[RetrievalResult]) -> List[RetrievalResult]:
"""从文档存储中填充候选文档的完整内容"""
# 实际实现会从数据库或文件系统加载
return candidates
def _build_context(self, results: List[RetrievalResult]) -> str:
"""将重排序结果构造为上下文文本"""
context_parts = []
for i, result in enumerate(results, 1):
context_parts.append(
f"[文档{i}] (相关度: {result.cross_encoder_score:.3f})\n{result.content}"
)
return "\n\n".join(context_parts)
def _build_prompt(self, question: str, context: str) -> str:
"""构造 RAG Prompt"""
return (
f"请根据以下参考文档回答问题。如果文档中没有相关信息,请说明。\n\n"
f"参考文档:\n{context}\n\n"
f"问题:{question}\n\n"
f"回答:"
)
四、混合检索与重排序的架构权衡
4.1 延迟与精度的权衡
混合检索 + Cross-Encoder 重排序的总延迟 = Dense 检索延迟 + Sparse 检索延迟 + Cross-Encoder 推理延迟。在典型配置下(50 条 Dense + 50 条 Sparse → 100 条重排序 → Top-5),总延迟约为 200-500ms。如果要求 < 100ms 的响应时间,需要减少重排序候选数或使用更轻量的 Cross-Encoder 模型。
4.2 索引维护成本
Dense 索引(FAISS)和 Sparse 索引(Elasticsearch)需要独立维护和更新。文档新增时,两个索引必须同步更新,否则会出现检索结果不一致。在文档频繁更新的场景中,索引同步是运维的持续负担。
4.3 Cross-Encoder 的领域适配
预训练的 Cross-Encoder 模型(如 ms-marco 系列)在通用领域表现良好,但在垂直领域(医疗、法律、金融)精度显著下降。领域适配需要在领域数据上微调 Cross-Encoder,但标注"查询-文档"相关性数据的成本极高。
五、总结
RAG 系统的检索质量决定了生成质量的上限。三个关键设计决策:第一,使用 Dense + Sparse 混合检索,通过 RRF 融合算法兼顾语义匹配和精确匹配;第二,在融合结果上应用 Cross-Encoder 重排序,利用交互编码提升精排质量;第三,控制重排序候选数量在 50-100 条,在精度和延迟之间取得平衡。检索不是"搜一下"的简单操作,而是需要系统性设计的核心模块。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)