RAG项目

RAG项目案例介绍

RAG即检索、增强和生成,其主要分为2条线:

  • 离线处理:向私有知识库(向量存储)源源不断添加私有知识文档。
    • 向知识库添加来自未来的知识文档(基于模型训练完成时间)
    • 向模型添加私有知识文档
    • 给出模型参考资料,规避模型幻觉(一本正经的胡说八道)
  • 在线处理:用户提问会先基于私有知识库做检索,获取参考资料,同步组装新提示词询问大模型获取结果。

在这里插入图片描述
项目需求和思路

本次项目以"某东商品衣服"为例,以衣服属性构建本地知识。使用者可以自由更新本地知识,用户问题的答案也是基于本地知识生成的。

在这里插入图片描述
项目中实现代码主要目录:

在这里插入图片描述

离线流程:
在这里插入图片描述

RAG项目-文本上传WEB服务器

这里主要使用 Streamlit来完成简单的 网站搭建,只需要下载该库,就可以轻松的实现一个 简单的 web框架。

运行命令:streamlit run app_file_loader.py

"""
基于Steamlit 完成WEB网页上传服务 (Web框架)

简要介绍
基于tornada框架 快速搭建Web应用的python库,封装了大量常用组件方法,支持大量数据表、图表等对象的渲染,支持网格化、响应式布局。可以让不了解前端的人快速搭建网站
"""
import streamlit as st


# 添加网页标题
st.title('知识库更新服务')

# file_uploader
uploader_file = st.file_uploader(
    "请上传txt文件",
    type=["txt"],
    accept_multiple_files=False # 仅接受一个文件上传
)

if uploader_file is not None:
    # 提取文件信息
    file_name = uploader_file.name
    file_type = uploader_file.type
    file_size = uploader_file.size / 1024 # KB

    st.subheader(f"文件名:{file_name}")
    st.write(f"格式:{file_type}, 大小:{file_size:.2f} KB")

    # 获取文件内容 得到的是字节数组 bytes 需要变成字符串 (.decode("utf-8"))
    text = uploader_file.getvalue().decode("utf-8")
    st.write(text)

RAG项目-md5工具函数开发

这里主要 使用hashlib库,可以将 字符串转为md5值。

主要流程:对于传入的字符串,进行检查check_md5看是否已经存在过,使用一个md5.txt进行存储,读文件进行检查,要是不存在就保存(save_md5),保存之前需要将字符串转为md5(get_string_md5)

"""
知识库

"""
import os
import config_data as config
import hashlib # 将字符串 转为md5值

def check_md5(md5_str: str):
    """
    检查传入的md5字符串是否已经被处理过了
    基于 md5.txt文件,检查是否在文件中出现过。或者文件根本不存在 —— 检查是否为新的向量;
    return False (表示文件未处理过)
    True:(表示文件已经被处理了)
    """
    if not os.path.exists(config.md5_path):
        # 文件不存在,肯定没有处理过该md5
        open(config.md5_path, 'w', encoding='utf-8').close()
        return False
    else:
        for line in open(config.md5_path, 'r', encoding="utf-8").readlines():
            line = line.strip() # 处理字符串前后的空格和回车
            if line == md5_str:
                return True  # 已经处理过
        return False

def save_md5(md5_str: str):
    """
    将传入的md5字符串,记录到文件内保存
    :return:
    """
    with open(config.md5_path, 'a', encoding='utf-8') as f:
        f.write(md5_str + '\n')

def get_string_md5(input_str: str, encoding='utf-8'):
    """
    将传入的字符串转换为md5字符串
    :return:
    """
    # 将字符串转化为 bytes 字节字符(二进制)
    str_bytes = input_str.encode(encoding=encoding)
    # 创建md5对象
    md5_obj = hashlib.md5() # 得到md5对象
    md5_obj.update(str_bytes) # 更新内容(传入即将要转换的字节数组)
    md5_hex = md5_obj.hexdigest() # 得到md5的十六进制字符串
    return md5_hex

class KnowledgeBaseServer(object):
    def __init__(self):
        self.chroma = None # 向量存储的实例,Chroma向量库对象
        self.spliter = None # 文本分割器

    def uploader_by_str(self, data, filename):
        """将传入的字符串进行向量化,存入向量数据库中"""
        pass

