RAG 系统的深度优化:从召回精度到生成质量的全链路提升

cover

一、从简单检索到智能问答:RAG 系统的演进痛点

在构建 AI 产品的早期,我们使用简单的 RAG 架构就能满足需求:文档切分、向量化存储、相似度检索、将结果拼接到 Prompt 中。但随着用户规模增长和文档数量增加,我们发现这套系统遇到了明显的瓶颈:

  • 召回的文档相关性不高,Top 5 中经常有不相关的内容
  • 对于需要跨多文档综合的问题,系统表现不佳
  • 答案经常与检索到的内容不一致,出现幻觉
  • 对长文档的处理效率低,检索时间过长

这些问题不是简单地换个嵌入模型就能解决的。技术如果不服务于真实的问答质量,那只是一个花哨的演示系统。我们需要一套完整的 RAG 优化框架,从数据预处理到最终答案生成,全链路提升系统质量。

二、RAG 系统的分层架构:优化的全景视图

flowchart TD
    subgraph 数据层
        A[原始文档] --> B[文档清洗]
        B --> C[智能切分]
        C --> D[元数据增强]
        D --> E[向量化]
    end
    
    subgraph 检索层
        F[查询理解] --> G[混合检索]
        G --> H[重排序]
        H --> I[结果过滤]
    end
    
    subgraph 生成层
        J[上下文构建] --> K[提示工程]
        K --> L[答案生成]
        L --> M[事实核查]
    end
    
    E --> F
    I --> J

2.1 智能文档切分策略

文档切分不是简单的按字数分割,我们需要根据文档结构进行智能切分:

import re
from typing import List, Dict, Any
from dataclasses import dataclass
import markdown

@dataclass
class DocumentChunk:
    content: str
    metadata: Dict[str, Any]
    start_pos: int
    end_pos: int

class SmartDocumentSplitter:
    """智能文档切分器"""
    
    def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
    
    def split_markdown(self, markdown_text: str) -> List[DocumentChunk]:
        """
        智能切分 Markdown 文档
        保留标题层级、代码块完整性
        """
        chunks = []
        lines = markdown_text.split('\n')
        
        # 识别标题层级
        sections = self._extract_sections(lines)
        
        current_chunk = ""
        current_pos = 0
        
        for section in sections:
            section_content = section['content']
            section_level = section['level']
            
            # 如果添加当前 section 会超过 chunk_size
            if len(current_chunk) + len(section_content) > self.chunk_size:
                if current_chunk:
                    chunks.append(DocumentChunk(
                        content=current_chunk,
                        metadata={"section_level": section_level},
                        start_pos=current_pos - len(current_chunk),
                        end_pos=current_pos
                    ))
                
                # 处理特别长的 section
                if len(section_content) > self.chunk_size:
                    sub_chunks = self._split_long_section(section_content)
                    for sub_chunk in sub_chunks:
                        chunks.append(DocumentChunk(
                            content=sub_chunk,
                            metadata={"section_level": section_level, "is_subchunk": True},
                            start_pos=current_pos,
                            end_pos=current_pos + len(sub_chunk)
                        ))
                        current_pos += len(sub_chunk)
                    current_chunk = ""
                else:
                    current_chunk = section_content
            else:
                current_chunk += "\n" + section_content if current_chunk else section_content
            
            current_pos += len(section_content)
        
        # 处理最后一个 chunk
        if current_chunk:
            chunks.append(DocumentChunk(
                content=current_chunk,
                metadata={"section_level": section_level if 'section_level' in locals() else 0},
                start_pos=current_pos - len(current_chunk),
                end_pos=current_pos
            ))
        
        return chunks
    
    def _extract_sections(self, lines: List[str]) -> List[Dict]:
        """从 Markdown 文本中提取章节"""
        sections = []
        current_section = {"level": 0, "content": ""}
        
        for line in lines:
            header_match = re.match(r'^(#{1,6})\s+(.*)$', line)
            
            if header_match:
                # 保存当前章节
                if current_section["content"]:
                    sections.append(current_section)
                
                # 开始新章节
                level = len(header_match.group(1))
                current_section = {
                    "level": level,
                    "content": line
                }
            else:
                if current_section["content"]:
                    current_section["content"] += "\n" + line
                else:
                    current_section["content"] = line
        
        if current_section["content"]:
            sections.append(current_section)
        
        return sections
    
    def _split_long_section(self, content: str) -> List[str]:
        """切分特别长的章节"""
        # 尝试按段落切分
        paragraphs = content.split('\n\n')
        chunks = []
        current_chunk = ""
        
        for paragraph in paragraphs:
            if len(current_chunk) + len(paragraph) > self.chunk_size:
                if current_chunk:
                    chunks.append(current_chunk)
                # 如果单段落超过 chunk_size,按句子切分
                if len(paragraph) > self.chunk_size:
                    sentences = re.split(r'[。!?.!?]', paragraph)
                    sentence_chunk = ""
                    for sentence in sentences:
                        if sentence.strip():
                            if len(sentence_chunk) + len(sentence) > self.chunk_size:
                                if sentence_chunk:
                                    chunks.append(sentence_chunk)
                                sentence_chunk = sentence
                            else:
                                sentence_chunk += sentence if sentence_chunk else sentence
                    if sentence_chunk:
                        chunks.append(sentence_chunk)
                else:
                    current_chunk = paragraph
            else:
                current_chunk += "\n\n" + paragraph if current_chunk else paragraph
        
        if current_chunk:
            chunks.append(current_chunk)
        
        return chunks

