在这里插入图片描述

一、前言

在现代人工智能应用开发中,如何有效组织和编排复杂的AI工作流是一个关键挑战。LangGraph 应运而生,它是一个基于 LangChain 的框架,专门用于构建有状态的、多参与者的智能代理(Agent)工作流。与传统的线性执行链不同,LangGraph 采用图形结构来组织计算流程,支持循环、分支、并行和状态管理,为构建复杂的AI应用提供了强大的基础设施。

为什么需要 LangGraph?在现实世界的 AI
应用中,许多任务不是简单的“输入-处理-输出”线性流程。例如,一个客服机器人可能需要:1. 理解用户问题;2. 查询知识库;3.
如果信息不足则询问用户;4. 调用计算工具;5. 生成回答。这种涉及决策、循环、工具调用和状态保持的流程,正是 LangGraph
擅长的领域。


二、LangGraph 核心能力

2.1 图形API

图形API是LangGraph最基础也最核心的能力。它允许开发者以节点的形式定义工作流,其中节点代表处理步骤,边定义步骤之间的流转逻辑。这种图形化表示方法特别适合复杂的、有状态的工作流。

2.1.1 基本概念

在LangGraph中,一个图由以下基本元素构成:

  • State:贯穿整个图执行过程的状态对象,在不同节点间传递和修改
  • Nodes:执行具体操作的函数,接收状态并返回更新后的状态
  • Edges:定义节点之间的流转条件,可以是普通边(无条件流转)或条件边(根据条件选择路径)
2.1.2 创建基本工作流

让我们通过一个简单的示例来理解如何创建LangGraph工作流。假设我们要构建一个对话代理,它能够:1. 接收用户输入;2. 调用语言模型生成回复;3. 如果用户说了"再见"则结束对话。

# 示例1:基本的LangGraph工作流
from typing import TypedDict, Annotated
from typing_extensions import TypedDict
import operator
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI

# 1. 定义状态结构
class ConversationState(TypedDict):
    """对话状态,包含用户消息和AI回复"""
    user_input: str  # 用户输入
    ai_response: Annotated[str, operator.add]  # AI回复,使用operator.add进行追加
    message_count: int  # 消息计数
    should_end: bool  # 是否结束对话

# 2. 定义节点函数
def user_input_node(state: ConversationState) -> ConversationState:
    """接收用户输入的节点"""
    # 在实际应用中,这里可能从API、终端或界面获取用户输入
    # 为示例,我们使用固定的输入
    if state.get("user_input") is None:
        state["user_input"] = "你好,我想了解LangGraph"
    
    # 增加消息计数
    state["message_count"] = state.get("message_count", 0) + 1
    
    print(f"[用户输入节点] 收到用户输入: {state['user_input']}")
    print(f"[用户输入节点] 当前消息数: {state['message_count']}")
    
    return state

def ai_response_node(state: ConversationState) -> ConversationState:
    """生成AI回复的节点"""
    # 初始化语言模型
    llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)
    
    # 基于用户输入生成回复
    user_message = state["user_input"]
    
    # 检查是否结束对话
    if "再见" in user_message or "bye" in user_message.lower():
        state["ai_response"] = "感谢对话,再见!"
        state["should_end"] = True
    else:
        # 生成回复
        response = llm.invoke(f"用户说: {user_message}\n请生成友好、有帮助的回复:")
        state["ai_response"] = response.content
    
    print(f"[AI回复节点] 生成回复: {state['ai_response']}")
    print(f"[AI回复节点] 是否结束: {state.get('should_end', False)}")
    
    return state

def end_conversation_node(state: ConversationState) -> ConversationState:
    """结束对话节点"""
    print("[结束节点] 对话结束,感谢使用!")
    print(f"[结束节点] 总消息数: {state['message_count']}")
    return state

# 3. 构建图
def build_conversation_graph():
    """构建对话图"""
    # 创建状态图
    workflow = StateGraph(ConversationState)
    
    # 添加节点
    workflow.add_node("user_input", user_input_node)
    workflow.add_node("ai_response", ai_response_node)
    workflow.add_node("end_conversation", end_conversation_node)
    
    # 设置入口点
    workflow.set_entry_point("user_input")
    
    # 添加边
    workflow.add_edge("user_input", "ai_response")
    
    # 添加条件边:根据是否结束决定下一步
    def should_end(state: ConversationState) -> str:
        """判断是否结束对话"""
        if state.get("should_end", False):
            return "end_conversation"
        else:
            return "__end__"  # 正常结束,回到入口点等待下一次用户输入
    
    workflow.add_conditional_edges(
        "ai_response",
        should_end,
        {
            "end_conversation": "end_conversation",
            "__end__": END
        }
    )
    
    workflow.add_edge("end_conversation", END)
    
    # 编译图
    graph = workflow.compile()
    
    return graph

# 4. 执行图
if __name__ == "__main__":
    # 构建图
    conversation_graph = build_conversation_graph()
    
    # 初始状态
    initial_state = {
        "user_input": None,
        "ai_response": "",
        "message_count": 0,
        "should_end": False
    }
    
    print("=== 开始对话 ===")
    
    # 第一轮对话
    print("\n--- 第一轮对话 ---")
    result1 = conversation_graph.invoke(initial_state)
    print(f"AI回复: {result1['ai_response']}")
    
    # 第二轮对话 - 模拟用户输入"再见"
    print("\n--- 第二轮对话 ---")
    result1["user_input"] = "谢谢,再见!"
    result2 = conversation_graph.invoke(result1)
    print(f"AI回复: {result2['ai_response']}")
    
    print(f"\n=== 对话结束,共交换 {result2['message_count']} 条消息 ===")
2.1.3 条件边和分支

LangGraph的强大之处在于能够根据状态动态选择执行路径。下面的示例展示了一个更复杂的分支逻辑:

# 示例2:带条件分支的工作流
from langgraph.graph import StateGraph, END
from typing import Literal

class DecisionState(TypedDict):
    """决策状态"""
    user_query: str
    query_type: str  # 查询类型:question, calculation, greeting, unknown
    response: str
    steps_taken: Annotated[list, operator.add]  # 记录执行步骤

def classify_query_node(state: DecisionState) -> DecisionState:
    """分类查询类型的节点"""
    query = state["user_query"].lower()
    
    if "你好" in query or "hi" in query or "hello" in query:
        state["query_type"] = "greeting"
    elif "?" in query or "什么" in query or "如何" in query:
        state["query_type"] = "question"
    elif "+" in query or "-" in query or "*" in query or "/" in query or "等于" in query:
        state["query_type"] = "calculation"
    else:
        state["query_type"] = "unknown"
    
    state["steps_taken"].append(f"分类查询: {state['user_query']} -> {state['query_type']}")
    return state

def handle_greeting_node(state: DecisionState) -> DecisionState:
    """处理问候语"""
    state["response"] = "你好!我是AI助手,有什么可以帮您的?"
    state["steps_taken"].append("处理问候")
    return state

def handle_question_node(state: DecisionState) -> DecisionState:
    """处理问题"""
    state["response"] = f"您的问题是:{state['user_query']}\n这是一个很好的问题,但我需要更多上下文来准确回答。"
    state["steps_taken"].append("处理问题")
    return state

def handle_calculation_node(state: DecisionState) -> DecisionState:
    """处理计算"""
    try:
        # 简单计算逻辑
        query = state["user_query"]
        if "+" in query:
            parts = query.split("+")
            result = float(parts[0]) + float(parts[1])
        elif "-" in query:
            parts = query.split("-")
            result = float(parts[0]) - float(parts[1])
        elif "*" in query:
            parts = query.split("*")
            result = float(parts[0]) * float(parts[1])
        elif "/" in query:
            parts = query.split("/")
            result = float(parts[0]) / float(parts[1])
        elif "等于" in query:
            # 简单处理 "2+2等于多少"
            expr = query.split("等于")[0]
            result = eval(expr)  # 注意:实际应用中慎用eval
        else:
            result = "无法计算"
        
        state["response"] = f"计算结果: {result}"
    except:
        state["response"] = "抱歉,我无法计算这个表达式"
    
    state["steps_taken"].append("处理计算")
    return state

def handle_unknown_node(state: DecisionState) -> DecisionState:
    """处理未知查询"""
    state["response"] = f"我不太确定如何处理: {state['user_query']},您可以重新表述吗?"
    state["steps_taken"].append("处理未知查询")
    return state