if __name__ == '__main__':
    # md5:结果算出来都是 32位的16进制 (方便做记录,效率高)
    # r1 = get_string_md5("周杰伦")
    # r2 = get_string_md5("周杰伦")
    # r3 = get_string_md5("周杰伦2")
    #
    # print(r1)
    # print(r2)
    # print(r3)
    save_md5("7a8941058aaf4df5147042ce104568da")
    print(check_md5('7a8941058aaf4df5147042ce104568da'))

RAG项目-知识库更新服务

核心:写知识库更新服务逻辑,主要完成 KnowledgeBaseServer 中的init,主要流程:输入 -> md5查重 -> 若是新内容,进行文本分割,就存数据库中,否则不处理。

"""
知识库

"""
import os
import config_data as config
import hashlib # 将字符串 转为md5值
from langchain_chroma import Chroma
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from datetime import datetime

def check_md5(md5_str: str):
    """
    检查传入的md5字符串是否已经被处理过了
    基于 md5.txt文件,检查是否在文件中出现过。或者文件根本不存在 —— 检查是否为新的向量;
    return False (表示文件未处理过)
    True:(表示文件已经被处理了)
    """
    if not os.path.exists(config.md5_path):
        # 文件不存在,肯定没有处理过该md5
        open(config.md5_path, 'w', encoding='utf-8').close()
        return False
    else:
        for line in open(config.md5_path, 'r', encoding="utf-8").readlines():
            line = line.strip() # 处理字符串前后的空格和回车
            if line == md5_str:
                return True  # 已经处理过
        return False

def save_md5(md5_str: str):
    """
    将传入的md5字符串,记录到文件内保存
    :return:
    """
    with open(config.md5_path, 'a', encoding='utf-8') as f:
        f.write(md5_str + '\n')

def get_string_md5(input_str: str, encoding='utf-8'):
    """
    将传入的字符串转换为md5字符串
    :return:
    """
    # 将字符串转化为 bytes 字节字符(二进制)
    str_bytes = input_str.encode(encoding=encoding)
    # 创建md5对象
    md5_obj = hashlib.md5() # 得到md5对象
    md5_obj.update(str_bytes) # 更新内容(传入即将要转换的字节数组)
    md5_hex = md5_obj.hexdigest() # 得到md5的十六进制字符串
    return md5_hex

class KnowledgeBaseServer(object):
    def __init__(self):
        # 如果文件夹不存在则 创建,存在则跳过
        os.makedirs(config.persist_directory, exist_ok=True)
        self.chroma = Chroma(
            collection_name=config.collection_name, # 数据库的表名
            embedding_function=DashScopeEmbeddings(model="text-embedding-v4"),
            persist_directory=config.persist_directory, # 数据库本地存储文件夹
        ) # 向量存储的实例,Chroma向量库对象
        self.spliter = RecursiveCharacterTextSplitter(
            chunk_size=config.chunk_size, # 分割后的文本段最大长度
            chunk_overlap=config.chunk_overlap, # 连续文本段之间的字符重叠数量
            separators=config.separators, # 自然段落划分的符号
            length_function=len, # 使用python自带的len函数 做长度统计的依据
        ) # 文本分割器

    def uploader_by_str(self, data: str, filename):
        """将传入的字符串进行向量化,存入向量数据库中"""
        # 先得到传入字符串的md5 值
        md5_hex = get_string_md5(data)
        if check_md5(md5_hex):
            return "[跳过]内容已经存在知识库中"
        if len(data) > config.max_split_char_number:
            knowledge_chunks = self.spliter.split_text(data)
        else:
            knowledge_chunks = [data]

        metadata = {
            "source": filename,
            # 2025-10-10 12:12:12
            "create_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            "operator": "zz"
        }
        self.chroma.add_texts( # 将内容加载到向量库中
            knowledge_chunks, # iterable -> list \ tuple
            metadatas=[metadata for _ in knowledge_chunks],
        )
        save_md5(md5_hex)
        return "[成功]内容已经成功载入向量库"

if __name__ == '__main__':
    # md5:结果算出来都是 32位的16进制 (方便做记录,效率高)
    # r1 = get_string_md5("周杰伦")
    # r2 = get_string_md5("周杰伦")
    # r3 = get_string_md5("周杰伦2")
    #
    # print(r1)
    # print(r2)
    # print(r3)
    # save_md5("7a8941058aaf4df5147042ce104568da")
    # print(check_md5('7a8941058aaf4df5147042ce104568da'))

    service = KnowledgeBaseServer()
    r = service.uploader_by_str("周杰伦2", "testfile")
    print(r)

RAG项目-完成离线流程开发

