04-RAG 项目
·
一、项目介绍
1、RAG 的本质
RAG即检索、增强和生成,其主要分为2条线:
•离线处理:向私有知识库(向量存储)源源不断添加私有知识文档。•向知识库添加来自未来的知识文档(基于模型训练完成时间)•向模型添加私有知识文档•给出模型参考资料,规避模型幻觉(一本正经的胡说八道)•在线处理:用户提问会先基于私有知识库做检索,获取参考资料,同步组装新提示词询问大模型获取结果。

2、项目
本次项目以"某东商品衣服"为例,以衣服属性构建本地知识。使用者可以自由更新本地知识,用户问题的答案也是基于本地知识生成的。

3、目录结构

二、项目实现
离线流程


1、文件上传实现
使用 streamlit 完成简单的页面搭载
一级标题 、
# 添加网页标题 st.title("知识库更新服务")二级标题和正文 、
st.subheader(f"文件名:{file_name}") st.write(f"格式:{file_type} | 大小:{file_size:.2f} KB")文件上传
# file_uploader uploader_file = st.file_uploader( "请上传TXT文件", type=['txt'], accept_multiple_files=False, # False表示仅接受一个文件的上传 )

"""
基于Streamlit完成WEB网页上传服务
pip install streamlit
Streamlit:当WEB页面元素发生变化,则代码重新执行一遍
"""
import time
import streamlit as st
from knowledge_base import KnowledgeBaseService
# 添加网页标题
st.title("知识库更新服务")
# file_uploader
uploader_file = st.file_uploader(
"请上传TXT文件",
type=['txt'],
accept_multiple_files=False, # False表示仅接受一个文件的上传
)
# session_state就是一个字典
if "service" not in st.session_state:
st.session_state["service"] = KnowledgeBaseService()
if uploader_file is not None:
# 提取文件的信息
file_name = uploader_file.name
file_type = uploader_file.type
file_size = uploader_file.size / 1024 # KB
st.subheader(f"文件名:{file_name}")
st.write(f"格式:{file_type} | 大小:{file_size:.2f} KB")
# get_value -> bytes -> decode('utf-8')
text = uploader_file.getvalue().decode("utf-8")
with st.spinner("载入知识库中。。。"): # 在spinner内的代码执行过程中,会有一个转圈动画
time.sleep(1)
result = st.session_state["service"].upload_by_str(text, file_name)
st.write(result)
2、md5检查文献是否存入
md5,算法,是将字符串转成 32位 的16进制,并且很快,用来判断是否重复很合适
def check_md5(md5_str: str):
"""检查传入的md5字符串是否已经被处理过了
return False(md5未处理过) True(已经处理过,已有记录)
"""
if not os.path.exists(config.md5_path):
# if进入表示文件不存在,那肯定没有处理过这个md5了
open(config.md5_path, 'w', encoding='utf-8').close()
return False
else:
for line in open(config.md5_path, 'r', encoding='utf-8').readlines():
line = line.strip() # 处理字符串前后的空格和回车
if line == md5_str:
return True # 已处理过
return False
def save_md5(md5_str: str):
"""将传入的md5字符串,记录到文件内保存"""
with open(config.md5_path, 'a', encoding="utf-8") as f:
f.write(md5_str + '\n')
def get_string_md5(input_str: str, encoding='utf-8'):
"""将传入的字符串转换为md5字符串"""
# 将字符串转换为bytes字节数组
str_bytes = input_str.encode(encoding=encoding)
# 创建md5对象
md5_obj = hashlib.md5() # 得到md5对象
md5_obj.update(str_bytes) # 更新内容(传入即将要转换的字节数组)
md5_hex = md5_obj.hexdigest() # 得到md5的十六进制字符串
return md5_hex
3、KnowledgeBaseService
完成文件存入向量数据库的流程
class KnowledgeBaseService(object):
def __init__(self):
# 如果文件夹不存在则创建,如果存在则跳过
os.makedirs(config.persist_directory, exist_ok=True)
self.chroma = Chroma(
collection_name=config.collection_name, # 数据库的表名
embedding_function=DashScopeEmbeddings(model="text-embedding-v4"),
persist_directory=config.persist_directory, # 数据库本地存储文件夹
) # 向量存储的实例 Chroma向量库对象
self.spliter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size, # 分割后的文本段最大长度
chunk_overlap=config.chunk_overlap, # 连续文本段之间的字符重叠数量
separators=config.separators, # 自然段落划分的符号
length_function=len, # 使用Python自带的len函数做长度统计的依据
) # 文本分割器的对象
def upload_by_str(self, data: str, filename):
"""将传入的字符串,进行向量化,存入向量数据库中"""
# 先得到传入字符串的md5值
md5_hex = get_string_md5(data)
if check_md5(md5_hex):
return "[跳过]内容已经存在知识库中"
if len(data) > config.max_split_char_number:
knowledge_chunks: list[str] = self.spliter.split_text(data)
else:
knowledge_chunks = [data]
metadata = {
"source": filename,
# 2025-01-01 10:00:00
"create_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"operator": "小曹",
}
self.chroma.add_texts( # 内容就加载到向量库中了
# iterable -> list \ tuple
knowledge_chunks,
metadatas=[metadata for _ in knowledge_chunks],
)
#
save_md5(md5_hex)
return "[成功]内容已经成功载入向量库"
4、文件上传配合向量存入
使用 streamlit 完成简单的页面搭载
这个页面每刷新一次/上传一次文件,整个代码都会重复执行一次,
这样就不会随着页面的刷新而丢失回话信息
# session_state就是一个字典 if "service" not in st.session_state: st.session_state["service"] = KnowledgeBaseService()

