RAG 与大模型微调:技术选型与实战指南

目录导航

一、引言

在大型语言模型(LLM)应用开发中,我们经常面临一个关键问题:如何让模型更好地适应特定任务或领域?

比如:

  • 企业内部知识库问答
  • 特定行业的专业术语理解
  • 符合特定格式的输出生成

面对这些需求,业界主要有两种主流方案:

  1. RAG(检索增强生成) - 通过外部知识库增强模型能力
  2. 大模型微调(Fine-tuning) - 通过训练调整模型参数

这两种方案各有优劣,如何选择往往让开发者困惑。本文将从原理出发,详细对比两种方案,并提供实战代码示例,帮助你做出正确的技术选型。


二、RAG 原理详解

2.1 什么是 RAG

RAG(Retrieval-Augmented Generation),即检索增强生成,是一种将信息检索与文本生成相结合的技术架构。

核心思想:不改变模型权重,而是通过检索外部知识库,为模型提供实时、准确的上下文信息。

类比理解:就像考试时允许查阅参考书,模型可以"查资料"来回答问题,而不是只依赖记忆。

2.2 RAG 工作流程

RAG 的完整工作流程可以分为以下几个步骤:

用户输入 → 文档切分 → 向量化存储 → 检索匹配 → 上下文组装 → LLM 生成

详细流程

第一步:文档处理与切分

将原始文档进行预处理和切分。切分策略直接影响检索效果。

# 示例:文档切分
def split_documents(text, chunk_size=500, overlap=50):
    chunks = []
    start = 0
    while start < len(text):
        end = start + chunk_size
        chunk = text[start:end]
        chunks.append(chunk)
        start = end - overlap  # 保持重叠,避免上下文丢失
    return chunks

关键参数

  • chunk_size:每个文本块的大小,通常 300-500 字
  • overlap:相邻块的重叠长度,通常 20-50 字
第二步:向量化存储

使用 Embedding 模型将文本块转换为向量,存储到向量数据库中。

# 示例:文本向量化
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
chunks = ["文档块1", "文档块2", ...]
embeddings = model.encode(chunks)  # 返回 numpy 数组

常用 Embedding 模型

  • text-embedding-ada-002(OpenAI)
  • paraphrase-multilingual-MiniLM-L12-v2(开源)
  • bge-large-zh-v1.5(中文优化)
第三步:检索匹配

当用户提问时,将问题同样转换为向量,在向量数据库中检索最相关的文本块。

# 示例:向量检索
query = "用户的问题"
query_embedding = model.encode([query])
similarities = cosine_similarity(query_embedding, embeddings)
top_k_indices = similarities.argsort()[-3:][::-1]  # 取最相似的3个
relevant_chunks = [chunks[i] for i in top_k_indices]
第四步:上下文组装与生成

将检索到的相关文本块组装成提示词(Prompt),发送给 LLM 生成答案。

# 示例:组装提示词
prompt = f"""基于以下参考资料回答问题。如果资料中没有相关信息,请说明。

参考资料:
{chr(10).join(relevant_chunks)}

用户问题:{query}
答案:"""

RAG 工作流程图

┌─────────────────────────────────────────────────────────────────┐
│                         RAG 工作流程                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌──────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐  │
│  │ 文档输入  │───→│ 文档切分 │───→│ 向量化   │───→│ 存储到   │  │
│  └──────────┘    └──────────┘    └──────────┘    │ 向量数据库│  │
│                                                 └──────────┘  │
│                                                        ↑       │
│                                                        │       │
│  ┌──────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐  │
│  │ 最终输出 │←───│ LLM 生成 │←───│ 组装Prompt│←───│ 检索匹配 │  │
│  └──────────┘    └──────────┘    └──────────┘    └──────────┘  │
│                                                                 │
│                         用户提问                                  │
│                            ↓                                    │
│  ┌──────────┐    ┌──────────┐                                  │
│  │ 查询向量 │←───│ 问题向量化│                                  │
│  └──────────┘    └──────────┘                                  │
└─────────────────────────────────────────────────────────────────┘

2.3 RAG 优势与局限

