前 言

咱们前面的代码是通过手写Agent工作流程,实现了一个论文RAG问答系统,但是在实际生产环境中不会用这种纯手写工作逻辑项目,更多的是使用现有框架比如LangGraph和LangChain,LangChain是一种高级封装后的框架,更适合需要借助智能体完成固定任务的非专业用户,而LangGraph是一种更细粒度的智能体开发框架,允许完全自定义图结构,适合定制更复杂的智能体,以满足个性化的任务需求。

打个比较形象的比方,相信很多人都喜欢摄影,专业摄影师会有各种长枪短炮,并且在摄影时会调整相机参数,比如曝光度、白平衡等等,然后看到有灵感的风景会考虑景深,和前后景的关系,最终拍出来一张美图。基于LangGraph的开发与这类似,整个过程细节可调的地方非常多,所以做出来的东西比较专业,定制化也比较高。相反,基于LangChain的智能体开发有点像是一个人,拿着那种开了美颜的傻瓜机在拍照,各种参数和接口都已经被提前设定好了,也可以拍出能看的照片,但是这种东西有上限。

开始动手

咱们手写Agent是写了三个函数choose_toolexecute_toolgenerate_final_answer,通过数据在三个函数间的流动实现了一个最小的智能体循环:

用户问题 → 思考 → 调工具 → 看结果 → 再思考 → 再决定下一步 → 最终回答

当进入到LangGraph后咱们对数据的概念需要发生一点变化,在这个框架中数据称为State,对数据进行操作函数称为Node,不同节点间通过Edge连接构成流,这个框架命中带Graph,可能是因为这个工作流结构和图论中的图一样。

为了更好的理解LangGraph,咱们还是在原有的代码上做迁移,但是最好还是新建一个项目,Python版本选择3.11,随后安装咱们前面用到的依赖,并且还要把LangGraph等依赖安装上:

pip install -U langgraph langchain langchain-openai
pip install -U fastapi uvicorn openai python-dotenv pypdf numpy pydantic
pip install -U faiss-cpu

项目结构重构

咱们的项目需要重构,可以按照下面的目录重新组织:

LangChain-for-A-Paper-Rag-Agent/
├─ app/
│  ├─ __init__.py
│  ├─ config.py
│  ├─ logger_config.py
│  ├─ llm_utils.py
│  ├─ data_loader.py
│  ├─ rag_system.py
│  ├─ tools.py
│  ├─ session_manager.py
│  ├─ main.py
│  └─ graph/
│     ├─ __init__.py
│     ├─ state.py
│     ├─ nodes.py
│     └─ builder.py
├─ data/
├─ .env
└─ requirements.txt

数据State化

智能体对数据的操作原理上是:Graph中的每个节点只负责读state,改state,再把结果交给下一个节点。所以我们首先在app/graph/目录下创建state.py,用于定义在图中流的state:

from typing import Any

from typing_extensions import TypedDict


class AgentState(TypedDict, total=False):
    # 当前会话信息
    session_id: str

    # 当前用户问题
    query: str

    # 历史对话
    chat_history: list[dict[str, str]]

    # 路由决策结果
    # 例如: {"tool": "rag", "input": "what is ..."}
    decision: dict[str, Any]

    # 工具执行结果
    # 例如:
    # {
    #   "tool_name": "rag",
    #   "tool_input": "...",
    #   "tool_output": "..."
    # }
    tool_result: dict[str, Any]

    # 最终返回给用户的答案
    final_answer: str

    # 预留错误字段,后面做异常兜底会用到
    error: str

函数Node化

在LangGraph中函数称为Node,我们需要咱们最早在agent.py中定义的三个函数进行改造,让他们变成可以接受数据state的node,首先需要在app/graph/目录下新建node.py,然后定义两个工厂函数和一个节点:

关于为什么用工厂函数后面会详细解释。

import json
from typing import Any

from app.graph.state import AgentState
from app.llm_utils import client
from app.config import CHAT_MODEL
from app.logger_config import setup_logger

logger = setup_logger()


def build_choose_tool_node(tools: list[dict[str, Any]]):
    def choose_tool_node(state: AgentState) -> AgentState:
        query = state["query"]

        tool_desc = "\n".join([
            f"{t['name']}: {t['description']}" for t in tools
        ])

        prompt = f"""
					You are an AI agent.
					
					Available tools:
					{tool_desc}
					
					User question:
					{query}
					
					Return JSON:
					{{"tool": "...", "input": "..."}}
				 """

        content = ""
        try:
            response = client.chat.completions.create(
                model=CHAT_MODEL,
                messages=[{"role": "user", "content": prompt}]
            )
            content = response.choices[0].message.content
            decision = json.loads(content)
        except Exception:
            logger.warning(f"Tool decision parse failed: {content}")
            decision = {"tool": "llm", "input": query}

        return {
            "decision": decision
        }

    return choose_tool_node


