AI大模型RAG与Agent开发学习
RAG项目
RAG项目案例介绍
RAG即检索、增强和生成,其主要分为2条线:
- 离线处理:向私有知识库(向量存储)源源不断添加私有知识文档。
- 向知识库添加来自未来的知识文档(基于模型训练完成时间)
- 向模型添加私有知识文档
- 给出模型参考资料,规避模型幻觉(一本正经的胡说八道)
- 在线处理:用户提问会先基于私有知识库做检索,获取参考资料,同步组装新提示词询问大模型获取结果。

项目需求和思路:
本次项目以"某东商品衣服"为例,以衣服属性构建本地知识。使用者可以自由更新本地知识,用户问题的答案也是基于本地知识生成的。

项目中实现代码主要目录:

离线流程:
RAG项目-文本上传WEB服务器
这里主要使用 Streamlit来完成简单的 网站搭建,只需要下载该库,就可以轻松的实现一个 简单的 web框架。
运行命令:streamlit run app_file_loader.py
"""
基于Steamlit 完成WEB网页上传服务 (Web框架)
简要介绍
基于tornada框架 快速搭建Web应用的python库,封装了大量常用组件方法,支持大量数据表、图表等对象的渲染,支持网格化、响应式布局。可以让不了解前端的人快速搭建网站
"""
import streamlit as st
# 添加网页标题
st.title('知识库更新服务')
# file_uploader
uploader_file = st.file_uploader(
"请上传txt文件",
type=["txt"],
accept_multiple_files=False # 仅接受一个文件上传
)
if uploader_file is not None:
# 提取文件信息
file_name = uploader_file.name
file_type = uploader_file.type
file_size = uploader_file.size / 1024 # KB
st.subheader(f"文件名:{file_name}")
st.write(f"格式:{file_type}, 大小:{file_size:.2f} KB")
# 获取文件内容 得到的是字节数组 bytes 需要变成字符串 (.decode("utf-8"))
text = uploader_file.getvalue().decode("utf-8")
st.write(text)
RAG项目-md5工具函数开发
这里主要 使用hashlib库,可以将 字符串转为md5值。
主要流程:对于传入的字符串,进行检查check_md5看是否已经存在过,使用一个md5.txt进行存储,读文件进行检查,要是不存在就保存(save_md5),保存之前需要将字符串转为md5(get_string_md5)
"""
知识库
"""
import os
import config_data as config
import hashlib # 将字符串 转为md5值
def check_md5(md5_str: str):
"""
检查传入的md5字符串是否已经被处理过了
基于 md5.txt文件,检查是否在文件中出现过。或者文件根本不存在 —— 检查是否为新的向量;
return False (表示文件未处理过)
True:(表示文件已经被处理了)
"""
if not os.path.exists(config.md5_path):
# 文件不存在,肯定没有处理过该md5
open(config.md5_path, 'w', encoding='utf-8').close()
return False
else:
for line in open(config.md5_path, 'r', encoding="utf-8").readlines():
line = line.strip() # 处理字符串前后的空格和回车
if line == md5_str:
return True # 已经处理过
return False
def save_md5(md5_str: str):
"""
将传入的md5字符串,记录到文件内保存
:return:
"""
with open(config.md5_path, 'a', encoding='utf-8') as f:
f.write(md5_str + '\n')
def get_string_md5(input_str: str, encoding='utf-8'):
"""
将传入的字符串转换为md5字符串
:return:
"""
# 将字符串转化为 bytes 字节字符(二进制)
str_bytes = input_str.encode(encoding=encoding)
# 创建md5对象
md5_obj = hashlib.md5() # 得到md5对象
md5_obj.update(str_bytes) # 更新内容(传入即将要转换的字节数组)
md5_hex = md5_obj.hexdigest() # 得到md5的十六进制字符串
return md5_hex
class KnowledgeBaseServer(object):
def __init__(self):
self.chroma = None # 向量存储的实例,Chroma向量库对象
self.spliter = None # 文本分割器
def uploader_by_str(self, data, filename):
"""将传入的字符串进行向量化,存入向量数据库中"""
pass
if __name__ == '__main__':
# md5:结果算出来都是 32位的16进制 (方便做记录,效率高)
# r1 = get_string_md5("周杰伦")
# r2 = get_string_md5("周杰伦")
# r3 = get_string_md5("周杰伦2")
#
# print(r1)
# print(r2)
# print(r3)
save_md5("7a8941058aaf4df5147042ce104568da")
print(check_md5('7a8941058aaf4df5147042ce104568da'))
RAG项目-知识库更新服务
核心:写知识库更新服务逻辑,主要完成 KnowledgeBaseServer 中的init,主要流程:输入 -> md5查重 -> 若是新内容,进行文本分割,就存数据库中,否则不处理。
"""
知识库
"""
import os
import config_data as config
import hashlib # 将字符串 转为md5值
from langchain_chroma import Chroma
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from datetime import datetime
def check_md5(md5_str: str):
"""
检查传入的md5字符串是否已经被处理过了
基于 md5.txt文件,检查是否在文件中出现过。或者文件根本不存在 —— 检查是否为新的向量;
return False (表示文件未处理过)
True:(表示文件已经被处理了)
"""
if not os.path.exists(config.md5_path):
# 文件不存在,肯定没有处理过该md5
open(config.md5_path, 'w', encoding='utf-8').close()
return False
else:
for line in open(config.md5_path, 'r', encoding="utf-8").readlines():
line = line.strip() # 处理字符串前后的空格和回车
if line == md5_str:
return True # 已经处理过
return False
def save_md5(md5_str: str):
"""
将传入的md5字符串,记录到文件内保存
:return:
"""
with open(config.md5_path, 'a', encoding='utf-8') as f:
f.write(md5_str + '\n')
def get_string_md5(input_str: str, encoding='utf-8'):
"""
将传入的字符串转换为md5字符串
:return:
"""
# 将字符串转化为 bytes 字节字符(二进制)
str_bytes = input_str.encode(encoding=encoding)
# 创建md5对象
md5_obj = hashlib.md5() # 得到md5对象
md5_obj.update(str_bytes) # 更新内容(传入即将要转换的字节数组)
md5_hex = md5_obj.hexdigest() # 得到md5的十六进制字符串
return md5_hex
class KnowledgeBaseServer(object):
def __init__(self):
# 如果文件夹不存在则 创建,存在则跳过
os.makedirs(config.persist_directory, exist_ok=True)
self.chroma = Chroma(
collection_name=config.collection_name, # 数据库的表名
embedding_function=DashScopeEmbeddings(model="text-embedding-v4"),
persist_directory=config.persist_directory, # 数据库本地存储文件夹
) # 向量存储的实例,Chroma向量库对象
self.spliter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size, # 分割后的文本段最大长度
chunk_overlap=config.chunk_overlap, # 连续文本段之间的字符重叠数量
separators=config.separators, # 自然段落划分的符号
length_function=len, # 使用python自带的len函数 做长度统计的依据
) # 文本分割器
def uploader_by_str(self, data: str, filename):
"""将传入的字符串进行向量化,存入向量数据库中"""
# 先得到传入字符串的md5 值
md5_hex = get_string_md5(data)
if check_md5(md5_hex):
return "[跳过]内容已经存在知识库中"
if len(data) > config.max_split_char_number:
knowledge_chunks = self.spliter.split_text(data)
else:
knowledge_chunks = [data]
metadata = {
"source": filename,
# 2025-10-10 12:12:12
"create_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
"operator": "zz"
}
self.chroma.add_texts( # 将内容加载到向量库中
knowledge_chunks, # iterable -> list \ tuple
metadatas=[metadata for _ in knowledge_chunks],
)
save_md5(md5_hex)
return "[成功]内容已经成功载入向量库"
if __name__ == '__main__':
# md5:结果算出来都是 32位的16进制 (方便做记录,效率高)
# r1 = get_string_md5("周杰伦")
# r2 = get_string_md5("周杰伦")
# r3 = get_string_md5("周杰伦2")
#
# print(r1)
# print(r2)
# print(r3)
# save_md5("7a8941058aaf4df5147042ce104568da")
# print(check_md5('7a8941058aaf4df5147042ce104568da'))
service = KnowledgeBaseServer()
r = service.uploader_by_str("周杰伦2", "testfile")
print(r)
RAG项目-完成离线流程开发
将文件上传代码和知识库更新 合起来;其中 app_file_uploader.py脚本是控制streamlit WEB页面进行上传页面;knowledge_base.py是 接受文本,分割文本,传到知识库(有查重功能)
为了不让streamlit页面一直刷新,状态不能保存,使用 st.session_state进行保存状态。
完整的 app_file_uploader代码如下:
"""
基于Steamlit 完成WEB网页上传服务 (Web框架)
简要介绍
基于tornada框架 快速搭建Web应用的python库,封装了大量常用组件方法,支持大量数据表、图表等对象的渲染,支持网格化、响应式布局。可以让不了解前端的人快速搭建网站
streamlit: 当WEB页面元素发生变化,代码重新执行一次
—— 缺点:造成状态的丢失
"""
import streamlit as st
from knowledge_base import KnowledgeBaseServer
import time
# 添加网页标题
st.title('知识库更新服务')
# file_uploader
uploader_file = st.file_uploader(
"请上传txt文件",
type=["txt"],
accept_multiple_files=False # 仅接受一个文件上传
)
count = 0
# st.session_state() # 就是是字典 —— 状态不会刷新
if "service" not in st.session_state:
st.session_state["service"] = KnowledgeBaseServer()
if uploader_file is not None:
# 提取文件信息
file_name = uploader_file.name
file_type = uploader_file.type
file_size = uploader_file.size / 1024 # KB
st.subheader(f"文件名:{file_name}")
st.write(f"格式:{file_type}, 大小:{file_size:.2f} KB")
# 获取文件内容 得到的是字节数组 bytes 需要变成字符串 (.decode("utf-8"))
text = uploader_file.getvalue().decode("utf-8")
# st.write(text)
with st.spinner("载入知识库中。。。"): # 在spinner 内的代码 执行过程中,会有转圈动画
time.sleep(1)
results = st.session_state["service"].uploader_by_str(text, file_name)
st.write(results)
RAG项目-在线流程向量存储服务代码