def build_decision_graph():
    """构建决策图"""
    workflow = StateGraph(DecisionState)
    
    # 添加所有节点
    workflow.add_node("classify", classify_query_node)
    workflow.add_node("greeting", handle_greeting_node)
    workflow.add_node("question", handle_question_node)
    workflow.add_node("calculation", handle_calculation_node)
    workflow.add_node("unknown", handle_unknown_node)
    
    # 设置入口点
    workflow.set_entry_point("classify")
    
    # 根据查询类型路由到不同节点
    def route_by_type(state: DecisionState) -> Literal["greeting", "question", "calculation", "unknown"]:
        return state["query_type"]
    
    workflow.add_conditional_edges(
        "classify",
        route_by_type,
        {
            "greeting": "greeting",
            "question": "question",
            "calculation": "calculation",
            "unknown": "unknown"
        }
    )
    
    # 所有处理节点都流向结束
    for node in ["greeting", "question", "calculation", "unknown"]:
        workflow.add_edge(node, END)
    
    return workflow.compile()

# 测试决策图
if __name__ == "__main__":
    decision_graph = build_decision_graph()
    
    test_cases = [
        "你好,今天天气怎么样?",
        "2+2等于多少?",
        "什么是LangGraph?",
        "随便说点什么"
    ]
    
    for query in test_cases:
        print(f"\n=== 测试查询: {query} ===")
        initial_state = {
            "user_query": query,
            "query_type": "",
            "response": "",
            "steps_taken": []
        }
        
        result = decision_graph.invoke(initial_state)
        print(f"查询类型: {result['query_type']}")
        print(f"响应: {result['response']}")
        print(f"执行步骤: {result['steps_taken']}")

2.2 持久性

持久性是LangGraph的关键特性之一,它允许工作流在中断后恢复执行。这对于长时间运行的任务、对话系统和需要处理中断的应用程序至关重要。

2.2.1 检查点机制

LangGraph通过检查点(Checkpoints)实现持久性。检查点保存了图的完整状态,包括所有变量的值和当前的执行位置。

# 示例3:持久性和检查点
import json
from datetime import datetime
from langgraph.checkpoint import MemorySaver
from langgraph.graph import StateGraph, START, END
from typing import Annotated

class PersistentState(TypedDict):
    """持久化状态"""
    task: str
    steps_completed: Annotated[list, operator.add]
    current_step: int
    result: str
    is_complete: bool
    created_at: str
    updated_at: str

def create_task_node(state: PersistentState) -> PersistentState:
    """创建任务节点"""
    if not state.get("task"):
        state["task"] = "默认任务"
    
    if not state.get("created_at"):
        state["created_at"] = datetime.now().isoformat()
    
    state["updated_at"] = datetime.now().isoformat()
    state["steps_completed"].append("任务创建")
    state["current_step"] = 1
    
    print(f"[创建任务] 任务: {state['task']}")
    return state

def process_step_1(state: PersistentState) -> PersistentState:
    """处理步骤1"""
    state["steps_completed"].append("步骤1: 数据收集")
    state["current_step"] = 2
    state["updated_at"] = datetime.now().isoformat()
    
    print(f"[步骤1] 完成数据收集")
    return state

def process_step_2(state: PersistentState) -> PersistentState:
    """处理步骤2"""
    state["steps_completed"].append("步骤2: 数据处理")
    state["current_step"] = 3
    state["updated_at"] = datetime.now().isoformat()
    
    # 模拟耗时操作
    import time
    time.sleep(1)  # 模拟处理时间
    
    print(f"[步骤2] 完成数据处理")
    return state

def process_step_3(state: PersistentState) -> PersistentState:
    """处理步骤3"""
    state["steps_completed"].append("步骤3: 结果生成")
    state["current_step"] = 4
    state["result"] = f"任务 '{state['task']}' 完成!"
    state["is_complete"] = True
    state["updated_at"] = datetime.now().isoformat()
    
    print(f"[步骤3] 生成结果: {state['result']}")
    return state

def build_persistent_graph():
    """构建带持久化的图"""
    workflow = StateGraph(PersistentState)
    
    # 添加节点
    workflow.add_node("create", create_task_node)
    workflow.add_node("step1", process_step_1)
    workflow.add_node("step2", process_step_2)
    workflow.add_node("step3", process_step_3)
    
    # 设置边
    workflow.add_edge(START, "create")
    workflow.add_edge("create", "step1")
    workflow.add_edge("step1", "step2")
    workflow.add_edge("step2", "step3")
    workflow.add_edge("step3", END)
    
    # 创建内存检查点存储
    memory = MemorySaver()
    
    # 编译图,启用持久化
    graph = workflow.compile(checkpointer=memory)
    
    return graph, memory

# 测试持久化功能
if __name__ == "__main__":
    # 构建图
    persistent_graph, checkpointer = build_persistent_graph()
    
    # 定义线程ID(在实际应用中可能是用户ID或会话ID)
    thread_id = "user_123_session_1"
    
    print("=== 第一次执行(完整流程)===")
    
    # 初始状态
    initial_config = {
        "configurable": {
            "thread_id": thread_id
        }
    }
    
    initial_state = {
        "task": "处理用户数据分析",
        "steps_completed": [],
        "current_step": 0,
        "result": "",
        "is_complete": False,
        "created_at": "",
        "updated_at": ""
    }
    
    # 执行图
    print("开始执行工作流...")
    try:
        # 模拟执行直到完成
        for step in persistent_graph.stream(initial_state, initial_config, stream_mode="values"):
            print(f"当前状态: 步骤{step['current_step']}, 完成步骤: {step['steps_completed'][-1] if step['steps_completed'] else '无'}")
    except Exception as e:
        print(f"执行中断: {e}")
    
    print("\n=== 检查点状态 ===")
    
    # 获取检查点
    checkpoints = list(checkpointer.list(initial_config))
    print(f"检查点数量: {len(checkpoints)}")
    
    for i, checkpoint in enumerate(checkpoints):
        print(f"\n检查点 {i}:")
        print(f"  ID: {checkpoint.checkpoint_id}")
        print(f"  时间: {checkpoint.timestamp}")
        print(f"  父ID: {checkpoint.parent_checkpoint_id}")
    
    # 恢复执行
    print("\n=== 从检查点恢复执行 ===")
    
    # 获取最后一个检查点
    if checkpoints:
        last_checkpoint = checkpoints[-1]
        
        # 从检查点恢复配置
        resume_config = {
            "configurable": {
                "thread_id": thread_id,
                "checkpoint_id": last_checkpoint.checkpoint_id
            }
        }
        
        # 从检查点恢复执行
        print(f"从检查点 {last_checkpoint.checkpoint_id} 恢复...")
        for step in persistent_graph.stream(None, resume_config, stream_mode="values"):
            print(f"恢复后状态: 步骤{step['current_step']}, 结果: {step.get('result', '未完成')}")
2.2.2 持久化存储后端

在实际应用中,我们通常需要将检查点保存到数据库或文件系统中。以下示例展示了如何自定义持久化存储:

# 示例4:自定义持久化存储
import pickle
import os
from typing import Any, Dict, List, Optional, Tuple
from langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata
from langgraph.checkpoint.serde.base import SerializerProtocol