✅ 优势
  1. 知识可更新性强
    • 更新知识库不需要重新训练模型
    • 适合知识频繁变化的场景(新闻、产品信息等)
  2. 可解释性好
    • 可以直接查看检索到的参考文档
    • 答案可溯源,增强可信度
  3. 成本较低
    • 只需支付 Embedding 和 LLM 调用费用
    • 无需 GPU 训练资源
  4. 部署灵活
    • 支持多种向量数据库(Milvus、Pinecone、Chroma 等)
    • 可轻松扩展知识库规模
❌ 局限
  1. 依赖检索质量
    • 检索效果差,整体性能受限
    • 对 Embedding 模型敏感
  2. 上下文长度限制
    • LLM 上下文窗口有限
    • 无法一次性利用大量检索文档
  3. 响应延迟
    • 检索+生成的两阶段流程
    • 延迟高于纯生成式模型
  4. 复杂推理受限
    • 对于需要复杂推理的任务
    • 可能无法充分利用分散的检索片段

三、大模型微调原理详解

3.1 什么是微调

微调(Fine-tuning) 是在预训练模型的基础上,使用特定领域或任务的数据进行额外训练,调整模型参数,使模型适应目标任务。

核心思想:通过训练,让模型"学会"特定任务的知识和模式。

类比理解:就像一个已经学过基础课程的学生,再参加专业培训来掌握特定技能。

3.2 常见微调方法

全参数微调(Full Fine-tuning)

训练模型的所有参数。

# 全参数微调示例(伪代码)
from transformers import AutoModelForSequenceClassification, Trainer

model = AutoModelForSequenceClassification.from_pretrained("base-model")
trainer = Trainer(
    model=model,
    train_dataset=train_data,
    args=TrainingArguments(
        learning_rate=2e-5,
        num_train_epochs=3,
        per_device_train_batch_size=8
    )
)
trainer.train()

问题:需要大量 GPU 显存,训练成本高,容易过拟合。

LoRA(Low-Rank Adaptation)

核心思想:不直接修改原始权重,而是添加低秩矩阵来捕捉任务特定知识。

原始权重 W:保持冻结
新增权重 ΔW = B × A,其中 B ∈ R^(d×r), A ∈ R^(r×k), r << min(d,k)
最终权重:W' = W + ΔW
# LoRA 微调示例
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,                    # 低秩维度
    lora_alpha=16,           # 缩放因子
    target_modules=["q_proj", "v_proj"],  # 应用 LoRA 的层
    lora_dropout=0.05,
    task_type="CAUSAL_LM"
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# 输出: trainable params: 0.1% || all params: 6.7B || trainable%: 0.1%

优势:大幅减少可训练参数,节省显存,训练速度快。

QLoRA(Quantized LoRA)

在 LoRA 基础上,对模型进行量化,进一步减少显存占用。

# QLoRA 微调示例
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
    "base-model",
    quantization_config=quantization_config,
    device_map="auto"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

优势:可在消费级 GPU(如 RTX 3090)上微调 70B 参数的模型。

对比表格
方法 可训练参数 显存需求 训练速度 效果
全参数微调 100% 非常高 最好
LoRA 0.1-5% 接近全参数
QLoRA 0.1-5% 中等 中等 接近 LoRA

3.3 微调优势与局限

✅ 优势
  1. 任务适应性强
    • 可以学习特定领域的表达方式
    • 输出格式更加可控
  2. 推理速度快
    • 训练完成后是纯生成模型
    • 无需额外的检索步骤
  3. 个性化程度高
    • 可以学习复杂的指令遵循模式
    • 适合需要强一致性输出的场景
❌ 局限
  1. 更新成本高
    • 每次知识更新都需要重新训练
    • 训练需要 GPU 资源
  2. 可能产生幻觉
    • 模型可能产生与训练数据不一致的输出
    • 可解释性不如 RAG
  3. 训练不稳定
    • 学习率、epoch 等超参敏感
    • 需要一定调参经验
  4. 灾难性遗忘
    • 新任务可能覆盖原有能力
    • 需要混合训练策略

四、技术选型指南

4.1 何时选择 RAG

