此文章还是根据黑马服务员的学习视频所做的个人笔记,是关于一个服装推荐助手AI的一个简单RAG项目的案例。学习视频地址:https://www.bilibili.com/video/BV1yjz5BLEoY/?spm_id_from=333.1391.0.0&p=55

目录

​编辑

RAG项目案例

离线流程开发

文本上传web服务

MD5工具函数开发

知识库更新服务

完成离线流程开发

在线流程开发

向量存储服务代码

Rag服务核心代码开发

历史记录会话功能的实现

聊天界面开发


RAG项目案例

本次项目以商品衣服为例,以衣服属性构建本地知识。使用者可以自由更新本地知识,用户问题的答案也是基于本地生成的。

离线流程开发

文本上传web服务

首先我们需要下载streamlit库,因为streamlit框架能够完成web网页上传服务,为了掌握初次使用,我们可以先做以下示例,设置标题,上传文件接口,与提取文字的信息。

"""
基于streamlit框架完成web网页上传服务
"""

import streamlit as st

st.title("知识库更新服务")

uploader_file = st.file_uploader(
    "请上传txt文件",
    type=["txt"],
    accept_multiple_files=False  #只接受一个文件的上传
)

if uploader_file is not None:
    #提取文件的信息
    file_name = uploader_file.name,
    file_type = uploader_file.type,
    file_size = uploader_file.size / 1024

    st.subheader(f"文件名:{file_name}")
    st.write(f"格式:{file_type} | 大小:{file_size:.2f} KB")

    #get_value获取文件内的内容
    text = uploader_file.getvalue().decode("UTF-8")
    st.write(text)

因为get_value获取的文件内容是byte数组形式的,所以我们要进行UTF-8的编码使我们获取到的文件内容正常。

但是要启动这个基于 Streamlit 的网页应用,我们不能像运行普通 Python 脚本那样直接运行,而是需要使用 Streamlit 提供的专用命令。我们要打开终端(pycharm终端),切换到我们这个py文件的所在目录,然后使用streamlit run app.py命令启动。

如图所示

成功启动后会有个email的可选选项,直接回车跳过就行。

实现结果

上传文件前

上传文件后

MD5工具函数开发

主要由三个函数组成,分别是字符串转md5值(get_string_md5),保存md5值(save_md5),和md5值的查重(check_md5)。我们需要下载hashlib的依赖库来实现md5对象的操作。代码简单,不多赘述。

"""
知识库
"""
import os.path

import config_data as config
import hashlib

#对文件进行查重
def check_md5(md5_str : str):
    """检查传入的md5字符是否已经被处理过了"""
    if not os.path.exists(config.md5_path):
        #文件不存在则创建文件
        open(config.md5_path,"w",encoding="UTF-8").close()
        return False
    else:
        for line in open(config.md5_path,"r",encoding="UTF-8").readlines():
            line = line.strip()     #处理字符串前后的空格和回车
            if line == md5_str:
                return True
        return False

#将得到的md5值进行保存
def save_md5(md5_str : str):
    """将传入的md5字符串记录到文件内保存"""
    with open(config.md5_path,"a",encoding="UTF-8") as f:
        f.write(md5_str + "\n")

#将字符串转换为md5字符串
def get_string_md5(input_str:str,   encoding='UTF-8'):
    """将传入的字符串转换为md5字符串"""
    str_byte = input_str.encode(encoding=encoding)
    md5_obj = hashlib.md5()
    md5_obj.update(str_byte)
    return md5_obj.hexdigest()

知识库更新服务

首先初始化KnowledgeBaseService类的向量库Chroma和文本分割器RecursiveCharacterTextSplitter,并将相关的配置都存储到config_data.py文件中。