class FileCheckpointSaver(BaseCheckpointSaver):
    """文件系统检查点存储"""
    
    def __init__(self, storage_path: str = "./checkpoints"):
        self.storage_path = storage_path
        os.makedirs(storage_path, exist_ok=True)
    
    def get(self, config: Dict[str, Any]) -> Optional[Checkpoint]:
        """获取检查点"""
        thread_id = config["configurable"]["thread_id"]
        checkpoint_id = config["configurable"].get("checkpoint_id", "latest")
        
        if checkpoint_id == "latest":
            # 获取最新的检查点
            checkpoint_files = [f for f in os.listdir(self.storage_path) 
                              if f.startswith(f"{thread_id}_")]
            if not checkpoint_files:
                return None
            checkpoint_files.sort()
            checkpoint_id = checkpoint_files[-1].replace(f"{thread_id}_", "").replace(".pkl", "")
        
        file_path = os.path.join(self.storage_path, f"{thread_id}_{checkpoint_id}.pkl")
        
        if not os.path.exists(file_path):
            return None
        
        with open(file_path, "rb") as f:
            checkpoint_data = pickle.load(f)
        
        return Checkpoint(
            checkpoint_id=checkpoint_id,
            ts=checkpoint_data["ts"],
            channel_values=checkpoint_data["channel_values"],
            channel_versions=checkpoint_data["channel_versions"],
            versions_seen=checkpoint_data["versions_seen"],
            parent_checkpoint_id=checkpoint_data["parent_checkpoint_id"],
            metadata=checkpoint_data.get("metadata", {})
        )
    
    def put(self, config: Dict[str, Any], checkpoint: Checkpoint) -> Dict[str, Any]:
        """保存检查点"""
        thread_id = config["configurable"]["thread_id"]
        
        # 生成检查点ID
        import time
        checkpoint_id = str(int(time.time() * 1000))  # 使用时间戳作为ID
        
        checkpoint_data = {
            "ts": checkpoint.ts,
            "channel_values": checkpoint.channel_values,
            "channel_versions": checkpoint.channel_versions,
            "versions_seen": checkpoint.versions_seen,
            "parent_checkpoint_id": checkpoint.parent_checkpoint_id,
            "metadata": checkpoint.metadata
        }
        
        file_path = os.path.join(self.storage_path, f"{thread_id}_{checkpoint_id}.pkl")
        
        with open(file_path, "wb") as f:
            pickle.dump(checkpoint_data, f)
        
        # 也保存为latest
        latest_path = os.path.join(self.storage_path, f"{thread_id}_latest.pkl")
        with open(latest_path, "wb") as f:
            pickle.dump(checkpoint_data, f)
        
        return {"checkpoint_id": checkpoint_id}
    
    def list(self, config: Dict[str, Any]) -> List[CheckpointMetadata]:
        """列出检查点"""
        thread_id = config["configurable"]["thread_id"]
        checkpoints = []
        
        for filename in os.listdir(self.storage_path):
            if filename.startswith(f"{thread_id}_") and filename.endswith(".pkl"):
                checkpoint_id = filename.replace(f"{thread_id}_", "").replace(".pkl", "")
                if checkpoint_id == "latest":
                    continue
                
                file_path = os.path.join(self.storage_path, filename)
                with open(file_path, "rb") as f:
                    checkpoint_data = pickle.load(f)
                
                checkpoints.append(CheckpointMetadata(
                    checkpoint_id=checkpoint_id,
                    ts=checkpoint_data["ts"],
                    parent_checkpoint_id=checkpoint_data["parent_checkpoint_id"]
                ))
        
        return sorted(checkpoints, key=lambda x: x.ts, reverse=True)

# 使用自定义检查点存储
def build_graph_with_file_checkpoints():
    """构建使用文件检查点的图"""
    workflow = StateGraph(dict)
    
    # 添加简单节点
    def step1(state):
        state["step"] = 1
        state["message"] = "完成步骤1"
        return state
    
    def step2(state):
        state["step"] = 2
        state["message"] = "完成步骤2"
        return state
    
    workflow.add_node("step1", step1)
    workflow.add_node("step2", step2)
    workflow.add_edge(START, "step1")
    workflow.add_edge("step1", "step2")
    workflow.add_edge("step2", END)
    
    # 使用自定义文件检查点
    file_checkpointer = FileCheckpointSaver("./my_checkpoints")
    
    # 编译图
    graph = workflow.compile(checkpointer=file_checkpointer)
    
    return graph

# 测试文件检查点
if __name__ == "__main__":
    graph = build_graph_with_file_checkpoints()
    
    thread_id = "test_thread_001"
    config = {"configurable": {"thread_id": thread_id}}
    
    print("=== 第一次执行 ===")
    initial_state = {"step": 0, "message": "开始"}
    result = graph.invoke(initial_state, config)
    print(f"结果: {result}")
    
    print("\n=== 检查文件检查点 ===")
    import glob
    checkpoint_files = glob.glob(f"./my_checkpoints/{thread_id}_*.pkl")
    print(f"检查点文件: {checkpoint_files}")
    
    for file in checkpoint_files:
        print(f"\n文件: {file}")
        with open(file, "rb") as f:
            data = pickle.load(f)
            print(f"  步骤: {data['channel_values'].get('step', 'N/A')}")
            print(f"  消息: {data['channel_values'].get('message', 'N/A')}")

2.3 记忆

记忆是智能代理系统的核心能力之一。LangGraph提供了强大的记忆机制,使代理能够记住先前的交互、保持对话上下文,并基于历史信息做出更好的决策。

2.3.1 短期记忆与长期记忆

在LangGraph中,记忆可以分为两种类型:

  • 短期记忆:保持在单个工作流执行期间
  • 长期记忆:跨多个会话或执行周期保持
# 示例5:记忆系统实现
from datetime import datetime, timedelta
from langchain.memory import ConversationBufferMemory, ConversationSummaryMemory
from langchain_openai import ChatOpenAI
import hashlib

class MemoryState(TypedDict):
    """记忆状态"""
    user_input: str
    conversation_history: Annotated[list, operator.add]  # 完整对话历史
    summary: str  # 对话摘要
    context: dict  # 上下文信息
    user_preferences: dict  # 用户偏好
    last_interaction_time: str  # 最后交互时间

class MemoryEnhancedAgent:
    """记忆增强代理"""
    
    def __init__(self):
        # 初始化不同记忆存储
        self.buffer_memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True
        )
        
        self.summary_memory = ConversationSummaryMemory.from_messages(
            llm=ChatOpenAI(temperature=0),
            memory_key="summary",
            return_messages=True
        )
        
        # 用户偏好记忆(模拟数据库)
        self.user_preferences_db = {}
        
        # 上下文缓存
        self.context_cache = {}
    
    def get_user_id(self, user_input: str, session_id: str = "default") -> str:
        """生成用户标识(简化版)"""
        # 在实际应用中,这里可能使用实际用户ID
        return f"user_{hashlib.md5(session_id.encode()).hexdigest()[:8]}"
    
    def update_conversation_history(self, state: MemoryState) -> MemoryState:
        """更新对话历史"""
        user_input = state["user_input"]
        
        # 添加到对话历史
        history_entry = {
            "role": "user",
            "content": user_input,
            "timestamp": datetime.now().isoformat()
        }
        state["conversation_history"].append(history_entry)
        
        # 更新最后交互时间
        state["last_interaction_time"] = datetime.now().isoformat()
        
        print(f"[记忆更新] 用户输入已添加到历史,历史长度: {len(state['conversation_history'])}")
        return state
    
    def generate_summary(self, state: MemoryState) -> MemoryState:
        """生成对话摘要"""
        if len(state["conversation_history"]) >= 3:  # 至少有3条消息时生成摘要
            # 提取最近的对话
            recent_messages = state["conversation_history"][-5:]  # 最近5条
            
            # 构建对话文本
            conversation_text = ""
            for msg in recent_messages:
                role = "用户" if msg["role"] == "user" else "助手"
                conversation_text += f"{role}: {msg['content']}\n"
            
            # 使用LLM生成摘要
            llm = ChatOpenAI(temperature=0.3)
            
            summary_prompt = f"""请总结以下对话的核心内容:
            
            {conversation_text}
            
            摘要(50字以内):"""
            
            try:
                response = llm.invoke(summary_prompt)
                state["summary"] = response.content
                print(f"[摘要生成] 新摘要: {state['summary'][:50]}...")
            except Exception as e:
                print(f"[摘要生成] 错误: {e}")
        
        return state
    
    def update_user_preferences(self, state: MemoryState) -> MemoryState:
        """更新用户偏好"""
        user_input = state["user_input"].lower()
        user_id = self.get_user_id(user_input)
        
        # 初始化用户偏好
        if user_id not in self.user_preferences_db:
            self.user_preferences_db[user_id] = {
                "topics": set(),
                "interaction_count": 0,
                "last_topics": []
            }
        
        # 检测用户感兴趣的话题
        topics_of_interest = ["天气", "时间", "新闻", "音乐", "电影", "体育", "科技", "美食"]
        
        detected_topics = []
        for topic in topics_of_interest:
            if topic in user_input:
                detected_topics.append(topic)
                self.user_preferences_db[user_id]["topics"].add(topic)
        
        # 更新交互计数
        self.user_preferences_db[user_id]["interaction_count"] += 1
        
        # 记录最近话题
        if detected_topics:
            self.user_preferences_db[user_id]["last_topics"].extend(detected_topics)
            # 只保留最近10个话题
            self.user_preferences_db[user_id]["last_topics"] = self.user_preferences_db[user_id]["last_topics"][-10:]
        
        # 更新状态中的用户偏好
        state["user_preferences"] = {
            "user_id": user_id,
            "preferred_topics": list(self.user_preferences_db[user_id]["topics"]),
            "interaction_count": self.user_preferences_db[user_id]["interaction_count"],
            "recent_topics": self.user_preferences_db[user_id]["last_topics"][-3:]  # 最近3个话题
        }
        
        print(f"[偏好更新] 用户{user_id} 偏好: {state['user_preferences']['preferred_topics']}")
        return state
    
    def retrieve_context(self, state: MemoryState) -> MemoryState:
        """检索相关上下文"""
        user_input = state["user_input"]
        user_id = self.get_user_id(user_input)
        
        # 初始化上下文
        if "context" not in state:
            state["context"] = {}
        
        # 添加上下文信息
        state["context"].update({
            "current_time": datetime.now().isoformat(),
            "conversation_length": len(state["conversation_history"]),
            "has_history": len(state["conversation_history"]) > 0
        })
        
        # 如果有对话历史,添加上下文
        if state["conversation_history"]:
            # 添加最近的消息作为上下文
            recent_messages = state["conversation_history"][-3:]  # 最近3条
            state["context"]["recent_messages"] = [
                {"role": msg["role"], "content": msg["content"][:50]} 
                for msg in recent_messages
            ]
        
        # 如果有用户偏好,添加到上下文
        if state.get("user_preferences"):
            prefs = state["user_preferences"]
            state["context"]["user_info"] = {
                "preferred_topics": prefs.get("preferred_topics", []),
                "interaction_count": prefs.get("interaction_count", 0)
            }
        
        print(f"[上下文检索] 上下文信息: {list(state['context'].keys())}")
        return state
    
    def generate_response_with_memory(self, state: MemoryState) -> MemoryState:
        """使用记忆生成响应"""
        llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo")
        
        # 构建包含记忆的提示
        prompt_parts = []
        
        # 1. 添加系统提示
        system_prompt = """你是一个有帮助的AI助手。请基于对话历史和用户偏好提供个性化的回答。"""
        prompt_parts.append(f"系统: {system_prompt}")
        
        # 2. 添加对话摘要
        if state.get("summary"):
            prompt_parts.append(f"对话摘要: {state['summary']}")
        
        # 3. 添加上下文信息
        if state.get("context"):
            context = state["context"]
            prompt_parts.append(f"上下文:")
            prompt_parts.append(f"- 当前对话长度: {context.get('conversation_length', 0)} 条消息")
            
            if context.get("user_info"):
                user_info = context["user_info"]
                if user_info.get("preferred_topics"):
                    prompt_parts.append(f"- 用户感兴趣的话题: {', '.join(user_info['preferred_topics'])}")
                if user_info.get("interaction_count", 0) > 1:
                    prompt_parts.append(f"- 这是与用户的第{user_info['interaction_count']}次交互")
        
        # 4. 添加当前用户输入
        prompt_parts.append(f"用户: {state['user_input']}")
        prompt_parts.append("助手:")
        
        full_prompt = "\n".join(prompt_parts)
        
        # 生成响应
        response = llm.invoke(full_prompt)
        
        # 将助手响应添加到历史
        assistant_entry = {
            "role": "assistant",
            "content": response.content,
            "timestamp": datetime.now().isoformat()
        }
        state["conversation_history"].append(assistant_entry)
        
        # 设置响应
        state["ai_response"] = response.content
        
        print(f"[响应生成] 基于记忆生成响应,提示长度: {len(full_prompt)} 字符")
        return state

