LangChain 索引增强对话链详解

本文介绍如何使用 ConversationalRetrievalChain 构建完整的 RAG(检索增强生成)对话系统。

1. 什么是 ConversationalRetrievalChain?

1.1 概念

ConversationalRetrievalChain 是 LangChain 提供的开箱即用的对话检索链,它将以下组件整合在一起:

组件 作用
Retriever 从向量数据库检索相关文档
LLM 根据检索结果生成答案
Memory 存储对话历史,维持上下文

1.2 对比:手动实现 vs ConversationalRetrievalChain

# 手动实现(繁琐)
retrieved_docs = retriever.invoke(query)
context = "\n".join([doc.page_content for doc in retrieved_docs])
prompt = f"基于上下文回答:{context}\n\n问题:{query}"
answer = llm.invoke(prompt)

# ConversationalRetrievalChain(简洁)
qa = ConversationalRetrievalChain.from_llm(llm, retriever, memory)
answer = qa.invoke({"question": query})

2. 核心组件

2.1 完整代码框架

from langchain_classic.chains import ConversationalRetrievalChain
from langchain_classic.memory import ConversationBufferMemory
from langchain_classic.document_loaders import TextLoader
from langchain_classic.vectorstores import FAISS
from langchain_classic.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

2.2 组件详解

导入 说明
ConversationalRetrievalChain 对话检索链
ConversationBufferMemory 对话记忆
TextLoader 文档加载器
FAISS 向量数据库
RecursiveCharacterTextSplitter 文本分块器
ChatOpenAI ChatGPT 模型
OpenAIEmbeddings 嵌入模型

3. 代码解析

3.1 文档加载与分块

# 加载文档
loader = TextLoader("./demo.txt", encoding="utf-8")
docs = loader.load()

# 文本分块
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=40,
    separators=["\n", "。", "!", "?", ",", "、", ""]
)
texts = text_splitter.split_documents(docs)

3.2 向量数据库与检索器

# 嵌入模型
embeddings_model = OpenAIEmbeddings(
    model="text-embedding-3-large",
    openai_api_key="sk-xxxx",
    base_url="https://api.xxx.com/v1"
)

# 构建向量数据库
db = FAISS.from_documents(texts, embeddings_model)

# 创建检索器(返回 top-3 相关文档)
retriever = db.as_retriever(search_kwargs={"k": 3})

3.3 LLM 模型配置

model = ChatOpenAI(
    model="gpt-3.5-turbo",      # 模型名称
    openai_api_key="sk-xxxx",
    openai_api_base="https://api.xxx.com/v1",
    temperature=0.7              # 创造性参数
)
参数 说明 建议值
temperature 控制回答随机性 0.0-1.0
0.0 确定性强,一致性高 事实问答
0.7 平衡创造性和准确性 一般对话

3.4 对话记忆配置

memory = ConversationBufferMemory(
    memory_key="chat_history",    # 历史记录 key
    return_messages=True,        # 返回消息对象
    output_key="answer"           # 输出 key
)
参数 说明
memory_key 在 prompt 中引用历史记录的变量名
return_messages True 返回消息对象,False 返回字符串
output_key 指定哪个输出字段存入记忆

3.5 创建对话链

qa = ConversationalRetrievalChain.from_llm(
    llm=model,           # LLM 模型
    retriever=retriever, # 检索器
    memory=memory,       # 对话记忆
    verbose=False        # True 显示详细过程
)

4. 对话流程

4.1 单轮对话

question = "卢浮宫这个名字怎么来的?"
result = qa.invoke({"question": question})

print(f"问题: {question}")
print(f"回答: {result['answer']}")

4.2 多轮对话(上下文理解)

# 第一轮
question1 = "卢浮宫这个名字怎么来的?"
result1 = qa.invoke({"question": question1})

# 第二轮(可引用上文)
question2 = "对应的拉丁语是什么呢?"
result2 = qa.invoke({"question": question2})

多轮对话的魔力

用户: 卢浮宫这个名字怎么来的?     → 第一轮检索
AI:  卢浮宫...源自...

用户: 对应的拉丁语是什么呢?       → 第二轮:自动理解"对应的"指"卢浮宫"
AI:  卢浮宫的拉丁语是...           → 利用第一轮的上下文 + 检索

4.3 查看对话历史

print(memory.chat_memory.messages)

5. 带来源文档的问答

5.1 启用 source_documents