def build_execute_tool_node(tools: list[dict[str, Any]], rag=None):
    def execute_tool_node(state: AgentState) -> AgentState:
        decision = state["decision"]
        chat_history = state.get("chat_history", [])

        tool_name = decision["tool"]
        tool_input = decision["input"]

        for t in tools:
            if t["name"] == tool_name:
                if tool_name == "rag":
                    result = t["func"](tool_input, rag, chat_history=chat_history)
                elif tool_name == "llm":
                    result = t["func"](tool_input, chat_history=chat_history)
                else:
                    result = t["func"](tool_input)

                return {
                    "tool_result": {
                        "tool_name": tool_name,
                        "tool_input": tool_input,
                        "tool_output": result
                    }
                }

        return {
            "tool_result": {
                "tool_name": "none",
                "tool_input": tool_input,
                "tool_output": "No valid tool found."
            }
        }

    return execute_tool_node


def generate_answer_node(state: AgentState) -> AgentState:
    query = state["query"]
    tool_result = state["tool_result"]

    prompt = f"""
				You are an AI assistant.
				
				The user asked:
				{query}
				
				A tool was used:
				Tool name: {tool_result['tool_name']}
				Tool input: {tool_result['tool_input']}
				Tool output: {tool_result['tool_output']}
				
				Now provide a final helpful answer to the user.
			 """

    response = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=[{"role": "user", "content": prompt}]
    )

    final_answer = response.choices[0].message.content

    return {
        "final_answer": final_answer
    }

串起Node形成Graph

到这一步咱们就到了LangGraph的核心操作,将各个node串起形成工作流,前面咱们说过咱们在创建节点的时候用到了工厂函数,在这一节我也想解释原因,不够首先咱们还是先在app/graph/目录下新建一个builder.py文件:

from langgraph.graph import StateGraph, START, END

from app.graph.state import AgentState
from app.graph.nodes import (
    build_choose_tool_node,
    build_execute_tool_node,
    generate_answer_node,
)


def build_agent_graph(tools, rag=None):
    graph_builder = StateGraph(AgentState)

    # 1. 注册节点
    graph_builder.add_node("choose_tool", build_choose_tool_node(tools))
    graph_builder.add_node("execute_tool", build_execute_tool_node(tools, rag=rag))
    graph_builder.add_node("generate_answer", generate_answer_node)

    # 2. 连接流程
    graph_builder.add_edge(START, "choose_tool")
    graph_builder.add_edge("choose_tool", "execute_tool")
    graph_builder.add_edge("execute_tool", "generate_answer")
    graph_builder.add_edge("generate_answer", END)

    # 3. 编译 graph
    return graph_builder.compile()

相信看到这里,你已经看明白了LangGraph的逻辑,这个框架将数据定义state,将函数定义为node,通过添加edge实现业务逻辑间的工作流。可能打一个不太恰当的比方,通过LangGraph创建的智能体他身上像是绑定了一堆state数据,你给它定义了node和edge,它就会按照你定义的工作顺序去拿node修改自身绑定的state,最后返回给你操作结果。

关于工厂函数这里我解释一下(我也是踩了坑了):选择工具node和执行node不同于咱们的生成回答node,生成回答node只需要给她传入state就可以得到响应的结果,但是前两个node你发现他们在创建的时候是需要一些外部依赖的,所以咱们需要定义工厂函数,将外部依赖传入,再生成出响应的node,所以这里用到了工厂函数。一句话概括:一些node的生成需要外部依赖,所以咱们需要用工厂函数加工再生成node。

收尾

最后就是启动咱们的项目了,在app目标下新建main.py文件,然后在原有的代码基础上增加一步创建图的操作graph = build_agent_graph(TOOLS, rag=rag),就可以启动新的系统了,这样便实现了手写Agent到LangGraph的迁移:

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from app.config import DATA_DIR
from app.data_loader import load_pdfs, process_documents
from app.rag_system import RAGSystem
from app.tools import TOOLS
from app.session_manager import SessionManager
from app.graph.builder import build_agent_graph
from app.logger_config import setup_logger

logger = setup_logger()

app = FastAPI()

session_manager = SessionManager(max_turns=3)

rag = None
graph = None


class QueryRequest(BaseModel):
    session_id: str
    question: str


@app.on_event("startup")
def startup_event():
    global rag, graph

    logger.info("Loading RAG system...")

    docs = load_pdfs(DATA_DIR)
    logger.info(f"docs数量: {len(docs)}")

    chunks = process_documents(docs)
    logger.info(f"chunks数量: {len(chunks)}")

    rag = RAGSystem(chunks)
    rag.build_index()

    graph = build_agent_graph(TOOLS, rag=rag)

    logger.info("RAG + LangGraph ready!")


@app.post("/ask")
def ask_question(req: QueryRequest):
    try:
        history = session_manager.get_history(req.session_id)

        state = {
            "session_id": req.session_id,
            "query": req.question,
            "chat_history": history,
        }

        result = graph.invoke(state)
        answer = result["final_answer"]

        session_manager.append_turn(
            req.session_id,
            req.question,
            answer
        )

        return {
            "session_id": req.session_id,
            "question": req.question,
            "answer": answer
        }

    except Exception as e:
        logger.exception("Error occurred in /ask")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/clear/{session_id}")
def clear_session(session_id: str):
    session_manager.clear_session(session_id)
    return {
        "session_id": session_id,
        "message": "session cleared"
    }

如果这篇文章对你有帮助,可以点个赞~
完整代码地址:https://github.com/1186141415/LangChain-for-A-Paper-Rag-Agent

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