2.2 元数据增强与混合检索

除了文本内容,我们还会为每个 Chunk 添加丰富的元数据,支持多维度检索:

from typing import List, Dict, Any
import hashlib

class MetadataEnhancer:
    """元数据增强器"""
    
    def enhance_chunk(self, chunk: DocumentChunk, 
                     document_metadata: Dict[str, Any]) -> DocumentChunk:
        """增强 Chunk 的元数据"""
        # 复制文档级元数据
        chunk.metadata.update(document_metadata)
        
        # 计算内容哈希
        chunk.metadata["content_hash"] = hashlib.md5(
            chunk.content.encode()
        ).hexdigest()
        
        # 估算 Token 数量
        chunk.metadata["estimated_tokens"] = len(chunk.content) // 4
        
        # 提取关键词
        chunk.metadata["keywords"] = self._extract_keywords(chunk.content)
        
        # 识别内容类型
        chunk.metadata["content_type"] = self._classify_content_type(chunk.content)
        
        return chunk
    
    def _extract_keywords(self, content: str, top_k: int = 5) -> List[str]:
        """提取关键词(简单实现,生产环境可用 TF-IDF 或 LLM)"""
        # 这里用简单的词频统计作为示例
        words = re.findall(r'[\w\d]+', content.lower())
        word_freq = {}
        for word in words:
            if len(word) > 2:  # 过滤短词
                word_freq[word] = word_freq.get(word, 0) + 1
        
        sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
        return [word for word, freq in sorted_words[:top_k]]
    
    def _classify_content_type(self, content: str) -> str:
        """分类内容类型"""
        if re.search(r'```[\s\S]*?```', content):
            return "code_example"
        elif re.search(r'^(?:步骤|Step|STEP)\s*\d+', content, re.MULTILINE):
            return "tutorial"
        elif re.search(r'^(?:定义|Definition|DEFINITION)\s*[::]', content, re.MULTILINE):
            return "definition"
        else:
            return "general"

class HybridRetriever:
    """混合检索器:关键词检索 + 语义检索"""
    
    def __init__(self, vector_store, keyword_index):
        self.vector_store = vector_store
        self.keyword_index = keyword_index
    
    def retrieve(self, query: str, top_k: int = 10, 
                 keyword_weight: float = 0.3, 
                 semantic_weight: float = 0.7) -> List[Dict]:
        """混合检索"""
        # 并行执行两种检索
        keyword_results = self.keyword_index.search(query, top_k * 2)
        semantic_results = self.vector_store.search(query, top_k * 2)
        
        # 结果合并与重排序
        combined = self._merge_results(
            keyword_results, semantic_results,
            keyword_weight, semantic_weight
        )
        
        return combined[:top_k]
    
    def _merge_results(self, keyword_results: List[Dict], 
                      semantic_results: List[Dict],
                      keyword_weight: float, 
                      semantic_weight: float) -> List[Dict]:
        """合并检索结果"""
        # 创建结果字典
        result_dict = {}
        
        # 处理关键词检索结果
        for rank, result in enumerate(keyword_results):
            doc_id = result["doc_id"]
            score = (1.0 - (rank / len(keyword_results))) * keyword_weight
            if doc_id in result_dict:
                result_dict[doc_id]["score"] += score
            else:
                result_dict[doc_id] = {**result, "score": score}
        
        # 处理语义检索结果
        for rank, result in enumerate(semantic_results):
            doc_id = result["doc_id"]
            score = (1.0 - (rank / len(semantic_results))) * semantic_weight
            if doc_id in result_dict:
                result_dict[doc_id]["score"] += score
            else:
                result_dict[doc_id] = {**result, "score": score}
        
        # 按分数排序
        sorted_results = sorted(
            result_dict.values(),
            key=lambda x: x["score"],
            reverse=True
        )
        
        return sorted_results

三、重排序与上下文压缩:提升有效信息密度

3.1 交叉编码器重排序

使用 Cross-Encoder 对初步检索结果进行精细重排序:

from sentence_transformers import CrossEncoder
from typing import List, Dict

class Reranker:
    """检索结果重排序器"""
    
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.model = CrossEncoder(model_name)
    
    def rerank(self, query: str, results: List[Dict], 
              top_k: int = 5) -> List[Dict]:
        """重排序检索结果"""
        if not results:
            return []
        
        # 准备输入对
        pairs = [(query, result["content"]) for result in results]
        
        # 预测相似度分数
        scores = self.model.predict(pairs)
        
        # 添加分数并排序
        for result, score in zip(results, scores):
            result["rerank_score"] = float(score)
        
        # 按重排序分数排序
        reranked = sorted(
            results,
            key=lambda x: x["rerank_score"],
            reverse=True
        )
        
        return reranked[:top_k]

3.2 上下文压缩与优化

即使检索到了相关内容,直接全部拼接到 Prompt 中也不是最优策略。我们使用 LLM 进行上下文压缩:

from typing import List, Dict, Any
from openai import OpenAI

class ContextCompressor:
    """上下文压缩器"""
    
    def __init__(self, client: OpenAI, model: str = "gpt-3.5-turbo"):
        self.client = client
        self.model = model
    
    def compress(self, query: str, chunks: List[Dict], 
                max_tokens: int = 4000) -> str:
        """压缩上下文"""
        # 先过滤掉不相关的内容
        relevant_chunks = self._filter_relevant(query, chunks)
        
        # 对每个 chunk 进行摘要
        compressed_parts = []
        for chunk in relevant_chunks:
            summary = self._summarize_chunk(query, chunk["content"])
            compressed_parts.append(summary)
            
            # 检查 Token 限制
            total_text = "\n\n".join(compressed_parts)
            if len(total_text) // 4 > max_tokens:
                break
        
        return "\n\n".join(compressed_parts)
    
    def _filter_relevant(self, query: str, chunks: List[Dict]) -> List[Dict]:
        """过滤不相关的 Chunk"""
        # 可以使用简单的关键词匹配或小模型分类
        # 这里简化处理,假设 rerank 后的结果已经足够相关
        return chunks
    
    def _summarize_chunk(self, query: str, content: str) -> str:
        """摘要单个 Chunk,保留与 Query 相关的信息"""
        prompt = f"""基于以下用户问题,请提取文档内容中最相关的信息。
只保留与问题直接相关的内容,删除无关细节。

用户问题:{query}

文档内容:{content}

相关信息摘要:"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1,
            max_tokens=500
        )
        
        return response.choices[0].message.content

四、提示工程与事实核查:提升答案质量

4.1 结构化的 RAG 提示模板

我们设计了结构化的提示模板,引导 LLM 基于检索到的内容生成答案:

class RAGPromptBuilder:
    """RAG 提示构建器"""
    
    def build_prompt(self, query: str, context: str, 
                    conversation_history: List[Dict] = None) -> str:
        """构建 RAG 提示"""
        prompt = self._get_system_prompt()
        
        # 添加上下文
        prompt += f"""<context>
{context}
</context>

"""
        
        # 添加对话历史(如果有)
        if conversation_history:
            prompt += "<conversation_history>\n"
            for msg in conversation_history[-5:]:  # 只保留最近 5 轮
                prompt += f"{msg['role']}: {msg['content']}\n"
            prompt += "</conversation_history>\n\n"
        
        # 添加用户问题
        prompt += f"""请基于上述上下文回答用户问题。如果上下文中没有足够信息,请直接说明,不要编造答案。

用户问题:{query}

回答:"""
        
        return prompt
    
    def _get_system_prompt(self) -> str:
        """获取系统提示"""
        return """你是一个专业的问答助手,需要基于提供的上下文回答用户问题。

回答要求:
1. 只使用上下文中的信息,不要引入外部知识
2. 如果上下文中没有答案,请明确说明
3. 回答要简洁、准确、有条理
4. 引用上下文信息时,可以说明信息来源

"""

4.2 答案的事实核查

为了减少幻觉,我们在生成答案后进行事实核查:

class FactChecker:
    """事实核查器"""
    
    def __init__(self, client: OpenAI, model: str = "gpt-3.5-turbo"):
        self.client = client
        self.model = model
    
    def check(self, answer: str, context: str) -> Dict[str, Any]:
        """核查答案是否基于上下文"""
        prompt = f"""请检查以下回答是否基于提供的上下文。

回答:
{answer}

上下文:
{context}

请按以下 JSON 格式输出核查结果:
{{
    "is_factual": true/false,  // 答案是否基于上下文
    "issues": [  // 发现的问题列表
        {{
            "type": "hallucination" | "misinformation" | "missing_source",
            "content": "问题内容描述",
            "position": "问题在回答中的大概位置"
        }}
    ],
    "suggested_revision": "如果有问题,提供修正建议"
}}"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1,
            max_tokens=1000
        )
        
        import json
        try:
            result = json.loads(response.choices[0].message.content)
            return result
        except:
            return {
                "is_factual": False,
                "issues": [{"type": "check_failed", "content": "核查过程出错"}],
                "suggested_revision": answer
            }