推荐场景

  1. 知识库频繁更新
    • 产品文档、新闻资讯、客服知识库
    • 需要实时反映最新信息
  2. 需要答案可溯源
    • 法律、医疗、金融等高风险领域
    • 需要展示参考来源
  3. 多源异构数据
    • 来自不同格式、不同来源的信息
    • 需要统一检索和整合
  4. 预算有限
    • 无 GPU 训练资源
    • 希望降低运营成本

判断标准

✓ 需要处理大量外部知识
✓ 知识库需要频繁更新
✓ 需要可解释性和可溯源性
✓ 没有 GPU 训练资源
✓ 延迟要求不是特别高

4.2 何时选择微调

推荐场景

  1. 任务模式固定
    • 特定的输出格式、风格要求
    • 复杂的指令遵循任务
  2. 领域专业性强
    • 需要学习专业术语和表达方式
    • 特定行业的复杂推理
  3. 推理延迟要求高
    • 实时对话系统
    • 高并发场景
  4. 数据量充足
    • 有大量高质量的训练数据
    • 能够承担训练成本

判断标准

✓ 任务模式相对固定
✓ 需要复杂指令遵循能力
✓ 推理延迟要求严格
✓ 训练数据充足
✓ 有 GPU 训练资源

4.3 混合策略

实际应用中,RAG 和微调并非互斥,可以结合使用:

┌─────────────────────────────────────────────┐
│              混合策略架构                     │
├─────────────────────────────────────────────┤
│                                             │
│  ┌─────────┐                               │
│  │ 用户输入 │                               │
│  └────┬────┘                               │
│       ↓                                     │
│  ┌────┴────┐                               │
│  │ 意图判断 │                               │
│  └────┬────┘                               │
│       ↓                                     │
│  ┌────┴────┐     ┌────────────┐            │
│  │ 知识密集 │────→│   RAG 模块  │            │
│  └────┬────┘     └──────┬─────┘            │
│       │                 ↓                    │
│       ↓           ┌────────────┐            │
│  ┌────┴────┐       │ 检索知识库  │            │
│  │ 任务密集 │────→└──────┬─────┘            │
│  └────┬────┘            ↓                    │
│       ↓          ┌────────────┐            │
│  ┌────┴────┐      │ 微调模型   │            │
│  │  融合层  │←────│ 生成响应   │            │
│  └────┬────┘      └────────────┘            │
│       ↓                                     │
│  ┌────┴────┐                                │
│  │ 最终输出 │                                │
│  └─────────┘                                │
└─────────────────────────────────────────────┘

典型模式

  1. 微调基础模型 + RAG 增强
    • 先微调模型学习领域知识
    • 运行时结合 RAG 提供最新信息
  2. 微调 Embedding + RAG
    • 微调 Embedding 模型提升检索质量
    • 结合 RAG 提供上下文
  3. 微调 Router + 路由选择
    • 微调一个小模型作为 Router
    • 决定使用 RAG 还是微调模型

五、实战代码示例

5.1 RAG 完整实现

环境准备
pip install langchain sentence-transformers chromadb openai python-dotenv
完整代码
"""
RAG 完整实现示例
使用 LangChain + ChromaDB + OpenAI
"""

import os
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