def build_memory_graph():
    """构建记忆增强图"""
    agent = MemoryEnhancedAgent()
    
    workflow = StateGraph(MemoryState)
    
    # 添加节点
    workflow.add_node("update_history", agent.update_conversation_history)
    workflow.add_node("update_preferences", agent.update_user_preferences)
    workflow.add_node("retrieve_context", agent.retrieve_context)
    workflow.add_node("generate_summary", agent.generate_summary)
    workflow.add_node("generate_response", agent.generate_response_with_memory)
    
    # 设置执行流程
    workflow.add_edge(START, "update_history")
    workflow.add_edge("update_history", "update_preferences")
    workflow.add_edge("update_preferences", "retrieve_context")
    
    # 根据对话历史长度决定是否生成摘要
    def should_summarize(state: MemoryState) -> str:
        """判断是否需要生成摘要"""
        if len(state["conversation_history"]) >= 3 and len(state["conversation_history"]) % 3 == 0:
            return "generate_summary"
        else:
            return "generate_response"
    
    workflow.add_conditional_edges(
        "retrieve_context",
        should_summarize,
        {
            "generate_summary": "generate_summary",
            "generate_response": "generate_response"
        }
    )
    
    workflow.add_edge("generate_summary", "generate_response")
    workflow.add_edge("generate_response", END)
    
    return workflow.compile()

# 测试记忆系统
if __name__ == "__main__":
    memory_graph = build_memory_graph()
    
    # 模拟对话
    test_conversations = [
        "你好,今天天气怎么样?",
        "我喜欢音乐和电影",
        "你能推荐一些好电影吗?",
        "我还对科技新闻感兴趣",
        "最近有什么科技新闻?"
    ]
    
    initial_state = {
        "user_input": "",
        "conversation_history": [],
        "summary": "",
        "context": {},
        "user_preferences": {},
        "last_interaction_time": "",
        "ai_response": ""
    }
    
    current_state = initial_state
    
    for i, user_input in enumerate(test_conversations, 1):
        print(f"\n{'='*50}")
        print(f"对话轮次 {i}: {user_input}")
        print(f"{'='*50}")
        
        current_state["user_input"] = user_input
        result = memory_graph.invoke(current_state)
        
        print(f"\n助手回复: {result['ai_response'][:100]}...")
        print(f"对话历史长度: {len(result['conversation_history'])}")
        print(f"用户偏好: {result.get('user_preferences', {}).get('preferred_topics', [])}")
        
        if result.get("summary"):
            print(f"对话摘要: {result['summary'][:50]}...")
        
        # 更新状态继续下一轮
        current_state = result
    
    print(f"\n{'='*50}")
    print("对话结束")
    print(f"最终对话历史: {len(current_state['conversation_history'])} 条消息")
    print(f"用户偏好话题: {current_state.get('user_preferences', {}).get('preferred_topics', [])}")

2.4 人机交互

在现实世界的AI应用中,人机交互是至关重要的。LangGraph提供了多种机制来支持人类在循环(Human-in-the-Loop)的交互模式,使AI系统能够在需要时暂停执行并等待人类输入。

2.4.1 中断与恢复机制

LangGraph允许在工作流的特定点中断执行,等待人类输入,然后恢复执行。这种能力对于需要人工审批、澄清或额外信息的场景特别有用。

# 示例6:人机交互与中断机制
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint import MemorySaver
from typing import Literal
import asyncio

class InteractionState(TypedDict):
    """交互状态"""
    user_query: str
    current_step: str
    requires_human_input: bool
    human_input: str
    steps_completed: Annotated[list, operator.add]
    response: str
    validation_result: str
    is_approved: bool

def receive_query_node(state: InteractionState) -> InteractionState:
    """接收用户查询节点"""
    print(f"[接收查询] 用户查询: {state['user_query']}")
    state["current_step"] = "query_received"
    state["steps_completed"].append("接收查询")
    return state

def analyze_query_node(state: InteractionState) -> InteractionState:
    """分析查询节点"""
    query = state["user_query"].lower()
    
    # 检测是否需要人工干预的情况
    needs_human = False
    reason = ""
    
    if "敏感" in query or "confidential" in query:
        needs_human = True
        reason = "查询包含敏感信息"
    elif "管理员" in query or "admin" in query:
        needs_human = True
        reason = "需要管理员权限"
    elif len(query.split()) > 50:  # 长查询
        needs_human = True
        reason = "查询过长,需要简化"
    elif "批准" in query or "approve" in query:
        needs_human = True
        reason = "需要人工批准"
    
    state["requires_human_input"] = needs_human
    if needs_human:
        state["response"] = f"需要人工干预: {reason}"
        print(f"[分析查询] 需要人工干预: {reason}")
    else:
        state["response"] = "查询分析完成,可以继续处理"
        print(f"[分析查询] 查询分析完成,无需人工干预")
    
    state["current_step"] = "query_analyzed"
    state["steps_completed"].append("分析查询")
    return state

def wait_for_human_node(state: InteractionState) -> InteractionState:
    """等待人工输入节点"""
    print(f"[等待人工输入] 当前状态: {state['response']}")
    
    # 在实际应用中,这里会等待人工输入
    # 为示例,我们模拟人工输入
    if not state.get("human_input"):
        # 模拟等待人类输入
        print("[等待人工输入] 模拟等待人工输入...")
        
        # 模拟不同的处理逻辑
        if "敏感" in state["user_query"]:
            state["human_input"] = "已审核,可以继续但需记录日志"
            state["is_approved"] = True
        elif "管理员" in state["user_query"]:
            state["human_input"] = "权限已授予"
            state["is_approved"] = True
        elif "批准" in state["user_query"]:
            state["human_input"] = "已批准"
            state["is_approved"] = True
        else:
            state["human_input"] = "已审阅,继续处理"
            state["is_approved"] = True
    
    state["current_step"] = "human_input_received"
    state["steps_completed"].append("接收人工输入")
    return state