五、评估框架与持续优化:数据驱动的改进

5.1 全面的 RAG 评估指标

我们建立了多维度的评估框架:

from typing import List, Dict, Any
import statistics

class RAGEvaluator:
    """RAG 系统评估器"""
    
    def __init__(self, client: OpenAI, model: str = "gpt-4"):
        self.client = client
        self.model = model
    
    def evaluate_query(self, query: str, expected_answer: str,
                      actual_answer: str, retrieved_contexts: List[str]) -> Dict[str, float]:
        """评估单个 Query 的结果"""
        # 检索质量评估
        retrieval_metrics = self._evaluate_retrieval(
            query, retrieved_contexts, expected_answer
        )
        
        # 答案质量评估
        answer_metrics = self._evaluate_answer(
            query, actual_answer, expected_answer
        )
        
        return {**retrieval_metrics, **answer_metrics}
    
    def _evaluate_retrieval(self, query: str, contexts: List[str], 
                           expected_answer: str) -> Dict[str, float]:
        """评估检索质量"""
        # 检查 Top N 召回率
        has_relevant = False
        for i, context in enumerate(contexts):
            if self._is_context_relevant(query, context, expected_answer):
                has_relevant = True
                break
        
        # 计算上下文有用性分数(使用 LLM 评估)
        prompt = f"""请评估以下检索到的上下文对回答问题是否有用。

问题:{query}

期望答案要点:{expected_answer}

检索到的上下文:
{chr(10).join([f"{i+1}. {ctx[:200]}..." for i, ctx in enumerate(contexts)])}

请在 0-5 分之间打分(5分表示非常相关有用,0分表示完全不相关)。
只返回分数数字:"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1,
            max_tokens=10
        )
        
        try:
            usefulness_score = float(response.choices[0].message.content.strip())
        except:
            usefulness_score = 2.5
        
        return {
            "retrieval_recall": 1.0 if has_relevant else 0.0,
            "context_usefulness": usefulness_score
        }
    
    def _evaluate_answer(self, query: str, actual_answer: str, 
                        expected_answer: str) -> Dict[str, float]:
        """评估答案质量"""
        prompt = f"""请评估以下回答的质量,从三个维度打分(每个维度 0-5 分):

问题:{query}

期望答案要点:{expected_answer}

实际回答:{actual_answer}

请按以下 JSON 格式输出:
{{
    "accuracy": 0-5,  // 答案的准确性
    "completeness": 0-5,  // 答案的完整性
    "relevance": 0-5  // 答案的相关性
}}"""
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1,
            max_tokens=200
        )
        
        import json
        try:
            result = json.loads(response.choices[0].message.content)
            return {
                "answer_accuracy": float(result.get("accuracy", 2.5)),
                "answer_completeness": float(result.get("completeness", 2.5)),
                "answer_relevance": float(result.get("relevance", 2.5))
            }
        except:
            return {
                "answer_accuracy": 2.5,
                "answer_completeness": 2.5,
                "answer_relevance": 2.5
            }
    
    def _is_context_relevant(self, query: str, context: str, 
                            expected_answer: str) -> bool:
        """判断上下文是否相关"""
        # 这里简化实现,生产环境可用 LLM 或交叉编码器
        for keyword in expected_answer.split()[:5]:
            if keyword in context:
                return True
        return False
    
    def evaluate_batch(self, test_cases: List[Dict]) -> Dict[str, float]:
        """批量评估"""
        all_metrics = []
        for case in test_cases:
            metrics = self.evaluate_query(
                case["query"],
                case["expected_answer"],
                case["actual_answer"],
                case["retrieved_contexts"]
            )
            all_metrics.append(metrics)
        
        # 计算平均指标
        avg_metrics = {}
        for key in all_metrics[0].keys():
            values = [m[key] for m in all_metrics]
            avg_metrics[key] = statistics.mean(values)
        
        return avg_metrics

六、总结

RAG 系统的优化是一个持续迭代的过程。通过智能文档切分、混合检索、重排序、上下文压缩、精心设计的提示工程和事实核查,我们显著提升了问答质量。

在优化过程中,建立完善的评估框架至关重要。数据驱动的优化让我们能够清晰地知道哪个环节的改进带来了最大的价值,避免盲目尝试。

对于 AI 创业公司来说,高质量的 RAG 系统是产品竞争力的关键。它让我们能够利用自有知识资产,为用户提供准确、可靠的 AI 服务,而不只是依赖通用大模型的能力。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