qa_with_source = ConversationalRetrievalChain.from_llm(
    llm=model,
    retriever=retriever,
    memory=memory,
    return_source_documents=True  # 返回参考文档
)

result = qa_with_source.invoke({"question": "卢浮宫在什么时候对公众开放?"})
print(f"回答: {result['answer']}")
print("参考来源:")
for i, doc in enumerate(result['source_documents'], 1):
    print(f"{i}. {doc.page_content[:200]}...")

5.2 返回值说明

result = qa.invoke({"question": "..."})
字段 说明
result['answer'] LLM 生成的回答
result['source_documents'] 参考文档列表(需启用 return_source_documents)
result['chat_history'] 对话历史

6. 完整示例

# 导入
from langchain_classic.chains import ConversationalRetrievalChain
from langchain_classic.memory import ConversationBufferMemory
from langchain_classic.document_loaders import TextLoader
from langchain_classic.vectorstores import FAISS
from langchain_classic.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

# 加载文档
loader = TextLoader("./demo.txt", encoding="utf-8")
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=40,
    separators=["\n", "。", "!", "?", ",", "、", ""]
)

texts = text_splitter.split_documents(docs)

# 嵌入模型
embeddings_model = OpenAIEmbeddings(
    model="text-embedding-3-large",
    openai_api_key="xxxx",
    openai_api_base="https://api.xxxx"
)

# 向量数据库
db = FAISS.from_documents(texts, embeddings_model)
retriever = db.as_retriever(search_kwargs={"k": 3})  # 设置返回3个相关文档

# 模型
model = ChatOpenAI(
    model="gpt-3.5-turbo", 
    openai_api_key="xxxx",
    openai_api_base="https://xxxx",
    temperature=0.7  # 控制回答的创造性
)

# 创建 memory - 使用正确的方式避免警告
memory = ConversationBufferMemory(
    memory_key="chat_history",  # 对话历史的key
    return_messages=True,        # 返回消息对象
    output_key="answer"          # 输出key
)

# 创建对话链
qa = ConversationalRetrievalChain.from_llm(
    llm=model,
    retriever=retriever,
    memory=memory,
    verbose=False  # 设置为True可以看到详细过程
)

# 开始对话
print("=== 第一轮对话 ===")
question1 = "卢浮宫这个名字怎么来的?"
result1 = qa.invoke({"question": question1})
print(f"问题: {question1}")
print(f"回答: {result1['answer']}\n")

print("=== 第二轮对话(上下文理解)===")
question2 = "对应的拉丁语是什么呢?"
result2 = qa.invoke({"question": question2})
print(f"问题: {question2}")
print(f"回答: {result2['answer']}\n")

# 查看对话历史
print("=== 对话历史 ===")
print(memory.chat_memory.messages)

# 如果需要查看来源文档
print("\n=== 带来源文档的问答 ===")
qa_with_source = ConversationalRetrievalChain.from_llm(
    llm=model,
    retriever=retriever,
    memory=memory,  # 使用相同的memory,保持上下文
    return_source_documents=True
)

question3 = "卢浮宫在什么时候对公众开放?"
result3 = qa_with_source.invoke({"question": question3})
print(f"问题: {question3}")
print(f"回答: {result3['answer']}")
print("\n参考来源:")
for i, doc in enumerate(result3['source_documents'], 1):
    print(f"{i}. {doc.page_content[:200]}...")  # 只显示前200字符

结果:
在这里插入图片描述


7. 常见问题

Q1: memory 有什么作用?

  • 存储对话历史
  • 让 LLM 理解上下文(“那”、"它"指代什么)
  • 实现多轮对话

Q2: return_messages=True vs False 区别?

# return_messages=True
memory.chat_memory.messages  # [HumanMessage(...), AIMessage(...)]

# return_messages=False
memory.chat_memory.messages  # "Human: xxx\nAI: yyy"

Q3: 如何清除对话历史?

memory.clear()

Q4: verbose 参数有什么用?

qa = ConversationalRetrievalChain.from_llm(..., verbose=True)

开启后可以看到完整的 chain 执行过程,便于调试。

Q5: 如何自定义 prompt?

qa = ConversationalRetrievalChain.from_llm(
    llm=model,
    retriever=retriever,
    memory=memory,
    condense_question_prompt=CustomPrompt,  # 自定义问题改写 prompt
    qa_prompt=CustomPrompt                   # 自定义问答 prompt
)
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