向量存储代码实现:
from langchain_chroma import Chroma
from openai.types import embedding
import config_data
class VectorStoreService(object):
def __init__(self, embedding):
"""
:param embedding: 嵌入模型的传入
"""
self.embedding = embedding
self.vector_store = Chroma(
collection_name=config_data.collection_name,
embedding_function=self.embedding,
persist_directory=config_data.persist_directory,
)
def get_retriever(self):
"""
返回向量检索器,方便加入chain
:return:
"""
return self.vector_store.as_retriever(search_kwargs={"k": config_data.similarity_threshold})
if __name__ == '__main__':
from langchain_community.embeddings import DashScopeEmbeddings
retriever = VectorStoreService(DashScopeEmbeddings(model='text-embedding-v4')).get_retriever()
res = retriever.invoke("我的体重100斤,尺码推荐")
print(res)
RAG项目-RAG服务核心代码开发
主要涉及rag.py的开发,将 检索内容、用户提问组织新的提示词,并将向量模型和 对话模型 链起来。
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableWithMessageHistory, RunnableLambda
from file_history_store import get_history
from vector_stores import VectorStoreService
from langchain_community.embeddings import DashScopeEmbeddings
import config_data as config
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_models.tongyi import ChatTongyi
def print_prompt(prompt):
print("="*20)
print(prompt.to_string())
print("="*20)
return prompt
class RagService(object):
def __init__(self):
self.vector_service = VectorStoreService(
embedding=DashScopeEmbeddings(model=config.embedding_model_name)
)
self.prompt_template = ChatPromptTemplate.from_messages(
[
("system", "以我提供的已知参考资料为主,"
"简洁和专业的回答用户问题。参考资料:{context}。"),
("system", "并且我提供用户的对话历史记录,如下:"),
MessagesPlaceholder("history"),
("user", "请回答用户提问:{input}")
]
)
self.chat_model = ChatTongyi(model=config.chat_model_name)
self.chain = self.__get_chain()
def __get_chain(self):
"""获取最终的执行链"""
retriever = self.vector_service.get_retriever()
def format_document(docs: list[Document]):
if not docs:
return "无相关参考资料"
formatted_str = ""
for doc in docs:
formatted_str += f"文档片段:{doc.page_content}\n文档元数据:{doc.metadata}\n\n"
return formatted_str
def format_for_retriever(value: dict):
return value["input"]
def format_for_prompt_template(value):
# {input, context, history}
new_value = {}
new_value["input"] = value["input"]["input"]
new_value["context"] = value["context"]
new_value["history"] = value["input"]["history"]
return new_value
chain = (
{
"input": RunnablePassthrough(),
"context": RunnableLambda(format_for_retriever) | retriever | format_document
} | RunnableLambda(format_for_prompt_template) | self.prompt_template | print_prompt | self.chat_model | StrOutputParser()
)
conversation_chain = RunnableWithMessageHistory(
chain,
get_history,
input_messages_key="input",
history_messages_key="history",
)
return conversation_chain
if __name__ == '__main__':
# session_id配置
session_config = {
"configurable": {
"session_id": "user_001"
}
}
# res = RagService().chain.invoke({"input": "我体重 100 斤, 身高155cm,尺码推荐"}, session_config)
res = RagService().chain.invoke({"input": "春天穿什么颜色的衣服?"}, session_config)
print(res)
RAG项目-历史会话记录功能的实现
import os
import json
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
from typing import Sequence
def get_history(session_id):
return FileChatMessageHistory(session_id, "./chat_history")
class FileChatMessageHistory(BaseChatMessageHistory):
def __init__(self, session_id, storage_path):
self.session_id = session_id # 会话ID
self.storage_path = storage_path # 不同会话id的存储路径,所在的文件夹路径
# 完整的文件路径
self.file_path = os.path.join(self.storage_path, self.session_id)
# 确保文件夹是存在的
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
# Sequence序列,类似list,tuple
all_messages = list(self.messages) # 已有的消息列表
all_messages.extend(messages) # 新的消息和已有的融合成一个list
# 将数据同步写入到本地文件中
# 类对象写入文件 -> 二进制
# 为了方便,将BaseMessage消息 转为字典,借助json模块以json字符串写入文件
# message_to_dict : 单个消息对象(BaseMessage类实例) -> 字典
# new_messages = []
# for message in all_messages:
# d = message_to_dict(message)
# new_messages.append(d)
new_messages = [message_to_dict(message) for message in all_messages]
# 将数据写入文件
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump(new_messages, f) # 将整体变成json写入文件中
@property # @property 装饰器将messages方法变成 成员属性用
def messages(self) -> list[BaseMessage]:
# 当前文件内: List[字典], 需要转为 list[BaseMessage]
try:
with open(self.file_path, "r", encoding="utf-8") as f:
messages_data = json.load(f) # 返回List[字典]
return messages_from_dict(messages_data)
except FileNotFoundError:
return []
def clear(self) -> None:
with open(self.file_path, "w", encoding="utf-8") as f:
json.dump([], f) # 清空文件
RAG项目-聊天页面的开发
网页的实现:主要实现 用户提问界面、流式输出、保存历史记录
import streamlit as st
import time
from rag import RagService
import config_data
# 标题
st.title('智能客服问答')
st.divider() # 分隔符
if 'rag_service' not in st.session_state:
st.session_state['rag_service'] = RagService()
if "message" not in st.session_state:
st.session_state["message"] = [{'role': 'assistant', 'content': '你好,有什么可以帮助你?'}]
for message in st.session_state["message"]:
st.chat_message(message['role']).write(message['content'])
# 用户输入框
prompt = st.chat_input()
if prompt:
# 在页面输出
st.chat_message("user").write(prompt)
st.session_state["message"].append({'role': 'user', 'content': prompt})
ai_res_list = []
with st.spinner("AI思考中。。。"):
# res = st.session_state['rag_service'].chain.invoke({"input": prompt}, config_data.session_config)
res_stream = st.session_state['rag_service'].chain.stream({"input": prompt}, config_data.session_config)
# yield
def capture(generator, cache_list):
for chunk in generator:
cache_list.append(chunk)
yield chunk
# st.chat_message("assistant").write(res)
st.chat_message("assistant").write_stream(capture(res_stream, ai_res_list))
# st.session_state["message"].append({'role': 'assistant', 'content': res})
# 历史记录保存
st.session_state["message"].append({'role': 'assistant', 'content': ''.join(ai_res_list)})
# ['a', 'b', 'c'] ''.join(list) -> abc
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)