def process_with_approval_node(state: InteractionState) -> InteractionState:
    """带批准的处理节点"""
    if state.get("is_approved", False):
        state["response"] = f"已获得批准: {state['human_input']}。继续处理查询: {state['user_query'][:50]}..."
        print(f"[处理查询] 已批准,继续处理")
    else:
        state["response"] = "未获得批准,终止处理"
        print(f"[处理查询] 未获批准,终止")
    
    state["current_step"] = "processed_with_approval"
    state["steps_completed"].append("批准后处理")
    return state

def process_automatically_node(state: InteractionState) -> InteractionState:
    """自动处理节点"""
    state["response"] = f"自动处理查询: {state['user_query'][:50]}..."
    print(f"[自动处理] 处理查询")
    
    state["current_step"] = "processed_automatically"
    state["steps_completed"].append("自动处理")
    return state

def validate_result_node(state: InteractionState) -> InteractionState:
    """验证结果节点"""
    # 模拟结果验证
    if "error" in state["response"].lower() or "失败" in state["response"]:
        state["validation_result"] = "验证失败"
        state["requires_human_input"] = True  # 验证失败需要人工干预
        print(f"[验证结果] 验证失败,需要人工干预")
    else:
        state["validation_result"] = "验证成功"
        print(f"[验证结果] 验证成功")
    
    state["current_step"] = "result_validated"
    state["steps_completed"].append("验证结果")
    return state

def build_interaction_graph():
    """构建交互式图"""
    workflow = StateGraph(InteractionState)
    
    # 添加所有节点
    workflow.add_node("receive_query", receive_query_node)
    workflow.add_node("analyze_query", analyze_query_node)
    workflow.add_node("wait_for_human", wait_for_human_node)
    workflow.add_node("process_with_approval", process_with_approval_node)
    workflow.add_node("process_automatically", process_automatically_node)
    workflow.add_node("validate_result", validate_result_node)
    
    # 设置入口点
    workflow.set_entry_point("receive_query")
    
    # 添加边
    workflow.add_edge("receive_query", "analyze_query")
    
    # 条件边:根据是否需要人工干预路由
    def route_by_analysis(state: InteractionState) -> Literal["wait_for_human", "process_automatically"]:
        if state["requires_human_input"]:
            return "wait_for_human"
        else:
            return "process_automatically"
    
    workflow.add_conditional_edges(
        "analyze_query",
        route_by_analysis,
        {
            "wait_for_human": "wait_for_human",
            "process_automatically": "process_automatically"
        }
    )
    
    # 人工输入后继续处理
    workflow.add_edge("wait_for_human", "process_with_approval")
    
    # 处理完成后验证
    workflow.add_edge("process_with_approval", "validate_result")
    workflow.add_edge("process_automatically", "validate_result")
    
    # 验证后根据结果路由
    def route_after_validation(state: InteractionState) -> Literal["wait_for_human", "__end__"]:
        if state["requires_human_input"] and state["validation_result"] == "验证失败":
            return "wait_for_human"
        else:
            return "__end__"
    
    workflow.add_conditional_edges(
        "validate_result",
        route_after_validation,
        {
            "wait_for_human": "wait_for_human",
            "__end__": END
        }
    )
    
    # 使用检查点实现中断/恢复
    memory = MemorySaver()
    graph = workflow.compile(checkpointer=memory)
    
    return graph, memory

# 模拟交互式执行
async def simulate_human_in_the_loop():
    """模拟人在循环的交互"""
    graph, memory = build_interaction_graph()
    
    # 测试用例
    test_cases = [
        {
            "query": "我想查询今天的天气",
            "description": "普通查询,无需人工干预"
        },
        {
            "query": "我需要访问敏感数据报告",
            "description": "敏感查询,需要人工批准"
        },
        {
            "query": "请批准我的管理员权限申请",
            "description": "权限申请,需要人工批准"
        }
    ]
    
    for i, test_case in enumerate(test_cases, 1):
        print(f"\n{'='*60}")
        print(f"测试用例 {i}: {test_case['description']}")
        print(f"查询: {test_case['query']}")
        print(f"{'='*60}")
        
        # 初始状态
        initial_state = {
            "user_query": test_case["query"],
            "current_step": "",
            "requires_human_input": False,
            "human_input": "",
            "steps_completed": [],
            "response": "",
            "validation_result": "",
            "is_approved": False
        }
        
        # 配置
        thread_id = f"test_thread_{i}"
        config = {"configurable": {"thread_id": thread_id}}
        
        # 执行图
        print("\n开始执行工作流...")
        
        # 流式执行,可以看到每一步
        async for event in graph.astream(initial_state, config, stream_mode="values"):
            print(f"\n[步骤更新] 当前步骤: {event['current_step']}")
            print(f"[步骤更新] 响应: {event['response']}")
            print(f"[步骤更新] 完成步骤: {event['steps_completed'][-1] if event['steps_completed'] else '无'}")
            
            # 模拟人工输入延迟
            if event["current_step"] == "human_input_received":
                print("[模拟交互] 模拟人类审核中...")
                await asyncio.sleep(1)
        
        # 获取最终结果
        print(f"\n测试用例 {i} 完成")
        print(f"最终响应: {event['response']}")
        print(f"总步骤: {', '.join(event['steps_completed'])}")
        
        # 显示检查点
        checkpoints = list(memory.list(config))
        print(f"创建的检查点: {len(checkpoints)} 个")

# 运行示例
if __name__ == "__main__":
    asyncio.run(simulate_human_in_the_loop())

2.5 状态管理

状态管理是LangGraph的核心特性之一。它提供了一个类型安全、可组合的状态管理系统,使得在复杂的、有状态的工作流中跟踪和管理数据变得简单而直观。

2.5.1 状态结构与注解

LangGraph使用类型化的状态字典和Python的Annotated类型来定义状态结构。这种设计提供了类型安全性和良好的开发体验。

# 示例7:高级状态管理
from typing import TypedDict, Annotated
from typing_extensions import TypedDict
import operator
from enum import Enum
from datetime import datetime
from langgraph.graph import StateGraph, END

# 定义状态枚举
class TaskStatus(Enum):
    PENDING = "pending"
    PROCESSING = "processing"
    WAITING_FOR_INPUT = "waiting_for_input"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"

class TaskPriority(Enum):
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"
    CRITICAL = "critical"

# 定义复杂的状态结构
class TaskState(TypedDict):
    """任务状态定义"""
    # 基本任务信息
    task_id: str
    task_name: str
    description: str
    created_at: str
    updated_at: Annotated[str, lambda old, new: new]  # 总是更新为新值
    status: TaskStatus
    priority: TaskPriority
    
    # 进度跟踪
    current_step: int
    total_steps: int
    progress_percentage: float
    steps_completed: Annotated[list, operator.add]  # 使用追加操作
    
    # 数据流
    input_data: dict
    intermediate_results: Annotated[dict, lambda old, new: {**old, **new}]  # 合并字典
    final_result: dict
    
    # 错误处理
    errors: Annotated[list, operator.add]
    retry_count: int
    max_retries: int
    
    # 元数据
    metadata: Annotated[dict, lambda old, new: {**old, **new}]
    tags: Annotated[set, lambda old, new: old.union(new) if isinstance(new, set) else old.union({new})]  # 合并集合