然后传入字符串时,先转成md5值,然后检查md5值有没有重复,没有重复的话则检查数据的长度有没有小于文本分割的阈值,小于的话则将这个数据直接一整个装进数组里面,反之则用for循环将分割后的文本装进数组。因为是一个文档分割出来的文本,所以他们共享一个元数据,metadata for _ in knowledge_chunks通过这个列表推导式,我们复制了与knowledge_chunks同等长度的metadatas,一对一对加入到向量数据库中。最后对md5的值进行保存。

#知识库更新服务
class KnowledgeBaseService(object):

    def __init__(self):
        os.makedirs(config.persist_directory,exist_ok=True)
        self.chroma = Chroma(
            collection_name=config.collection_name, #数据库表明
            embedding_function=DashScopeEmbeddings(model="text-embedding-v4"),
            persist_directory = config.persist_directory #数据库本地文件夹
        )  #向量存储的实例
        self.spliter = RecursiveCharacterTextSplitter(
            chunk_size=config.chunk_size,
            chunk_overlap=config.chunk_overlap,
            separators=config.separators,
            length_function=len
        ) #文本分割器的对象

    def upload_by_str(self,data,filename):
        """将传入的字符串,进行向量化存入向量数据库中"""
        md5_hex = get_string_md5(data)
        if(check_md5(md5_hex)):
            return "[跳过]内容已经存在在知识库中"
        if len(data) > config.max_split_char_number :
            knowledge_chunks:list[str] = self.spliter.split_text(data)
        else:
            knowledge_chunks = [data]

        metadata = {
            "source": filename,
            "create_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "operator": "孙权"
        }

        #将内容加载到数据库中
        self.chroma.add_texts(
            knowledge_chunks,
            metadatas=[metadata for _ in knowledge_chunks]
        )

        save_md5(md5_hex)

        return "[成功]内容已经成功载入到向量库"

if __name__ == "__main__":
    service = KnowledgeBaseService()
    f = service.upload_by_str("周杰伦","testfile")
    print(f)

运行结果

再次运行结果

完成离线流程开发

因为需要知识库的更新与网页中的上传服务对接上,所以我们需要将KnowledgeBaseService类在app_file_uploader.py中进行实例化。而且我们需要保证,在streamlit启动后,我们动用的都是同一个KnowledgeBaseService类,但是streamlit在每次用户进行交互的时候都会重新运行,streamlit.session_state能够做到这一点是因为它为每个用户会话在服务器内部创建了一个持久化的字典对象,它的生命周期独立于脚本的重新运行,当脚本重新运行时,streamlit.session_state对象本身不会被重建,他会一直存在知道会话结束。

#session_state是一个字典
if "service" not in st.session_state:
    st.session_state["service"] = KnowledgeBaseService()

 with st.spinner("载入知识库中......"):
        time.sleep(1)
        result = st.session_state["service"].upload_by_str(text,file_name)
        st.write(result)

在线流程开发

向量存储服务代码

就是创建一个向量检索器,很简单,不多赘述。(调用向量检索器时,它会自己将传入的字符串转换为向量到向量库中进行匹配,并返回一定数目的文档数组)写在vector_stores.py文件中。

class VectorStoreService(object):
    def __init__(self,embedding):
        self.embedding = embedding
        self.vector_store = Chroma(
            collection_name=config.collection_name,
            embedding_function=self.embedding,
            persist_directory=config.persist_directory
        )

    def get_retriever(self):
        #返回向量检索器,方便加入chain
        return self.vector_store.as_retriever(search_kwargs={"k":config.similarity_threshold})

Rag服务核心代码开发

在写到这里时,记得将资料中的data资料通过离线模式上传到知识库进行更新!

四个核心成员:向量库服务类VectorStoreService,提示词模板,聊天模型和链对象。

首先将调用链时传入的字符串,使用向量库服务得到返回的文档列表,然后创建一个函数把文档列表转换为字符串得到content,接着将input和content放入提示词模板丢给ai,再把ai的返回内容字符串化得到结果。

from xml.dom.minidom import Document

from langchain_community.chat_models import ChatTongyi
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