class RAGSystem:
    def __init__(self, openai_api_key: str, persist_directory: str = "./chroma_db"):
        self.persist_directory = persist_directory
        
        # 初始化 Embedding 模型
        self.embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
        
        # 初始化 LLM
        self.llm = ChatOpenAI(
            model_name="gpt-3.5-turbo",
            temperature=0.7,
            openai_api_key=openai_api_key
        )
        
        self.vectorstore = None
        self.qa_chain = None
    
    def load_documents(self, file_path: str):
        """加载文档"""
        loader = TextLoader(file_path, encoding='utf-8')
        documents = loader.load()
        return documents
    
    def split_documents(self, documents, chunk_size: int = 500, chunk_overlap: int = 50):
        """文档切分"""
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len
        )
        texts = text_splitter.split_documents(documents)
        return texts
    
    def create_vectorstore(self, texts):
        """创建向量数据库"""
        self.vectorstore = Chroma.from_documents(
            documents=texts,
            embedding=self.embeddings,
            persist_directory=self.persist_directory
        )
        return self.vectorstore
    
    def setup_qa_chain(self, search_kwargs: dict = None):
        """设置问答链"""
        if search_kwargs is None:
            search_kwargs = {"k": 3}  # 检索最相关的3个文档块
        
        # 自定义提示词模板
        prompt_template = """基于以下参考资料回答问题。如果资料中没有相关信息,请说明。

参考资料:
{context}

用户问题:{question}
答案:"""
        
        PROMPT = PromptTemplate(
            template=prompt_template,
            input_variables=["context", "question"]
        )
        
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.vectorstore.as_retriever(
                search_kwargs=search_kwargs
            ),
            chain_type_kwargs={"prompt": PROMPT},
            return_source_documents=True
        )
        
        return self.qa_chain
    
    def query(self, question: str):
        """执行查询"""
        if self.qa_chain is None:
            raise ValueError("QA chain not initialized. Call setup_qa_chain() first.")
        
        result = self.qa_chain({"query": question})
        return {
            "answer": result["result"],
            "source_documents": result["source_documents"]
        }


def main():
    # 初始化 RAG 系统
    rag = RAGSystem(
        openai_api_key=os.getenv("OPENAI_API_KEY"),
        persist_directory="./my_knowledge_base"
    )
    
    # 加载和处理文档
    documents = rag.load_documents("knowledge_base.txt")
    texts = rag.split_documents(documents)
    rag.create_vectorstore(texts)
    rag.setup_qa_chain()
    
    # 执行查询
    question = "产品的主要功能有哪些?"
    result = rag.query(question)
    
    print(f"问题:{question}")
    print(f"答案:{result['answer']}")
    print(f"\n参考来源:")
    for i, doc in enumerate(result['source_documents'], 1):
        print(f"{i}. {doc.page_content[:100]}...")


if __name__ == "__main__":
    main()
使用说明
  1. 创建 knowledge_base.txt 文件,放入你要处理的知识库内容
  2. 设置环境变量 OPENAI_API_KEY
  3. 运行脚本即可实现基于知识库的问答

5.2 LoRA 微调完整实现

环境准备
pip install transformers peft datasets accelerate bitsandbytes
完整代码
"""
LoRA 微调完整实现示例
使用 HuggingFace Transformers + PEFT
"""

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset
import json

class LoRAFineTuner:
    def __init__(
        self,
        model_name: str = "THUDM/chatglm2-6b",
        lora_r: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.05
    ):
        self.model_name = model_name
        self.tokenizer = None
        self.model = None
        self.lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
            lora_dropout=lora_dropout,
            bias="none",
            task_type="CAUSAL_LM"
        )
    
    def load_model(self, use_4bit: bool = False):
        """加载模型"""
        print(f"Loading model: {self.model_name}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            trust_remote_code=True
        )
        
        if use_4bit:
            # QLoRA 配置
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                quantization_config=quantization_config,
                device_map="auto",
                trust_remote_code=True
            )
            self.model = prepare_model_for_kbit_training(self.model)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                device_map="auto",
                trust_remote_code=True
            )
        
        print("Model loaded successfully!")
        return self.model, self.tokenizer
    
    def prepare_lora_model(self):
        """应用 LoRA"""
        self.model = get_peft_model(self.model, self.lora_config)
        self.model.print_trainable_parameters()
        return self.model
    
    def tokenize_function(self, examples, max_length: int = 512):
        """Tokenize 数据"""
        result = self.tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length"
        )
        result["labels"] = result["input_ids"].copy()
        return result
    
    def prepare_dataset(self, data_path: str):
        """准备数据集"""
        # 从 JSONL 文件加载数据
        texts = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                # 格式:[系统提示] 用户: 问题\n助手: 回答
                text = f"{data.get('instruction', '')} User: {data['input']}\nAssistant: {data['output']}"
                texts.append(text)
        
        dataset = Dataset.from_dict({"text": texts})
        tokenized_dataset = dataset.map(
            self.tokenize_function,
            batched=True,
            remove_columns=dataset.column_names
        )
        
        return tokenized_dataset
    
    def train(
        self,
        train_dataset,
        output_dir: str = "./lora_model",
        num_train_epochs: int = 3,
        per_device_train_batch_size: int = 4,
        gradient_accumulation_steps: int = 4,
        learning_rate: float = 2e-4,
        warmup_steps: int = 100,
        logging_steps: int = 10,
        save_steps: int = 100
    ):
        """执行训练"""
        
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_train_epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            learning_rate=learning_rate,
            warmup_steps=warmup_steps,
            logging_steps=logging_steps,
            save_steps=save_steps,
            fp16=True,
            optim="paged_adamw_8bit",
            report_to="tensorboard"
        )
        
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False  # Causal LM,不需要 MLM
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            data_collator=data_collator
        )
        
        print("Starting training...")
        trainer.train()
        
        # 保存 LoRA 权重
        self.model.save_pretrained(output_dir)
        print(f"Model saved to {output_dir}")
        
        return trainer
    
    def chat(self, prompt: str, max_length: int = 512):
        """使用微调后的模型对话"""
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                temperature=0.7,
                top_p=0.9
            )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response