class TaskManager:
    """任务管理器"""
    
    def __init__(self):
        self.task_counter = 0
    
    def generate_task_id(self) -> str:
        """生成任务ID"""
        self.task_counter += 1
        return f"task_{self.task_counter:06d}"
    
    def create_task_node(self, state: TaskState) -> TaskState:
        """创建任务节点"""
        if not state.get("task_id"):
            state["task_id"] = self.generate_task_id()
        
        current_time = datetime.now().isoformat()
        
        # 设置默认值
        defaults = {
            "created_at": current_time,
            "updated_at": current_time,
            "status": TaskStatus.PENDING,
            "priority": TaskPriority.MEDIUM,
            "current_step": 0,
            "total_steps": 5,  # 默认5个步骤
            "progress_percentage": 0.0,
            "steps_completed": [],
            "input_data": {},
            "intermediate_results": {},
            "final_result": {},
            "errors": [],
            "retry_count": 0,
            "max_retries": 3,
            "metadata": {},
            "tags": set()
        }
        
        # 应用默认值
        for key, value in defaults.items():
            if key not in state or not state[key]:
                state[key] = value
        
        # 添加创建步骤
        state["steps_completed"].append("任务创建")
        state["current_step"] = 1
        state["progress_percentage"] = state["current_step"] / state["total_steps"] * 100
        
        print(f"[创建任务] 任务ID: {state['task_id']}, 名称: {state.get('task_name', '未命名')}")
        return state
    
    def validate_input_node(self, state: TaskState) -> TaskState:
        """验证输入节点"""
        state["updated_at"] = datetime.now().isoformat()
        state["status"] = TaskStatus.PROCESSING
        
        # 验证逻辑
        errors = []
        
        if not state.get("task_name"):
            errors.append("任务名称不能为空")
        
        if not state.get("description"):
            errors.append("任务描述不能为空")
        
        if state.get("priority") not in [p.value for p in TaskPriority]:
            errors.append(f"无效的优先级: {state.get('priority')}")
        
        if errors:
            state["errors"].extend(errors)
            state["status"] = TaskStatus.FAILED
            print(f"[验证输入] 验证失败: {errors}")
        else:
            state["steps_completed"].append("输入验证")
            state["current_step"] = 2
            state["progress_percentage"] = state["current_step"] / state["total_steps"] * 100
            print(f"[验证输入] 验证成功")
        
        return state
    
    def process_data_node(self, state: TaskState) -> TaskState:
        """处理数据节点"""
        state["updated_at"] = datetime.now().isoformat()
        
        if state["status"] == TaskStatus.FAILED:
            print(f"[处理数据] 跳过,因为任务失败")
            return state
        
        try:
            # 模拟数据处理
            print(f"[处理数据] 处理任务: {state['task_name']}")
            
            # 添加中间结果
            state["intermediate_results"]["processed_at"] = state["updated_at"]
            state["intermediate_results"]["step"] = "data_processing"
            
            # 模拟处理时间
            import time
            time.sleep(0.5)
            
            state["steps_completed"].append("数据处理")
            state["current_step"] = 3
            state["progress_percentage"] = state["current_step"] / state["total_steps"] * 100
            
        except Exception as e:
            state["errors"].append(f"数据处理错误: {str(e)}")
            state["status"] = TaskStatus.FAILED
        
        return state
    
    def analyze_results_node(self, state: TaskState) -> TaskState:
        """分析结果节点"""
        state["updated_at"] = datetime.now().isoformat()
        
        if state["status"] == TaskStatus.FAILED:
            print(f"[分析结果] 跳过,因为任务失败")
            return state
        
        try:
            # 模拟结果分析
            print(f"[分析结果] 分析任务结果")
            
            # 添加分析结果
            state["intermediate_results"]["analyzed_at"] = state["updated_at"]
            state["intermediate_results"]["step"] = "analysis"
            
            # 基于优先级的模拟分析
            if state["priority"] == TaskPriority.HIGH.value or state["priority"] == TaskPriority.CRITICAL.value:
                state["intermediate_results"]["analysis"] = "深度分析完成"
            else:
                state["intermediate_results"]["analysis"] = "基础分析完成"
            
            state["steps_completed"].append("结果分析")
            state["current_step"] = 4
            state["progress_percentage"] = state["current_step"] / state["total_steps"] * 100
            
        except Exception as e:
            state["errors"].append(f"结果分析错误: {str(e)}")
            state["status"] = TaskStatus.FAILED
        
        return state
    
    def generate_output_node(self, state: TaskState) -> TaskState:
        """生成输出节点"""
        state["updated_at"] = datetime.now().isoformat()
        
        if state["status"] == TaskStatus.FAILED:
            print(f"[生成输出] 跳过,因为任务失败")
            return state
        
        try:
            # 生成最终输出
            print(f"[生成输出] 生成最终结果")
            
            state["final_result"] = {
                "task_id": state["task_id"],
                "task_name": state["task_name"],
                "status": state["status"].value,
                "completed_steps": state["steps_completed"],
                "progress": f"{state['progress_percentage']:.1f}%",
                "intermediate_results": state["intermediate_results"],
                "completed_at": state["updated_at"],
                "has_errors": len(state["errors"]) > 0,
                "error_count": len(state["errors"])
            }
            
            if state["errors"]:
                state["final_result"]["errors"] = state["errors"]
            
            state["steps_completed"].append("生成输出")
            state["current_step"] = 5
            state["progress_percentage"] = state["current_step"] / state["total_steps"] * 100
            state["status"] = TaskStatus.COMPLETED
            
        except Exception as e:
            state["errors"].append(f"生成输出错误: {str(e)}")
            state["status"] = TaskStatus.FAILED
        
        return state
    
    def handle_errors_node(self, state: TaskState) -> TaskState:
        """处理错误节点"""
        state["updated_at"] = datetime.now().isoformat()
        
        if state["status"] != TaskStatus.FAILED:
            print(f"[处理错误] 无错误需要处理")
            return state
        
        print(f"[处理错误] 处理任务错误")
        
        # 重试逻辑
        state["retry_count"] += 1
        
        if state["retry_count"] <= state["max_retries"]:
            print(f"[处理错误] 重试 {state['retry_count']}/{state['max_retries']}")
            state["status"] = TaskStatus.PENDING
            state["current_step"] = 1  # 重置到第一步
            state["progress_percentage"] = 0.0
            # 清除之前的错误(或保留用于日志)
            # state["errors"] = []
        else:
            print(f"[处理错误] 达到最大重试次数 {state['max_retries']}")
            state["status"] = TaskStatus.FAILED
            state["final_result"] = {
                "task_id": state["task_id"],
                "status": "failed",
                "error_count": len(state["errors"]),
                "errors": state["errors"],
                "retry_count": state["retry_count"]
            }
        
        state["steps_completed"].append("错误处理")
        return state

def build_state_management_graph():
    """构建状态管理图"""
    task_manager = TaskManager()
    
    workflow = StateGraph(TaskState)
    
    # 添加节点
    workflow.add_node("create_task", task_manager.create_task_node)
    workflow.add_node("validate_input", task_manager.validate_input_node)
    workflow.add_node("process_data", task_manager.process_data_node)
    workflow.add_node("analyze_results", task_manager.analyze_results_node)
    workflow.add_node("generate_output", task_manager.generate_output_node)
    workflow.add_node("handle_errors", task_manager.handle_errors_node)
    
    # 设置入口点
    workflow.set_entry_point("create_task")
    
    # 主流程边
    workflow.add_edge("create_task", "validate_input")
    workflow.add_edge("validate_input", "process_data")
    workflow.add_edge("process_data", "analyze_results")
    workflow.add_edge("analyze_results", "generate_output")
    
    # 条件边:根据状态路由
    def route_by_status(state: TaskState) -> str:
        """根据状态路由"""
        if state["status"] == TaskStatus.FAILED:
            return "handle_errors"
        elif state["status"] == TaskStatus.COMPLETED:
            return "__end__"
        else:
            # 继续下一个节点
            if state["current_step"] == 1:
                return "validate_input"
            elif state["current_step"] == 2:
                return "process_data"
            elif state["current_step"] == 3:
                return "analyze_results"
            elif state["current_step"] == 4:
                return "generate_output"
            else:
                return "__end__"
    
    # 为每个节点添加条件边
    for node in ["validate_input", "process_data", "analyze_results"]:
        workflow.add_conditional_edges(
            node,
            route_by_status,
            {
                "handle_errors": "handle_errors",
                "validate_input": "validate_input",
                "process_data": "process_data",
                "analyze_results": "analyze_results",
                "generate_output": "generate_output",
                "__end__": END
            }
        )
    
    # 错误处理后的路由
    def route_after_error_handling(state: TaskState) -> str:
        """错误处理后的路由"""
        if state["status"] == TaskStatus.PENDING:
            # 重试,回到验证输入
            return "validate_input"
        else:
            # 失败,结束
            return "__end__"
    
    workflow.add_conditional_edges(
        "handle_errors",
        route_after_error_handling,
        {
            "validate_input": "validate_input",
            "__end__": END
        }
    )
    
    # 生成输出后的路由
    workflow.add_conditional_edges(
        "generate_output",
        route_by_status,
        {
            "handle_errors": "handle_errors",
            "__end__": END
        }
    )
    
    return workflow.compile()