from vector_stores import VectorStoreService
import config_data as config


def print_prompt(prompt):
    print("="*20)
    print(prompt.to_string())
    print("="*20)
    return  prompt

class RagService(object):
    def __init__(self):
        self.vector_service = VectorStoreService(
            embedding=DashScopeEmbeddings(model=config.embedding_model_name)

        )

        self.prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system","以我提供的已知参考资料为主,"
                 "简洁和专业的回答用户问题,参考资料:{context}"),
                ("user","请回答用户提问:{input}")
            ]
        )

        self.chat_model = ChatTongyi(
            model_name=config.chat_model_name,
        )

        self.chain = self.__get__chain()

    def __get__chain(self):
        retriever = self.vector_service.get_retriever()

        def format_document(docs: list[Document]):

            if not docs:
                return "无相关参考资料"

            formatted_str = ""
            for doc in docs:
                formatted_str += f"文档片段:{doc.page_content}\n\n"

            return formatted_str

        chain = (
            {
                "input":RunnablePassthrough(),
                "context":retriever | format_document
            }   | self.prompt_template | print_prompt |self.chat_model | StrOutputParser()
        )

        return chain

if __name__ == '__main__':
    res = RagService().chain.invoke("我体重180斤,尺码推荐")
    print(res)

实现效果

历史记录会话功能的实现

这里应该是很难理解的部分,因为我们对原有链进行了携带历史会话记录的加强,使我们调用链时输入变成了字典,导致向量检索器的传参出错(它需要字符串),你想出来把input提取出来,又会导致history的丢失。

所以能够成功实现的思路如下:

在进行chain调用的时候,input和context先依靠一个函数提取到传入的字典中用户的输入"input",RunnableLambda(format_for_retriever),返回的结果是依然是一个字典,这个字典里的key分别是input(传入给向量检索器的参数)和context(向量检索器返回的文档经过转换后的字符串),但是input的value,也就是调用链时的传参也是个字典包含input(调用加强链时传入的参数)和history(用session_id找到的历史会话记录),所以最终我们需要生成一个新字典,把这三个都包含上,也就是format_for_prompt_template这个函数的实现,才能够实现提示词模板的正常嵌入,完成后续功能。

from xml.dom.minidom import Document

from langchain_community.chat_models import ChatTongyi
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough, RunnableWithMessageHistory, RunnableLambda
from file_history_store import get_history
from vector_stores import VectorStoreService
import config_data as config


def print_prompt(prompt):
    print("="*20)
    print(prompt.to_string())
    print("="*20)
    return  prompt

class RagService(object):
    def __init__(self):
        self.vector_service = VectorStoreService(
            embedding=DashScopeEmbeddings(model=config.embedding_model_name)

        )

        self.prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system","以我提供的已知参考资料为主,"
                 "简洁和专业的回答用户问题,参考资料:{context}"),
                ("system","并且我提供用户的对话历史记录如下"),
                MessagesPlaceholder("history"),
                ("user","请回答用户提问:{input}")
            ]
        )

        self.chat_model = ChatTongyi(
            model_name=config.chat_model_name,
        )

        self.chain = self.__get__chain()

    def __get__chain(self):
        retriever = self.vector_service.get_retriever()

        def format_document(docs: list[Document]):

            if not docs:
                return "无相关参考资料"

            formatted_str = ""
            for doc in docs:
                formatted_str += f"文档片段:{doc.page_content}\n\n"

            return formatted_str

        def format_for_retriever(value:dict) -> str:
            return value["input"]

        def format_for_prompt_template(value):
            new_value = {}
            new_value["input"] = value["input"]["input"]
            new_value["history"] = value["input"]["history"]
            new_value["context"] = value["context"]
            return new_value

        chain = (
            {
                "input":RunnablePassthrough(),
                "context":RunnableLambda(format_for_retriever)| retriever | format_document
            } |RunnableLambda(format_for_prompt_template) | self.prompt_template | print_prompt |self.chat_model | StrOutputParser()
        )

        conversation_chain = RunnableWithMessageHistory(
            chain,
            get_history,
            input_messages_key="input",
            history_messages_key="history"
        )

        return conversation_chain

