详细介绍自定义中间件的实现方式、核心特性、最佳实践及实战示例,帮助开发者快速掌握 Agent 执行流程的拦截与扩展能力。

一、什么是 LangChain 中间件?

中间件是 LangChain Agent 执行流程中的拦截器,通过实现特定钩子(Hooks),可以在 Agent 运行的关键节点插入自定义逻辑,例如:

  • 日志记录与监控
  • 请求/响应转换
  • 重试与缓存
  • 权限校验与流量控制
  • 动态模型/工具切换

中间件支持两种核心钩子类型,覆盖 Agent 执行全生命周期:

二、核心钩子类型(Hooks)

2.1 节点式钩子(Node-style hooks)

按执行顺序触发,适用于日志、验证、状态更新等场景,支持以下钩子:

  • before_agent:Agent 启动前(每次调用执行1次)
  • before_model:每次模型调用前
  • after_model:每次模型响应后
  • after_agent:Agent 执行完成后(每次调用执行1次)
实现方式
装饰器模式
from langchain.agents.middleware import before_model, after_model, AgentState
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any

@before_model(can_jump_to=["end"])
def check_message_limit(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """检查对话消息数量限制"""
    if len(state["messages"]) >= 50:
        return {
            "messages": [AIMessage("Conversation limit reached.")],
            "jump_to": "end"  # 触发 Agent 提前结束
        }
    return None

@after_model
def log_response(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """记录模型响应日志"""
    print(f"Model returned: {state['messages'][-1].content}")
    return None
类模式
from langchain.agents.middleware import AgentMiddleware, AgentState, hook_config
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import Any

class MessageLimitMiddleware(AgentMiddleware):
    def __init__(self, max_messages: int = 50):
        super().__init__()
        self.max_messages = max_messages

    @hook_config(can_jump_to=["end"])
    def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        if len(state["messages"]) >= self.max_messages:
            return {
                "messages": [AIMessage("Conversation limit reached.")],
                "jump_to": "end"
            }
        return None

    def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
        print(f"Model returned: {state['messages'][-1].content}")
        return None

2.2 包装式钩子(Wrap-style hooks)

环绕模型/工具调用执行,支持控制调用流程(短路、重试、多轮调用),适用于:

  • 重试逻辑
  • 缓存控制
  • 请求/响应转换
  • 异常捕获

支持以下钩子:

  • wrap_model_call:环绕每次模型调用
  • wrap_tool_call:环绕每次工具调用
实现方式
装饰器模式
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from typing import Callable

@wrap_model_call
def retry_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    """模型调用重试逻辑"""
    for attempt in range(3):
        try:
            return handler(request)  # 执行原始模型调用
        except Exception as e:
            if attempt == 2:
                raise  # 最后一次尝试失败则抛出异常
            print(f"Retry {attempt + 1}/3 after error: {e}")
类模式
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
from typing import Callable

class RetryMiddleware(AgentMiddleware):
    def __init__(self, max_retries: int = 3):
        super().__init__()
        self.max_retries = max_retries

    def wrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], ModelResponse],
    ) -> ModelResponse:
        for attempt in range(self.max_retries):
            try:
                return handler(request)
            except Exception as e:
                if attempt == self.max_retries - 1:
                    raise
                print(f"Retry {attempt + 1}/{self.max_retries} after error: {e}")

三、状态更新(State Updates)

中间件可更新 Agent 状态,两种钩子的更新机制不同:

3.1 节点式钩子状态更新

直接返回字典,通过 Graph reducer 合并到 Agent 状态:

from langchain.agents.middleware import after_model, AgentState
from langgraph.runtime import Runtime
from typing import Any
from typing_extensions import NotRequired

# 定义自定义状态 schema
class TrackingState(AgentState):
    model_call_count: NotRequired[int]  # 可选字段

@after_model(state_schema=TrackingState)
def increment_after_model(state: TrackingState, runtime: Runtime) -> dict[str, Any] | None:
    """统计模型调用次数"""
    return {"model_call_count": state.get("model_call_count", 0) + 1}

3.2 包装式钩子状态更新