# 测试状态管理
if __name__ == "__main__":
    state_graph = build_state_management_graph()
    
    # 测试用例
    test_tasks = [
        {
            "name": "有效数据处理任务",
            "description": "处理用户数据并生成报告",
            "priority": TaskPriority.HIGH.value,
            "should_fail": False
        },
        {
            "name": "",
            "description": "无效任务(缺少名称)",
            "priority": TaskPriority.MEDIUM.value,
            "should_fail": True
        },
        {
            "name": "高优先级分析任务",
            "description": "执行深度数据分析",
            "priority": TaskPriority.CRITICAL.value,
            "should_fail": False
        }
    ]
    
    for i, task_config in enumerate(test_tasks, 1):
        print(f"\n{'='*60}")
        print(f"执行任务 {i}: {task_config['name'] or '未命名任务'}")
        print(f"描述: {task_config['description']}")
        print(f"优先级: {task_config['priority']}")
        print(f"{'='*60}")
        
        # 准备状态
        initial_state = {
            "task_name": task_config["name"],
            "description": task_config["description"],
            "priority": task_config["priority"],
            "input_data": {
                "source": "test_data",
                "records": 1000
            },
            "tags": {"test", f"priority_{task_config['priority']}"},
            "metadata": {
                "test_id": i,
                "created_by": "state_management_test"
            }
        }
        
        if task_config["should_fail"]:
            # 模拟会导致失败的状态
            initial_state["task_name"] = ""  # 空名称会导致验证失败
        
        # 执行图
        result = state_graph.invoke(initial_state)
        
        # 显示结果
        print(f"\n任务结果:")
        print(f"  任务ID: {result.get('task_id', 'N/A')}")
        print(f"  最终状态: {result.get('status', 'N/A')}")
        print(f"  进度: {result.get('progress_percentage', 0):.1f}%")
        print(f"  完成步骤: {len(result.get('steps_completed', []))}/{result.get('total_steps', 0)}")
        print(f"  重试次数: {result.get('retry_count', 0)}")
        
        if result.get("errors"):
            print(f"  错误: {len(result['errors'])} 个")
            for j, error in enumerate(result["errors"][:3], 1):  # 显示前3个错误
                print(f"    {j}. {error}")
            if len(result["errors"]) > 3:
                print(f"    ... 还有 {len(result['errors']) - 3} 个错误")
        
        if result.get("final_result"):
            print(f"  最终结果键: {list(result['final_result'].keys())}")
        
        print(f"  标签: {result.get('tags', set())}")

三、课后练习

3.1 选择题

  1. 在LangGraph中,哪个组件用于定义工作流中不同步骤的执行顺序?
    A. 节点
    B. 边
    C. 状态
    D. 工具

  2. 以下哪个不是LangGraph内置的持久化存储后端?
    A. MemorySaver
    B. FileCheckpointSaver
    C. RedisCheckpointSaver
    D. 以上都是内置的

  3. 在LangGraph中,如何实现工作流的中断和恢复?
    A. 使用中断函数
    B. 通过检查点机制
    C. 使用异常处理
    D. 无法实现

  4. 以下哪个工具调用模式允许代理在需要时自动选择并执行工具?
    A. 手动工具调用
    B. 自动工具调用
    C. 条件工具调用
    D. 循环工具调用

  5. LangGraph中,哪种记忆类型主要用于保持对话的长期上下文?
    A. 短期记忆
    B. 工作记忆
    C. 长期记忆
    D. 瞬时记忆

  6. 在状态管理中,使用Annotated[str, operator.add]注解表示什么?
    A. 字符串类型的状态字段
    B. 可以追加操作的字符串字段
    C. 数值累加字段
    D. 只读字符串字段

  7. LangGraph中人机交互的主要实现方式是什么?
    A. 通过API调用
    B. 通过中断和恢复机制
    C. 通过消息队列
    D. 通过数据库

  8. 以下哪个是LangGraph状态管理的特点?
    A. 类型安全
    B. 不可变性
    C. 全局共享
    D. 无状态

答案:

  1. B
  2. C(RedisCheckpointSaver不是内置的,但可以实现自定义)
  3. B
  4. C
  5. C
  6. B
  7. B
  8. A

3.2 填空题

  1. LangGraph 使用 ______ 和 ______ 来构建工作流,其中节点代表处理步骤,边定义步骤之间的流转逻辑。
  2. 在LangGraph中,______ 允许代理记住先前的交互,保持对话上下文。
  3. ______ 机制允许工作流在中断后恢复执行,这对于长时间运行的任务特别重要。
  4. LangGraph的 ______ 系统提供了类型安全、可组合的状态管理。
  5. 通过 ______ 功能,LangGraph代理可以执行搜索、计算、API调用等实际操作。
  6. 在条件边中,通过返回不同的 ______ 值来决定下一步执行哪个节点。
  7. ______ 类型的记忆通常保持在单个工作流执行期间,而 ______ 记忆可以跨多个会话保持。

答案:

  1. 节点,边
  2. 记忆
  3. 检查点(或持久性)
  4. 状态管理
  5. 调用工具
  6. 字符串(或路由键)
  7. 短期,长期

3.3 简答题

  1. 简述LangGraph中状态管理的工作原理。

答案:
LangGraph使用类型化的状态字典和Python的Annotated类型来定义状态结构。状态在整个图中传递,每个节点可以读取和修改状态。状态更新通过定义在Annotated中的归约函数来处理,这些函数指定了如何合并状态更新。例如:

  • Annotated[list, operator.add]:使用operator.add追加列表
  • Annotated[dict, lambda old, new: {**old, **new}]:合并字典
  • Annotated[str, lambda old, new: new]:总是使用新值替换旧值
    这种设计提供了类型安全性和可预测的状态更新,使得复杂状态管理变得简单直观。
  1. 解释LangGraph中持久性机制的重要性及其实现方式。

答案: 持久性机制允许工作流在中断后恢复执行,这对于长时间运行的任务、对话系统和需要处理中断的应用程序至关重要。实现方式包括:

  • 检查点(Checkpoints):保存图的完整状态,包括所有变量的值和当前的执行位置
  • 检查点存储后端:如MemorySaver(内存存储)、FileCheckpointSaver(文件存储)等
  • 线程ID:用于标识不同的执行会话
    持久性使得工作流可以在意外中断(如网络故障、系统崩溃)后从中断点继续执行,提高了系统的可靠性和用户体验。
  1. 描述LangGraph中工具调用的基本流程。

答案: LangGraph中工具调用的基本流程包括以下步骤:

  1. 代理接收用户输入并分析请求
  2. 决策是否调用工具以及调用哪个工具
  3. 准备工具调用参数
  4. 执行工具调用并获取结果
  5. 根据工具结果生成响应或决定下一步操作
  6. 工具调用可以通过条件边和循环实现复杂的多工具调用流程
  1. 解释LangGraph中人机交互的实现方式及其应用场景。

答案: LangGraph中的人机交互主要通过中断和恢复机制实现:

  • 在工作流的特定点,可以暂停执行并等待人类输入
  • 通过检查点保存当前状态
  • 在获得人工输入后,从检查点恢复执行


    应用场景包括:
  • 需要人工审批的流程
  • 需要人工澄清或确认的操作
  • 敏感操作需要人工监督
  • 复杂决策需要人类专家介入

3.4 实操题

题目: 创建一个简单的LangGraph工作流,实现一个天气查询代理。要求:

  1. 工作流包含至少3个节点:接收用户输入、调用天气API、生成响应
  2. 使用条件边处理无效输入
  3. 实现简单的错误处理
  4. 包含记忆功能,记住用户的查询历史
  5. 实现工具调用来获取天气信息

参考实现:

from typing import TypedDict, Annotated
from langgraph.graph import StateGraph, END
import operator
from datetime import datetime
from langchain.tools import tool

class WeatherState(TypedDict):
    """天气查询状态"""
    user_input: str
    city: str
    weather_data: dict
    response: str
    has_error: bool
    error_message: str
    query_history: Annotated[list, operator.add]  # 查询历史
    last_updated: str

# 天气查询工具
@tool
def get_weather_tool(city: str) -> str:
    """获取城市天气信息"""
    # 模拟天气API调用
    mock_weather_data = {
        "北京": {"weather": "晴", "temperature": "20-25°C", "humidity": "40%", "wind": "微风"},
        "上海": {"weather": "小雨", "temperature": "22-26°C", "humidity": "85%", "wind": "东风3级"},
        "广州": {"weather": "雷阵雨", "temperature": "25-30°C", "humidity": "90%", "wind": "南风2级"},
        "深圳": {"weather": "阵雨", "temperature": "24-29°C", "humidity": "88%", "wind": "东南风3级"},
        "杭州": {"weather": "多云", "temperature": "19-24°C", "humidity": "65%", "wind": "北风2级"}
    }
    
    if city in mock_weather_data:
        weather = mock_weather_data[city]
        return f"{city}天气: {weather['weather']},温度{weather['temperature']},湿度{weather['humidity']}{weather['wind']}"
    else:
        return f"抱歉,找不到{city}的天气信息"