将文件上传代码和知识库更新 合起来;其中 app_file_uploader.py脚本是控制streamlit WEB页面进行上传页面;knowledge_base.py是 接受文本,分割文本,传到知识库(有查重功能)

为了不让streamlit页面一直刷新,状态不能保存,使用 st.session_state进行保存状态。

完整的 app_file_uploader代码如下:

"""
基于Steamlit 完成WEB网页上传服务 (Web框架)

简要介绍
基于tornada框架 快速搭建Web应用的python库,封装了大量常用组件方法,支持大量数据表、图表等对象的渲染,支持网格化、响应式布局。可以让不了解前端的人快速搭建网站

streamlit: 当WEB页面元素发生变化,代码重新执行一次
—— 缺点:造成状态的丢失
"""
import streamlit as st
from knowledge_base import KnowledgeBaseServer
import time

# 添加网页标题
st.title('知识库更新服务')

# file_uploader
uploader_file = st.file_uploader(
    "请上传txt文件",
    type=["txt"],
    accept_multiple_files=False # 仅接受一个文件上传
)
count = 0
# st.session_state() # 就是是字典 —— 状态不会刷新
if "service" not in st.session_state:
    st.session_state["service"] = KnowledgeBaseServer()

if uploader_file is not None:
    # 提取文件信息
    file_name = uploader_file.name
    file_type = uploader_file.type
    file_size = uploader_file.size / 1024 # KB

    st.subheader(f"文件名:{file_name}")
    st.write(f"格式:{file_type}, 大小:{file_size:.2f} KB")

    # 获取文件内容 得到的是字节数组 bytes 需要变成字符串 (.decode("utf-8"))
    text = uploader_file.getvalue().decode("utf-8")
    # st.write(text)
    with st.spinner("载入知识库中。。。"): # 在spinner 内的代码 执行过程中,会有转圈动画
        time.sleep(1)
        results = st.session_state["service"].uploader_by_str(text, file_name)
        st.write(results)

RAG项目-在线流程向量存储服务代码

在这里插入图片描述
向量存储代码实现:

from langchain_chroma import Chroma
from openai.types import embedding

import config_data

class VectorStoreService(object):
    def __init__(self, embedding):
        """
        :param embedding: 嵌入模型的传入
        """
        self.embedding = embedding

        self.vector_store = Chroma(
            collection_name=config_data.collection_name,
            embedding_function=self.embedding,
            persist_directory=config_data.persist_directory,
        )
    def get_retriever(self):
        """
        返回向量检索器,方便加入chain
        :return:
        """
        return self.vector_store.as_retriever(search_kwargs={"k": config_data.similarity_threshold})

if __name__ == '__main__':
    from langchain_community.embeddings import DashScopeEmbeddings
    retriever = VectorStoreService(DashScopeEmbeddings(model='text-embedding-v4')).get_retriever()
    res = retriever.invoke("我的体重100斤,尺码推荐")
    print(res)

RAG项目-RAG服务核心代码开发

主要涉及rag.py的开发,将 检索内容、用户提问组织新的提示词,并将向量模型和 对话模型 链起来。

from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableWithMessageHistory, RunnableLambda
from file_history_store import get_history
from vector_stores import VectorStoreService
from langchain_community.embeddings import DashScopeEmbeddings
import config_data as config
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_models.tongyi import ChatTongyi


def print_prompt(prompt):
    print("="*20)
    print(prompt.to_string())
    print("="*20)

    return prompt


class RagService(object):
    def __init__(self):

        self.vector_service = VectorStoreService(
            embedding=DashScopeEmbeddings(model=config.embedding_model_name)
        )

        self.prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", "以我提供的已知参考资料为主,"
                 "简洁和专业的回答用户问题。参考资料:{context}。"),
                ("system", "并且我提供用户的对话历史记录,如下:"),
                MessagesPlaceholder("history"),
                ("user", "请回答用户提问:{input}")
            ]
        )

        self.chat_model = ChatTongyi(model=config.chat_model_name)

        self.chain = self.__get_chain()

    def __get_chain(self):
        """获取最终的执行链"""
        retriever = self.vector_service.get_retriever()

        def format_document(docs: list[Document]):
            if not docs:
                return "无相关参考资料"

            formatted_str = ""
            for doc in docs:
                formatted_str += f"文档片段:{doc.page_content}\n文档元数据:{doc.metadata}\n\n"

            return formatted_str

        def format_for_retriever(value: dict):
            return value["input"]

        def format_for_prompt_template(value):
            # {input, context, history}
            new_value = {}
            new_value["input"] = value["input"]["input"]
            new_value["context"] = value["context"]
            new_value["history"] = value["input"]["history"]
            return new_value

        chain = (
            {
                "input": RunnablePassthrough(),
                "context": RunnableLambda(format_for_retriever) | retriever | format_document
            } | RunnableLambda(format_for_prompt_template) | self.prompt_template | print_prompt | self.chat_model | StrOutputParser()
        )

        conversation_chain = RunnableWithMessageHistory(
            chain,
            get_history,
            input_messages_key="input",
            history_messages_key="history",
        )

        return conversation_chain