if __name__ == '__main__':
    session_config = {
        "configurable":{
            "session_id": "user_001"
        }
    }
    res = RagService().chain.invoke({"input":"春天推荐我穿什么衣服"},session_config)
    print(res)

成功实现的模型的回答

聊天ye

ai回答如下

System: 并且我提供用户的对话历史记录如下
Human: 我体重180斤,尺码推荐
AI: 根据您提供的体重180斤,结合参考资料中的尺码建议:

  • 若您的身高在 180–190cm 范围内,推荐尺码为 4XL
  • 若您的身高 超过190cm,则推荐尺码为 5XL

建议您同时参考身高数据以获得更准确的尺码推荐。
Human: 请回答用户提问:春天推荐我穿什么衣服
====================
根据您的体型(体重180斤,属于丰满体型)和春季特点,结合参考资料,为您推荐如下穿搭建议:

一、颜色选择

  • 主色调:优先选择深色系(如藏蓝、深灰、黑色),视觉显瘦;
  • 点缀色:可搭配低饱和度亮色(如雾霾蓝、浅紫)或小面积暖色配饰(如豆沙色丝巾),提亮气色又不显臃肿;
  • 避免:大面积亮色、横条纹、冷调荧光色(如荧光绿)。

二、款式与材质(适合春季)

  • 上衣:选择纯棉或针织棉的宽松衬衫、薄开衫,避免紧身;
  • 下装:推荐薄款直筒牛仔裤或深色休闲裤,版型挺括更修饰腿型;
  • 外套:可选轻薄化纤风衣(深色),防风又利落;
  • 整体搭配:采用“上浅下深”或“同色系深色搭配”,拉长比例、显瘦显高。

三、养护小贴士

  • 春季潮湿,衣物洗后务必阴干,收纳时放干燥剂防霉;
  • 牛仔和针织类避免暴晒,防止褪色和变形。

综上,推荐组合示例:
藏蓝色薄针织开衫 + 白色纯棉T恤(内搭)+ 深灰直筒休闲裤 + 黑色轻薄风衣(早晚穿),既符合春季氛围,又修饰体型。

聊天界面开发

基于streamlit框架,将RagService服务类(避免服务类在重新运行时被重新覆盖)和历史消息(每次重新运行都要将历史消息的记录显示出来)保存到服务器内部,实现持久化。

但是我们要想实现流式输出的效果,又会和历史消息的添加类型不兼容,所以我们新创建一个list队列,每次流式输出时,都会将信息添加到这个队列,最后将这个队列字符串化存入到历史消息队列中。

import streamlit as st
from rag import RagService
import config_data as config

st.title("智能客服")
st.divider() #分隔符

if "message" not in st.session_state:
    st.session_state["message"] = [{"role":"assistant","content":"你好,有什么可以帮助你"}]

if "rag" not in st.session_state:
    st.session_state["rag"] = RagService()

for message in st.session_state["message"]:
    st.chat_message(message["role"]).write(message["content"])

prompt = st.chat_input()

if prompt:
    st.chat_message("user").write(prompt)
    st.session_state["message"].append({"role":"user","content":prompt})

    ai_list = []
    with st.spinner("AI思考中..."):
        #流式输出
        res_stream = st.session_state["rag"].chain.stream({"input":prompt},config.session_config)

        def capture(generator,cache_list):
            for chunk in generator:
                cache_list.append(chunk)
                yield chunk

        st.chat_message("assistant").write_stream(capture(res_stream,ai_list))
        st.session_state["message"].append({"role": "assistant", "content": "".join(ai_list)})

实现的AI聊天如下

服装穿搭助手AI-实现效果

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