def receive_input_node(state: WeatherState) -> WeatherState:
    """接收用户输入节点"""
    user_input = state["user_input"]
    
    # 记录查询
    query_record = {
        "query": user_input,
        "timestamp": datetime.now().isoformat()
    }
    state["query_history"].append(query_record)
    
    # 简单解析城市名称
    if "天气" in user_input:
        # 尝试提取城市名称
        city = user_input.split("天气")[0].strip()
        if city:
            state["city"] = city
            print(f"[接收输入] 解析到城市: {city}")
        else:
            state["has_error"] = True
            state["error_message"] = "无法解析城市名称,请包含'天气'关键词,如'北京天气'"
    else:
        state["has_error"] = True
        state["error_message"] = "请输入包含'天气'的查询,如'北京天气怎么样?'"
    
    state["last_updated"] = datetime.now().isoformat()
    return state

def validate_input_node(state: WeatherState) -> WeatherState:
    """验证输入节点"""
    if state["has_error"]:
        return state
    
    city = state["city"]
    
    # 简单验证城市名称
    if not city or len(city) < 2:
        state["has_error"] = True
        state["error_message"] = f"城市名称'{city}'无效,请输入有效的城市名称"
    elif len(city) > 10:
        state["has_error"] = True
        state["error_message"] = f"城市名称'{city}'过长,请输入有效的城市名称"
    
    state["last_updated"] = datetime.now().isoformat()
    return state

def fetch_weather_node(state: WeatherState) -> WeatherState:
    """获取天气数据节点"""
    if state["has_error"]:
        return state
    
    city = state["city"]
    
    print(f"[获取天气] 查询城市: {city}")
    
    try:
        # 调用天气工具
        weather_info = get_weather_tool.invoke({"city": city})
        
        # 解析天气信息
        if "抱歉" in weather_info:
            state["has_error"] = True
            state["error_message"] = weather_info
        else:
            # 提取天气数据
            state["weather_data"] = {
                "city": city,
                "info": weather_info,
                "timestamp": datetime.now().isoformat()
            }
            print(f"[获取天气] 获取成功: {weather_info[:50]}...")
    
    except Exception as e:
        state["has_error"] = True
        state["error_message"] = f"获取天气信息时出错: {str(e)}"
        print(f"[获取天气] 错误: {str(e)}")
    
    state["last_updated"] = datetime.now().isoformat()
    return state

def generate_response_node(state: WeatherState) -> WeatherState:
    """生成响应节点"""
    if state["has_error"]:
        state["response"] = f"抱歉,{state['error_message']}"
    else:
        weather = state["weather_data"]
        state["response"] = weather["info"]
        
        # 添加历史记录信息
        history_count = len(state["query_history"])
        if history_count > 1:
            state["response"] += f"\n\n(这是您今天的第{history_count}次天气查询)"
    
    # 记录响应
    state["query_history"][-1]["response"] = state["response"]
    state["last_updated"] = datetime.now().isoformat()
    
    print(f"[生成响应] 响应: {state['response'][:100]}...")
    return state

def display_history_node(state: WeatherState) -> WeatherState:
    """显示历史节点(可选)"""
    if "历史" in state["user_input"] or "记录" in state["user_input"]:
        history = state["query_history"]
        if history:
            history_text = "查询历史:\n"
            for i, record in enumerate(history[-5:], 1):  # 显示最近5条
                query = record.get("query", "")
                response = record.get("response", "尚未回复")
                time = record.get("timestamp", "")
                if time:
                    time = datetime.fromisoformat(time).strftime("%H:%M:%S")
                
                history_text += f"{i}. [{time}] {query}\n   回复: {response[:50]}...\n"
            
            state["response"] = history_text
        else:
            state["response"] = "暂无查询历史"
    
    return state

def build_weather_agent_graph():
    """构建天气查询代理图"""
    workflow = StateGraph(WeatherState)
    
    # 添加节点
    workflow.add_node("receive_input", receive_input_node)
    workflow.add_node("validate_input", validate_input_node)
    workflow.add_node("fetch_weather", fetch_weather_node)
    workflow.add_node("generate_response", generate_response_node)
    workflow.add_node("display_history", display_history_node)
    
    # 设置入口点
    workflow.set_entry_point("receive_input")
    
    # 条件边:根据用户输入决定是否显示历史
    def check_for_history(state: WeatherState) -> str:
        if "历史" in state["user_input"] or "记录" in state["user_input"]:
            return "display_history"
        else:
            return "validate_input"
    
    workflow.add_conditional_edges(
        "receive_input",
        check_for_history,
        {
            "display_history": "display_history",
            "validate_input": "validate_input"
        }
    )
    
    # 验证后的路由
    def route_after_validation(state: WeatherState) -> str:
        if state["has_error"]:
            return "generate_response"
        else:
            return "fetch_weather"
    
    workflow.add_conditional_edges(
        "validate_input",
        route_after_validation,
        {
            "fetch_weather": "fetch_weather",
            "generate_response": "generate_response"
        }
    )
    
    # 获取天气后生成响应
    workflow.add_edge("fetch_weather", "generate_response")
    
    # 显示历史后结束
    workflow.add_edge("display_history", END)
    
    # 生成响应后结束
    workflow.add_edge("generate_response", END)
    
    return workflow.compile()

# 测试天气查询代理
if __name__ == "__main__":
    weather_agent = build_weather_agent_graph()
    
    test_queries = [
        "北京天气怎么样?",
        "上海天气",
        "显示历史记录",
        " invalid city weather",
        "广州天气情况",
        "杭州天气"
    ]
    
    # 初始化状态
    state = {
        "user_input": "",
        "city": "",
        "weather_data": {},
        "response": "",
        "has_error": False,
        "error_message": "",
        "query_history": [],
        "last_updated": ""
    }
    
    print("=== 天气查询代理测试 ===")
    
    for i, query in enumerate(test_queries, 1):
        print(f"\n{'='*50}")
        print(f"测试 {i}: {query}")
        print(f"{'='*50}")
        
        # 更新用户输入
        state["user_input"] = query
        state["has_error"] = False
        state["error_message"] = ""
        
        # 执行工作流
        result = weather_agent.invoke(state)
        
        # 显示结果
        print(f"\n用户: {query}")
        print(f"助手: {result['response']}")
        
        # 更新状态
        state = result
    
    # 显示最终历史
    print(f"\n{'='*50}")
    print("最终查询历史:")
    for i, record in enumerate(state["query_history"], 1):
        query = record.get("query", "")
        response = record.get("response", "")[:50]
        if response:
            print(f"{i}. 查询: {query}")
            print(f"   回复: {response}...")

四、总结

LangGraph是一个强大的框架,专门用于构建有状态的、多参与者的AI应用工作流。通过图形化的方式组织计算流程,它支持循环、分支、并行和复杂的状态管理,使得开发者能够构建出更加智能和灵活的AI应用。

本文详细介绍了LangGraph的六大核心能力:

  1. 图形API:提供了直观的节点和边抽象,使得构建复杂工作流变得简单直观。通过条件边、循环和并行执行,LangGraph能够表达丰富的控制流逻辑。

  2. 持久性:通过检查点机制支持工作流的中断与恢复,这对于长时间运行的任务、用户会话和容错性至关重要的应用场景尤为关键。

  3. 记忆:使代理能够记住历史交互,实现上下文感知的对话和处理。支持短期记忆和长期记忆,以及用户偏好的学习和适应。

  4. 人机交互:支持人类在循环的交互模式,使AI系统能够在需要时请求人工干预,实现人机协同的智能流程。

  5. 状态管理:提供了类型安全、可组合的状态管理系统,简化了复杂状态的处理。通过注解系统,可以灵活定义状态字段的合并行为。

  6. 调用工具:允许AI代理无缝集成和调用外部工具,扩展了AI的能力边界。支持自动工具选择、参数提取和多工具协同工作。

通过这些核心能力的组合,LangGraph使得开发者能够构建出从前难以实现的复杂AI应用,如智能对话代理、自动化工作流、决策支持系统等。无论是初级开发者还是经验丰富的工程师,都能从LangGraph的抽象中受益,快速构建出强大、可靠的AI应用。


🌟 感谢您耐心阅读到这里!
🚀 技术成长没有捷径,但每一次的阅读、思考和实践,都在默默缩短您与成功的距离。
💡 如果本文对您有所启发,欢迎点赞👍、收藏📌、分享📤给更多需要的伙伴!
🗣️ 期待在评论区看到您的想法、疑问或建议,我会认真回复,让我们共同探讨、一起进步~
🔔 关注我,持续获取更多干货内容!
🤗 我们下篇文章见!

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