RAG 与大模型微调:技术选型与实战指南
RAG 与大模型微调:技术选型与实战指南
目录导航
一、引言
在大型语言模型(LLM)应用开发中,我们经常面临一个关键问题:如何让模型更好地适应特定任务或领域?
比如:
- 企业内部知识库问答
- 特定行业的专业术语理解
- 符合特定格式的输出生成
面对这些需求,业界主要有两种主流方案:
- RAG(检索增强生成) - 通过外部知识库增强模型能力
- 大模型微调(Fine-tuning) - 通过训练调整模型参数
这两种方案各有优劣,如何选择往往让开发者困惑。本文将从原理出发,详细对比两种方案,并提供实战代码示例,帮助你做出正确的技术选型。
二、RAG 原理详解
2.1 什么是 RAG
RAG(Retrieval-Augmented Generation),即检索增强生成,是一种将信息检索与文本生成相结合的技术架构。
核心思想:不改变模型权重,而是通过检索外部知识库,为模型提供实时、准确的上下文信息。
类比理解:就像考试时允许查阅参考书,模型可以"查资料"来回答问题,而不是只依赖记忆。
2.2 RAG 工作流程
RAG 的完整工作流程可以分为以下几个步骤:
用户输入 → 文档切分 → 向量化存储 → 检索匹配 → 上下文组装 → LLM 生成
详细流程:
第一步:文档处理与切分
将原始文档进行预处理和切分。切分策略直接影响检索效果。
# 示例:文档切分
def split_documents(text, chunk_size=500, overlap=50):
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk)
start = end - overlap # 保持重叠,避免上下文丢失
return chunks
关键参数:
chunk_size:每个文本块的大小,通常 300-500 字overlap:相邻块的重叠长度,通常 20-50 字
第二步:向量化存储
使用 Embedding 模型将文本块转换为向量,存储到向量数据库中。
# 示例:文本向量化
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
chunks = ["文档块1", "文档块2", ...]
embeddings = model.encode(chunks) # 返回 numpy 数组
常用 Embedding 模型:
text-embedding-ada-002(OpenAI)paraphrase-multilingual-MiniLM-L12-v2(开源)bge-large-zh-v1.5(中文优化)
第三步:检索匹配
当用户提问时,将问题同样转换为向量,在向量数据库中检索最相关的文本块。
# 示例:向量检索
query = "用户的问题"
query_embedding = model.encode([query])
similarities = cosine_similarity(query_embedding, embeddings)
top_k_indices = similarities.argsort()[-3:][::-1] # 取最相似的3个
relevant_chunks = [chunks[i] for i in top_k_indices]
第四步:上下文组装与生成
将检索到的相关文本块组装成提示词(Prompt),发送给 LLM 生成答案。
# 示例:组装提示词
prompt = f"""基于以下参考资料回答问题。如果资料中没有相关信息,请说明。
参考资料:
{chr(10).join(relevant_chunks)}
用户问题:{query}
答案:"""
RAG 工作流程图:
┌─────────────────────────────────────────────────────────────────┐
│ RAG 工作流程 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 文档输入 │───→│ 文档切分 │───→│ 向量化 │───→│ 存储到 │ │
│ └──────────┘ └──────────┘ └──────────┘ │ 向量数据库│ │
│ └──────────┘ │
│ ↑ │
│ │ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 最终输出 │←───│ LLM 生成 │←───│ 组装Prompt│←───│ 检索匹配 │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
│ │
│ 用户提问 │
│ ↓ │
│ ┌──────────┐ ┌──────────┐ │
│ │ 查询向量 │←───│ 问题向量化│ │
│ └──────────┘ └──────────┘ │
└─────────────────────────────────────────────────────────────────┘
2.3 RAG 优势与局限
✅ 优势
- 知识可更新性强
- 更新知识库不需要重新训练模型
- 适合知识频繁变化的场景(新闻、产品信息等)
- 可解释性好
- 可以直接查看检索到的参考文档
- 答案可溯源,增强可信度
- 成本较低
- 只需支付 Embedding 和 LLM 调用费用
- 无需 GPU 训练资源
- 部署灵活
- 支持多种向量数据库(Milvus、Pinecone、Chroma 等)
- 可轻松扩展知识库规模
❌ 局限
- 依赖检索质量
- 检索效果差,整体性能受限
- 对 Embedding 模型敏感
- 上下文长度限制
- LLM 上下文窗口有限
- 无法一次性利用大量检索文档
- 响应延迟
- 检索+生成的两阶段流程
- 延迟高于纯生成式模型
- 复杂推理受限
- 对于需要复杂推理的任务
- 可能无法充分利用分散的检索片段
三、大模型微调原理详解
3.1 什么是微调
微调(Fine-tuning) 是在预训练模型的基础上,使用特定领域或任务的数据进行额外训练,调整模型参数,使模型适应目标任务。
核心思想:通过训练,让模型"学会"特定任务的知识和模式。
类比理解:就像一个已经学过基础课程的学生,再参加专业培训来掌握特定技能。
3.2 常见微调方法
全参数微调(Full Fine-tuning)
训练模型的所有参数。
# 全参数微调示例(伪代码)
from transformers import AutoModelForSequenceClassification, Trainer
model = AutoModelForSequenceClassification.from_pretrained("base-model")
trainer = Trainer(
model=model,
train_dataset=train_data,
args=TrainingArguments(
learning_rate=2e-5,
num_train_epochs=3,
per_device_train_batch_size=8
)
)
trainer.train()
问题:需要大量 GPU 显存,训练成本高,容易过拟合。
LoRA(Low-Rank Adaptation)
核心思想:不直接修改原始权重,而是添加低秩矩阵来捕捉任务特定知识。
原始权重 W:保持冻结
新增权重 ΔW = B × A,其中 B ∈ R^(d×r), A ∈ R^(r×k), r << min(d,k)
最终权重:W' = W + ΔW
# LoRA 微调示例
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=8, # 低秩维度
lora_alpha=16, # 缩放因子
target_modules=["q_proj", "v_proj"], # 应用 LoRA 的层
lora_dropout=0.05,
task_type="CAUSAL_LM"
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# 输出: trainable params: 0.1% || all params: 6.7B || trainable%: 0.1%
优势:大幅减少可训练参数,节省显存,训练速度快。
QLoRA(Quantized LoRA)
在 LoRA 基础上,对模型进行量化,进一步减少显存占用。
# QLoRA 微调示例
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_kbit_training
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
"base-model",
quantization_config=quantization_config,
device_map="auto"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
优势:可在消费级 GPU(如 RTX 3090)上微调 70B 参数的模型。
对比表格
| 方法 | 可训练参数 | 显存需求 | 训练速度 | 效果 |
|---|---|---|---|---|
| 全参数微调 | 100% | 非常高 | 慢 | 最好 |
| LoRA | 0.1-5% | 高 | 快 | 接近全参数 |
| QLoRA | 0.1-5% | 中等 | 中等 | 接近 LoRA |
3.3 微调优势与局限
✅ 优势
- 任务适应性强
- 可以学习特定领域的表达方式
- 输出格式更加可控
- 推理速度快
- 训练完成后是纯生成模型
- 无需额外的检索步骤
- 个性化程度高
- 可以学习复杂的指令遵循模式
- 适合需要强一致性输出的场景
❌ 局限
- 更新成本高
- 每次知识更新都需要重新训练
- 训练需要 GPU 资源
- 可能产生幻觉
- 模型可能产生与训练数据不一致的输出
- 可解释性不如 RAG
- 训练不稳定
- 学习率、epoch 等超参敏感
- 需要一定调参经验
- 灾难性遗忘
- 新任务可能覆盖原有能力
- 需要混合训练策略
四、技术选型指南
4.1 何时选择 RAG
推荐场景:
- 知识库频繁更新
- 产品文档、新闻资讯、客服知识库
- 需要实时反映最新信息
- 需要答案可溯源
- 法律、医疗、金融等高风险领域
- 需要展示参考来源
- 多源异构数据
- 来自不同格式、不同来源的信息
- 需要统一检索和整合
- 预算有限
- 无 GPU 训练资源
- 希望降低运营成本
判断标准:
✓ 需要处理大量外部知识
✓ 知识库需要频繁更新
✓ 需要可解释性和可溯源性
✓ 没有 GPU 训练资源
✓ 延迟要求不是特别高
4.2 何时选择微调
推荐场景:
- 任务模式固定
- 特定的输出格式、风格要求
- 复杂的指令遵循任务
- 领域专业性强
- 需要学习专业术语和表达方式
- 特定行业的复杂推理
- 推理延迟要求高
- 实时对话系统
- 高并发场景
- 数据量充足
- 有大量高质量的训练数据
- 能够承担训练成本
判断标准:
✓ 任务模式相对固定
✓ 需要复杂指令遵循能力
✓ 推理延迟要求严格
✓ 训练数据充足
✓ 有 GPU 训练资源
4.3 混合策略
实际应用中,RAG 和微调并非互斥,可以结合使用:
┌─────────────────────────────────────────────┐
│ 混合策略架构 │
├─────────────────────────────────────────────┤
│ │
│ ┌─────────┐ │
│ │ 用户输入 │ │
│ └────┬────┘ │
│ ↓ │
│ ┌────┴────┐ │
│ │ 意图判断 │ │
│ └────┬────┘ │
│ ↓ │
│ ┌────┴────┐ ┌────────────┐ │
│ │ 知识密集 │────→│ RAG 模块 │ │
│ └────┬────┘ └──────┬─────┘ │
│ │ ↓ │
│ ↓ ┌────────────┐ │
│ ┌────┴────┐ │ 检索知识库 │ │
│ │ 任务密集 │────→└──────┬─────┘ │
│ └────┬────┘ ↓ │
│ ↓ ┌────────────┐ │
│ ┌────┴────┐ │ 微调模型 │ │
│ │ 融合层 │←────│ 生成响应 │ │
│ └────┬────┘ └────────────┘ │
│ ↓ │
│ ┌────┴────┐ │
│ │ 最终输出 │ │
│ └─────────┘ │
└─────────────────────────────────────────────┘
典型模式:
- 微调基础模型 + RAG 增强
- 先微调模型学习领域知识
- 运行时结合 RAG 提供最新信息
- 微调 Embedding + RAG
- 微调 Embedding 模型提升检索质量
- 结合 RAG 提供上下文
- 微调 Router + 路由选择
- 微调一个小模型作为 Router
- 决定使用 RAG 还是微调模型
五、实战代码示例
5.1 RAG 完整实现
环境准备
pip install langchain sentence-transformers chromadb openai python-dotenv
完整代码
"""
RAG 完整实现示例
使用 LangChain + ChromaDB + OpenAI
"""
import os
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
class RAGSystem:
def __init__(self, openai_api_key: str, persist_directory: str = "./chroma_db"):
self.persist_directory = persist_directory
# 初始化 Embedding 模型
self.embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
# 初始化 LLM
self.llm = ChatOpenAI(
model_name="gpt-3.5-turbo",
temperature=0.7,
openai_api_key=openai_api_key
)
self.vectorstore = None
self.qa_chain = None
def load_documents(self, file_path: str):
"""加载文档"""
loader = TextLoader(file_path, encoding='utf-8')
documents = loader.load()
return documents
def split_documents(self, documents, chunk_size: int = 500, chunk_overlap: int = 50):
"""文档切分"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len
)
texts = text_splitter.split_documents(documents)
return texts
def create_vectorstore(self, texts):
"""创建向量数据库"""
self.vectorstore = Chroma.from_documents(
documents=texts,
embedding=self.embeddings,
persist_directory=self.persist_directory
)
return self.vectorstore
def setup_qa_chain(self, search_kwargs: dict = None):
"""设置问答链"""
if search_kwargs is None:
search_kwargs = {"k": 3} # 检索最相关的3个文档块
# 自定义提示词模板
prompt_template = """基于以下参考资料回答问题。如果资料中没有相关信息,请说明。
参考资料:
{context}
用户问题:{question}
答案:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vectorstore.as_retriever(
search_kwargs=search_kwargs
),
chain_type_kwargs={"prompt": PROMPT},
return_source_documents=True
)
return self.qa_chain
def query(self, question: str):
"""执行查询"""
if self.qa_chain is None:
raise ValueError("QA chain not initialized. Call setup_qa_chain() first.")
result = self.qa_chain({"query": question})
return {
"answer": result["result"],
"source_documents": result["source_documents"]
}
def main():
# 初始化 RAG 系统
rag = RAGSystem(
openai_api_key=os.getenv("OPENAI_API_KEY"),
persist_directory="./my_knowledge_base"
)
# 加载和处理文档
documents = rag.load_documents("knowledge_base.txt")
texts = rag.split_documents(documents)
rag.create_vectorstore(texts)
rag.setup_qa_chain()
# 执行查询
question = "产品的主要功能有哪些?"
result = rag.query(question)
print(f"问题:{question}")
print(f"答案:{result['answer']}")
print(f"\n参考来源:")
for i, doc in enumerate(result['source_documents'], 1):
print(f"{i}. {doc.page_content[:100]}...")
if __name__ == "__main__":
main()
使用说明
- 创建
knowledge_base.txt文件,放入你要处理的知识库内容 - 设置环境变量
OPENAI_API_KEY - 运行脚本即可实现基于知识库的问答
5.2 LoRA 微调完整实现
环境准备
pip install transformers peft datasets accelerate bitsandbytes
完整代码
"""
LoRA 微调完整实现示例
使用 HuggingFace Transformers + PEFT
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset
import json
class LoRAFineTuner:
def __init__(
self,
model_name: str = "THUDM/chatglm2-6b",
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05
):
self.model_name = model_name
self.tokenizer = None
self.model = None
self.lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
)
def load_model(self, use_4bit: bool = False):
"""加载模型"""
print(f"Loading model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
if use_4bit:
# QLoRA 配置
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
self.model = prepare_model_for_kbit_training(self.model)
else:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto",
trust_remote_code=True
)
print("Model loaded successfully!")
return self.model, self.tokenizer
def prepare_lora_model(self):
"""应用 LoRA"""
self.model = get_peft_model(self.model, self.lora_config)
self.model.print_trainable_parameters()
return self.model
def tokenize_function(self, examples, max_length: int = 512):
"""Tokenize 数据"""
result = self.tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding="max_length"
)
result["labels"] = result["input_ids"].copy()
return result
def prepare_dataset(self, data_path: str):
"""准备数据集"""
# 从 JSONL 文件加载数据
texts = []
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
# 格式:[系统提示] 用户: 问题\n助手: 回答
text = f"{data.get('instruction', '')} User: {data['input']}\nAssistant: {data['output']}"
texts.append(text)
dataset = Dataset.from_dict({"text": texts})
tokenized_dataset = dataset.map(
self.tokenize_function,
batched=True,
remove_columns=dataset.column_names
)
return tokenized_dataset
def train(
self,
train_dataset,
output_dir: str = "./lora_model",
num_train_epochs: int = 3,
per_device_train_batch_size: int = 4,
gradient_accumulation_steps: int = 4,
learning_rate: float = 2e-4,
warmup_steps: int = 100,
logging_steps: int = 10,
save_steps: int = 100
):
"""执行训练"""
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
logging_steps=logging_steps,
save_steps=save_steps,
fp16=True,
optim="paged_adamw_8bit",
report_to="tensorboard"
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False # Causal LM,不需要 MLM
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator
)
print("Starting training...")
trainer.train()
# 保存 LoRA 权重
self.model.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")
return trainer
def chat(self, prompt: str, max_length: int = 512):
"""使用微调后的模型对话"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
temperature=0.7,
top_p=0.9
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def main():
# 初始化微调器
trainer = LoRAFineTuner(
model_name="microsoft/phi-2", # 使用较小的模型作为示例
lora_r=8,
lora_alpha=16
)
# 加载模型(根据显存调整是否使用 QLoRA)
trainer.load_model(use_4bit=True)
trainer.prepare_lora_model()
# 准备数据集(需要准备 JSONL 格式的训练数据)
# 格式:{"instruction": "系统指令", "input": "用户输入", "output": "模型输出"}
train_dataset = trainer.prepare_dataset("train_data.jsonl")
# 训练
trainer.train(
train_dataset=train_dataset,
output_dir="./my_lora_model",
num_train_epochs=3,
per_device_train_batch_size=2
)
# 测试对话
response = trainer.chat("用户: 你好\n助手:")
print(f"Response: {response}")
if __name__ == "__main__":
main()
训练数据格式示例
创建 train_data.jsonl:
{"instruction": "你是一个专业的客服助手", "input": "如何重置密码?", "output": "您可以通过以下步骤重置密码:1. 点击登录页面的'忘记密码'链接;2. 输入您的注册邮箱;3. 查收邮件并点击验证链接;4. 设置新密码。如果有任何问题,请联系我们的客服团队。"}
{"instruction": "你是一个专业的客服助手", "input": "产品支持哪些支付方式?", "output": "我们支持多种支付方式,包括:信用卡(Visa、MasterCard)、借记卡、支付宝、微信支付以及银行转账。您可以在结账页面选择最方便的方式完成支付。"}
六、总结与建议
核心对比
| 维度 | RAG | 微调 |
|---|---|---|
| 知识更新 | 实时更新,无需训练 | 需要重新训练 |
| 成本 | Embedding + LLM 调用 | GPU 训练资源 |
| 可解释性 | 高,可溯源 | 低,隐式学习 |
| 任务适配 | 通用知识检索 | 特定任务优化 |
| 推理延迟 | 较高(检索+生成) | 低(纯生成) |
| 部署复杂度 | 中(需向量数据库) | 低(单一模型) |
选型建议
优先选择 RAG:
- 知识库需要频繁更新
- 需要答案可溯源
- 预算有限,无 GPU
- 数据源多样化
优先选择微调:
- 任务模式固定
- 需要复杂的指令遵循
- 推理延迟要求高
- 有充足的训练数据
考虑混合策略:
- 知识密集且任务复杂
- 需要兼顾实时性和专业性
- 有一定的技术储备
实践建议
- 从小规模开始
- 先用 RAG 验证需求
- 再根据效果决定是否微调
- 重视数据质量
- 无论 RAG 还是微调,数据质量都是关键
- 做好数据清洗和标注
- 建立评估体系
- 设计合理的评估指标
- A/B 测试不同方案
- 持续迭代优化
- 根据用户反馈优化
- 保持模型的知识更新
希望本文能帮助你更好地理解 RAG 和微调技术,并在实际项目中做出正确的技术选型!
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)