if __name__ == '__main__':
    # session_id配置
    session_config = {
        "configurable": {
            "session_id": "user_001"
        }
    }
    # res = RagService().chain.invoke({"input": "我体重 100 斤, 身高155cm,尺码推荐"}, session_config)
    res = RagService().chain.invoke({"input": "春天穿什么颜色的衣服?"}, session_config)
    print(res)

RAG项目-历史会话记录功能的实现

import os
import json
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
from typing import Sequence


def get_history(session_id):
    return FileChatMessageHistory(session_id, "./chat_history")


class FileChatMessageHistory(BaseChatMessageHistory):
    def __init__(self, session_id, storage_path):
        self.session_id = session_id # 会话ID
        self.storage_path = storage_path # 不同会话id的存储路径,所在的文件夹路径
        # 完整的文件路径
        self.file_path = os.path.join(self.storage_path, self.session_id)
        # 确保文件夹是存在的
        os.makedirs(os.path.dirname(self.file_path), exist_ok=True)

    def add_messages(self, messages: Sequence[BaseMessage]) -> None:
        # Sequence序列,类似list,tuple
        all_messages = list(self.messages) # 已有的消息列表
        all_messages.extend(messages) # 新的消息和已有的融合成一个list

        # 将数据同步写入到本地文件中
        # 类对象写入文件 -> 二进制
        # 为了方便,将BaseMessage消息 转为字典,借助json模块以json字符串写入文件
        # message_to_dict : 单个消息对象(BaseMessage类实例) -> 字典

        # new_messages = []
        # for message in all_messages:
        #     d = message_to_dict(message)
        #     new_messages.append(d)
        new_messages = [message_to_dict(message) for message in all_messages]
        # 将数据写入文件
        with open(self.file_path, "w", encoding="utf-8") as f:
            json.dump(new_messages, f) # 将整体变成json写入文件中

    @property # @property 装饰器将messages方法变成 成员属性用
    def messages(self) -> list[BaseMessage]:
        # 当前文件内: List[字典], 需要转为 list[BaseMessage]
        try:
            with open(self.file_path, "r", encoding="utf-8") as f:
                messages_data = json.load(f) # 返回List[字典]
                return messages_from_dict(messages_data)
        except FileNotFoundError:
            return []

    def clear(self) -> None:
        with open(self.file_path, "w", encoding="utf-8") as f:
            json.dump([], f) # 清空文件

RAG项目-聊天页面的开发

网页的实现:主要实现 用户提问界面、流式输出、保存历史记录

import streamlit as st
import time
from rag import RagService
import config_data

# 标题
st.title('智能客服问答')
st.divider() # 分隔符

if 'rag_service' not in st.session_state:
    st.session_state['rag_service'] = RagService()

if "message" not in st.session_state:
    st.session_state["message"] = [{'role': 'assistant', 'content': '你好,有什么可以帮助你?'}]

for message in st.session_state["message"]:
    st.chat_message(message['role']).write(message['content'])


# 用户输入框
prompt = st.chat_input()

if prompt:
    # 在页面输出
    st.chat_message("user").write(prompt)
    st.session_state["message"].append({'role': 'user', 'content': prompt})
    ai_res_list = []
    with st.spinner("AI思考中。。。"):
        # res = st.session_state['rag_service'].chain.invoke({"input": prompt}, config_data.session_config)
        res_stream = st.session_state['rag_service'].chain.stream({"input": prompt}, config_data.session_config)
        # yield
        def capture(generator, cache_list):
            for chunk in generator:
                cache_list.append(chunk)
                yield chunk
        # st.chat_message("assistant").write(res)
        st.chat_message("assistant").write_stream(capture(res_stream, ai_res_list))
        # st.session_state["message"].append({'role': 'assistant', 'content': res})
        # 历史记录保存
        st.session_state["message"].append({'role': 'assistant', 'content': ''.join(ai_res_list)})
        # ['a', 'b', 'c']   ''.join(list) -> abc
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