通过 ExtendedModelResponse + Command 注入状态更新:

from typing import Callable
from langchain.agents.middleware import (
    wrap_model_call,
    ModelRequest,
    ModelResponse,
    AgentState,
    ExtendedModelResponse
)
from langgraph.types import Command
from typing_extensions import NotRequired

class UsageTrackingState(AgentState):
    """带 Token 用量统计的状态"""
    last_model_call_tokens: NotRequired[int]

@wrap_model_call(state_schema=UsageTrackingState)
def track_usage(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ExtendedModelResponse:
    response = handler(request)
    return ExtendedModelResponse(
        model_response=response,
        command=Command(update={"last_model_call_tokens": 150}),  # 注入状态更新
    )
多中间件状态合并规则
  1. 消息类字段: additive 累加
  2. 普通字段:内层优先,外层覆盖(冲突时外层生效)
  3. 重试安全:重试场景下,前序调用的状态更新会被丢弃

四、自定义状态 Schema

通过扩展 AgentState 定义自定义状态字段,支持:

  • 跨钩子数据共享
  • 生命周期状态跟踪
  • 条件决策依据

完整示例

from langchain.agents import create_agent
from langchain.messages import HumanMessage
from langchain.agents.middleware import AgentState, before_model, after_model
from typing_extensions import NotRequired
from typing import Any

# 定义自定义状态
class CustomState(AgentState):
    model_call_count: NotRequired[int]  # 模型调用计数器
    user_id: NotRequired[str]           # 用户 ID

@before_model(state_schema=CustomState, can_jump_to=["end"])
def check_call_limit(state: CustomState) -> dict[str, Any] | None:
    """限制最大调用次数"""
    count = state.get("model_call_count", 0)
    if count > 10:
        return {"jump_to": "end"}  # 超过限制则终止
    return None

@after_model(state_schema=CustomState)
def increment_counter(state: CustomState) -> dict[str, Any] | None:
    """递增计数器"""
    return {"model_call_count": state.get("model_call_count", 0) + 1}

# 创建 Agent 并注入中间件
agent = create_agent(
    model="gpt-4.1",
    middleware=[check_call_limit, increment_counter],
    tools=[],
)

# 带自定义状态调用
result = agent.invoke({
    "messages": [HumanMessage("Hello")],
    "model_call_count": 0,
    "user_id": "user-123",
})

五、中间件执行顺序

当配置多个中间件时,执行顺序遵循以下规则:

# 示例:middleware1 → middleware2 → middleware3
agent = create_agent(
    model="gpt-4.1",
    middleware=[middleware1, middleware2, middleware3],
    tools=[...],
)

执行流程

  1. Before 钩子:按注册顺序执行(middleware1 → middleware2 → middleware3)
  2. Wrap 钩子:嵌套执行(middleware1 包裹 middleware2,再包裹 middleware3)
  3. After 钩子:按注册逆序执行(middleware3 → middleware2 → middleware1)

可视化流程

┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│ before_agent │    │ before_agent │    │ before_agent │
│ middleware1  │ →  │ middleware2  │ →  │ middleware3  │
└─────────────┘    └─────────────┘    └─────────────┘
        │                  │                  │
        ▼                  ▼                  ▼
┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│ before_model │    │ before_model │    │ before_model │
│ middleware1  │ →  │ middleware2  │ →  │ middleware3  │
└─────────────┘    └─────────────┘    └─────────────┘
        │                  │                  │
        ▼                  ▼                  ▼
┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│ wrap_model  │    │ wrap_model  │    │ wrap_model  │
│ middleware1  │ ←  │ middleware2  │ ←  │ middleware3  │
└─────────────┘    └─────────────┘    └─────────────┘
        │                  │                  │
        ▼                  ▼                  ▼
┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│ after_model  │    │ after_model  │    │ after_model  │
│ middleware1  │ ←  │ middleware2  │ ←  │ middleware3  │
└─────────────┘    └─────────────┘    └─────────────┘

六、Agent 跳转(Agent Jumps)

通过 jump_to 实现执行流程的跳转,支持以下目标:

  • end:直接结束 Agent 执行
  • tools:跳转到工具调用节点
  • model:跳转到模型调用节点

示例实现

from langchain.agents.middleware import after_model, hook_config, AgentState
from langchain.messages import AIMessage
from typing import Any

@after_model
@hook_config(can_jump_to=["end"])
def check_for_blocked(state: AgentState) -> dict[str, Any] | None:
    """检测敏感内容并终止"""
    last_message = state["messages"][-1]
    if "BLOCKED" in last_message.content:
        return {
            "messages": [AIMessage("I cannot respond to that request.")],
            "jump_to": "end"  # 触发跳转
        }
    return None

七、实战示例合集

7.1 动态模型选择

根据对话长度自动切换模型:

from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.chat_models import init_chat_model
from typing import Callable

# 初始化不同模型
complex_model = init_chat_model("gpt-4.1")
simple_model = init_chat_model("gpt-4.1-mini")

@wrap_model_call
def dynamic_model(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    # 长对话用复杂模型,短对话用轻量模型
    if len(request.messages) > 10:
        model = complex_model
    else:
        model = simple_model
    return handler(request.override(model=model))

7.2 工具调用监控

监控工具执行状态并记录日志:

from langchain.agents.middleware import wrap_tool_call
from langchain.tools.tool_node import ToolCallRequest
from langchain.messages import ToolMessage
from typing import Callable

@wrap_tool_call
def monitor_tool(
    request: ToolCallRequest,
    handler: Callable[[ToolCallRequest], ToolMessage],
) -> ToolMessage:
    print(f"=== 工具调用开始 ===")
    print(f"工具名:{request.tool_call['name']}")
    print(f"参数:{request.tool_call['args']}")
    try:
        result = handler(request)
        print(f"=== 工具调用成功 ===")
        return result
    except Exception as e:
        print(f"=== 工具调用失败:{e} ===")
        raise

7.3 动态工具选择

根据上下文过滤相关工具,优化模型选择效率:

from langchain.agents import create_agent
from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from typing import Callable

def select_relevant_tools(state, runtime):
    """根据状态筛选相关工具(示例逻辑)"""
    user_query = state["messages"][-1].content
    if "search" in user_query.lower():
        return ["search_tool"]  # 只保留搜索工具
    return ["calculator_tool", "weather_tool"]

@wrap_model_call
def dynamic_tool_selector(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    # 动态筛选工具
    relevant_tools = select_relevant_tools(request.state, request.runtime)
    return handler(request.override(tools=relevant_tools))

# 注册所有工具,中间件动态筛选
all_tools = ["search_tool", "calculator_tool", "weather_tool", "email_tool"]
agent = create_agent(
    model="gpt-4.1",
    tools=all_tools,
    middleware=[dynamic_tool_selector],
)

7.4 系统消息动态修改

为系统消息添加动态上下文:

from langchain.agents.middleware import wrap_model_call, ModelRequest, ModelResponse
from langchain.messages import SystemMessage
from typing import Callable

@wrap_model_call
def add_dynamic_context(
    request: ModelRequest,
    handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
    # 读取原始系统消息
    original_blocks = list(request.system_message.content_blocks)
    # 添加动态上下文(如用户权限、当前时间等)
    dynamic_blocks = original_blocks + [
        {"type": "text", "text": "当前用户权限:管理员\n允许访问所有工具"}
    ]
    # 构造新系统消息
    new_system_msg = SystemMessage(content=dynamic_blocks)
    return handler(request.override(system_message=new_system_msg))

八、最佳实践

  1. 单一职责:每个中间件专注于一项功能(如日志、重试、权限)
  2. 错误处理:中间件内部捕获异常,避免影响 Agent 主流程
  3. 钩子选择
    • 顺序执行逻辑用 Node-style hooks
    • 控制流逻辑用 Wrap-style hooks
  4. 状态设计:明确自定义状态的用途,避免冗余字段
  5. 测试优先:独立测试中间件,再集成到 Agent
  6. 执行顺序:关键中间件(如权限校验)放在注册列表前面
  7. 优先使用内置中间件:避免重复造轮子
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