5、配置文件
md5_path = "./md5.text"
# Chroma
collection_name = "rag"
persist_directory = "./chroma_db"
# spliter
chunk_size = 1000
chunk_overlap = 100
separators = ["\n\n", "\n", ".", "!", "?", "。", "!", "?", " ", ""]
max_split_char_number = 1000 # 文本分割的阈值
#
similarity_threshold = 1 # 检索返回匹配的文档数量
embedding_model_name = "text-embedding-v4"
chat_model_name = "qwen3-max"
session_config = {
"configurable": {
"session_id": "user_001",
}
}
在线流程


1、VectorStoreService
from langchain_community.vectorstores import Chroma
import config_data as config
class VectorStoreService(object):
def __init__(self, embedding):
"""
:param embedding: 嵌入模型的传入
"""
self.embedding = embedding
self.vector_store = Chroma(
collection_name=config.collection_name,
embedding_function=self.embedding,
persist_directory=config.persist_directory,
)
def get_retriever(self):
"""返回向量检索器,方便加入chain"""
return self.vector_store.as_retriever(search_kwargs={"k": config.similarity_threshold})
if __name__ == '__main__':
from langchain_community.embeddings import DashScopeEmbeddings
retriever = VectorStoreService(DashScopeEmbeddings(model="text-embedding-v4")).get_retriever()
res = retriever.invoke("我的体重180斤,尺码推荐")
print(res)
2、RAG的流程
将用户提问与参考资料,组装后交给LLM
存在类型的匹配的问题,所以要做一个类型的转换函数

from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableWithMessageHistory, RunnableLambda
from file_history_store import get_history
from vector_stores import VectorStoreService
from langchain_community.embeddings import DashScopeEmbeddings
import config_data as config
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_models.tongyi import ChatTongyi
def print_prompt(prompt):
print("="*20)
print(prompt.to_string())
print("="*20)
return prompt
class RagService(object):
def __init__(self):
self.vector_service = VectorStoreService(
embedding=DashScopeEmbeddings(model=config.embedding_model_name)
)
self.prompt_template = ChatPromptTemplate.from_messages(
[
("system", "以我提供的已知参考资料为主,"
"简洁和专业的回答用户问题。参考资料:{context}。"),
("system", "并且我提供用户的对话历史记录,如下:"),
MessagesPlaceholder("history"),
("user", "请回答用户提问:{input}")
]
)
self.chat_model = ChatTongyi(model=config.chat_model_name)
self.chain = self.__get_chain()
def __get_chain(self):
"""获取最终的执行链"""
retriever = self.vector_service.get_retriever()
def format_document(docs: list[Document]):
if not docs:
return "无相关参考资料"
formatted_str = ""
for doc in docs:
formatted_str += f"文档片段:{doc.page_content}\n文档元数据:{doc.metadata}\n\n"
return formatted_str
def format_for_retriever(value: dict) -> str:
return value["input"]
def format_for_prompt_template(value):
# {input, context, history}
new_value = {}
new_value["input"] = value["input"]["input"]
new_value["context"] = value["context"]
new_value["history"] = value["input"]["history"]
return new_value
chain = (
{
"input": RunnablePassthrough(),
"context": RunnableLambda(format_for_retriever) | retriever | format_document
} | RunnableLambda(format_for_prompt_template) | self.prompt_template | print_prompt | self.chat_model | StrOutputParser()
)
conversation_chain = RunnableWithMessageHistory(
chain,
get_history,
input_messages_key="input",
history_messages_key="history",
)
return conversation_chain
if __name__ == '__main__':
# session id 配置
session_config = {
"configurable": {
"session_id": "user_001",
}
}
res = RagService().chain.invoke({"input": "针织毛衣如何保养?"}, session_config)
print(res)
3、搭建聊天的网页
import time
from rag import RagService
import streamlit as st
import config_data as config
# 标题
st.title("智能客服")
st.divider() # 分隔符
if "message" not in st.session_state:
st.session_state["message"] = [{"role": "assistant", "content": "你好,有什么可以帮助你?"}]
if "rag" not in st.session_state:
st.session_state["rag"] = RagService()
for message in st.session_state["message"]:
st.chat_message(message["role"]).write(message["content"])
# 在页面最下方提供用户输入栏
prompt = st.chat_input()
if prompt:
# 在页面输出用户的提问
st.chat_message("user").write(prompt)
st.session_state["message"].append({"role": "user", "content": prompt})
ai_res_list = []
with st.spinner("AI思考中..."):
res_stream = st.session_state["rag"].chain.stream({"input": prompt}, config.session_config)
# yield
def capture(generator, cache_list):
for chunk in generator:
cache_list.append(chunk)
yield chunk
st.chat_message("assistant").write_stream(capture(res_stream, ai_res_list))
st.session_state["message"].append({"role": "assistant", "content": "".join(ai_res_list)})
# ["a", "b", "c"] "".join(list) -> abc
# ["a", "b", "c"] ",".join(list) -> a,b,c
三、成品
streamlit run .\app_qa.py

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)