def main():
    # 初始化微调器
    trainer = LoRAFineTuner(
        model_name="microsoft/phi-2",  # 使用较小的模型作为示例
        lora_r=8,
        lora_alpha=16
    )
    
    # 加载模型(根据显存调整是否使用 QLoRA)
    trainer.load_model(use_4bit=True)
    trainer.prepare_lora_model()
    
    # 准备数据集(需要准备 JSONL 格式的训练数据)
    # 格式:{"instruction": "系统指令", "input": "用户输入", "output": "模型输出"}
    train_dataset = trainer.prepare_dataset("train_data.jsonl")
    
    # 训练
    trainer.train(
        train_dataset=train_dataset,
        output_dir="./my_lora_model",
        num_train_epochs=3,
        per_device_train_batch_size=2
    )
    
    # 测试对话
    response = trainer.chat("用户: 你好\n助手:")
    print(f"Response: {response}")


if __name__ == "__main__":
    main()
训练数据格式示例

创建 train_data.jsonl

{"instruction": "你是一个专业的客服助手", "input": "如何重置密码?", "output": "您可以通过以下步骤重置密码:1. 点击登录页面的'忘记密码'链接;2. 输入您的注册邮箱;3. 查收邮件并点击验证链接;4. 设置新密码。如果有任何问题,请联系我们的客服团队。"}
{"instruction": "你是一个专业的客服助手", "input": "产品支持哪些支付方式?", "output": "我们支持多种支付方式,包括:信用卡(Visa、MasterCard)、借记卡、支付宝、微信支付以及银行转账。您可以在结账页面选择最方便的方式完成支付。"}

六、总结与建议

核心对比

维度 RAG 微调
知识更新 实时更新,无需训练 需要重新训练
成本 Embedding + LLM 调用 GPU 训练资源
可解释性 高,可溯源 低,隐式学习
任务适配 通用知识检索 特定任务优化
推理延迟 较高(检索+生成) 低(纯生成)
部署复杂度 中(需向量数据库) 低(单一模型)

选型建议

优先选择 RAG

  • 知识库需要频繁更新
  • 需要答案可溯源
  • 预算有限,无 GPU
  • 数据源多样化

优先选择微调

  • 任务模式固定
  • 需要复杂的指令遵循
  • 推理延迟要求高
  • 有充足的训练数据

考虑混合策略

  • 知识密集且任务复杂
  • 需要兼顾实时性和专业性
  • 有一定的技术储备

实践建议

  1. 从小规模开始
    • 先用 RAG 验证需求
    • 再根据效果决定是否微调
  2. 重视数据质量
    • 无论 RAG 还是微调,数据质量都是关键
    • 做好数据清洗和标注
  3. 建立评估体系
    • 设计合理的评估指标
    • A/B 测试不同方案
  4. 持续迭代优化
    • 根据用户反馈优化
    • 保持模型的知识更新

希望本文能帮助你更好地理解 RAG 和微调技术,并在实际项目中做出正确的技术选型!


Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