create_agent函数返回的Agent的类型声明如下,其第一个泛型参数代表承载整个Agent的状态类型,具体的类型为AgentState[ResponseT],后者的泛型参数ResponseT代表格式化输出的Schema。

CompiledStateGraph[AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]]

1. AgentState

我们不断在强调Agent本质上是一个Pregel对象,其状态完全利用通道来维护,所以作为状态的AgentState对象的数据成员会转换成对应的通道。当我们看到AgentState的三个字段定义的时候,是不是感到很熟悉:前面演示实例输出的通道就有三个与它们同名。

CompiledStateGraph[
    AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]

class AgentState(TypedDict, Generic[ResponseT]):
    messages: Required[Annotated[list[AnyMessage], add_messages]]
    jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
    structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]

JumpTo = Literal["tools", "model", "end"]

定义在AgentState中,绑定为通道的状态字段可以利用标注的Annotated来定义数据类型。如果定义了reducer函数,最终创建的将是一个BinaryOperatorAggregate类型的通道,reducer函数会作为它的操作符。通过Annotated标注可以看出,messages字段的类型是一个AnyMessage列表,AnyMessage是对众多预定义的针对BaseMessage/BaseMessageChunk具体实现类型的统称(我的文章消息——Agent与模型交互的媒介中具有针对它们的详细介绍)。

通过前面的演示程序可知:messagsjump_to这两个通道的类型分别为BinaryOperatorAggregateEphemeralValue。这样设计很好理解,因为前者用于收集生成的消息,只有设计成BinaryOperatorAggregate类型并采用基于“追加”的reducer函数才能实现;后者用于指导下一步跳转到何处,具有“邻步有效”的特性,EphemeralValue类型的通道正式针对这种场景设计的。

除了类型,前面的演示程序还体现了三个通道针对输入/输出的差异:messages通道同时作为输入和输出,structured_response通道只作为输出,而jump_to通道既非输入也非输出。这是因为structured_responsejump_to字段分别被标注了OmitFromInputPrivateStateAttr,前者将数据成员从输入中剔除,后者则同时从输入和输出中除名。如果只作为输入,则可以是标注OmitFromOutput。没有对此作显式标注的messages意味这输入和输出均可见。

@dataclass
class OmitFromSchema:
    input: bool = True
    output: bool = True
OmitFromInput = OmitFromSchema(input=True, output=False)
OmitFromOutput = OmitFromSchema(input=False, output=True)
PrivateStateAttr = OmitFromSchema(input=True, output=True)

2. 自定义状态成员

默认的状态类型AgentState是一个TypedDict,它只定义了三个数据成员,我们也可以自定义任意的TypedDict作为Agent的状态Schema。但是很多现有的组件都依赖于这三个成员,所以最稳妥的方式还是定义它的子类。接下来我们将通过一个实例演示:

  • 添加自定义状态数据成员;
  • 将状态注入工具函数;
  • 调用Agent是初始化自定义状态成员;
  • 从调用结果中提取自定义状态成员;
from typing import Annotated, Any,Callable, Sequence, override
import builtins
from langchain.agents import create_agent
from langchain.agents.middleware.types import OmitFromOutput,OmitFromInput
from langchain_core.language_models import BaseChatModel,LanguageModelInput
from langchain_core.messages import BaseMessage, AIMessage, ToolMessage, ToolCall
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs.chat_result import ChatResult, ChatGeneration
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from langgraph.prebuilt import InjectedState
from langchain.agents.middleware import AgentState
from langgraph.types import Command

class ModelSimulator(BaseChatModel):
    def _generate(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
        
        if[message for message in messages if isinstance(message, ToolMessage)]:
             generation = ChatGeneration(message=AIMessage(""))
             return ChatResult(generations=[generation], llm_output={})
        
        tool_call: ToolCall = {
                "name": "fake_tool",
                "args": {},
                "id": "tool_call_001",
            }            
        generation = ChatGeneration(message=AIMessage(content="", tool_calls=[tool_call]))
        return ChatResult(generations=[generation], llm_output={})

    @property
    def _llm_type(self) -> str:
        return "model-simulator"

    @override
    def bind_tools(
        self,
        tools: Sequence[builtins.dict[str, Any] | type | Callable | BaseTool],
        *,
        tool_choice: str | None = None,
        **kwargs: Any,
    ) -> Runnable[LanguageModelInput, AIMessage]:
        return self

class ExtendedAgentState(AgentState):
    foo: Annotated[str,OmitFromOutput]
    bar: Annotated[str,OmitFromOutput]
    baz: Annotated[str,OmitFromInput]
    qux: Annotated[str,OmitFromInput]


def fake_tool(state:Annotated[dict, InjectedState]) -> Command:
    """A fake tool"""

    tool_call_id = state["messages"][-1].tool_calls[0]["id"]
    return Command(
        update={
            "messages":[ToolMessage("", tool_call_id=tool_call_id)],
            "baz": state.get("foo", "N/A"),
            "qux": state.get("bar", "N/A"), 
            },
    )

agent = create_agent(
    model= ModelSimulator(),
    tools=[fake_tool],
    state_schema=ExtendedAgentState,
)

result = agent.invoke(input= {"foo": "Hello", "bar":"World"}) # type: ignore
assert result["baz"] == "Hello"
assert result["qux"]== "World"

我们通过继承AgentState创建了自定义状态类型ExtendedAgentState,并为它额外添加了四个字符串类型的成员(foo、bar、baz和qux),我们利用针对OmitFromOutputOmitFromInput的注解,将foo和bar定义成输入,将baz和qux定义成输出。

模拟的工具函数fake_tool并没有外部输入,而是将当前状态视为输入。为此我们在参数state上应用了Annotated[dict, InjectedState]实现了针对状态的参数注入(以字典形式)。在默认情况下,Agent状态的所有成员都会注入到参数绑定的字典中,如果需要对注入的字典进行过滤,过滤的字段可以在InjectedState对象中指定。比如Annotated[str, InjectedState("foo")]

class InjectedState(InjectedToolArg):
    def __init__(self, field: str | None = None) -> None:
        self.field = field

由于作为LangChain引擎的Pregel采用基于BSP的执行机制,承载状态的通道不允许被节点直接修改,节点只能提交通道更新请求,并最终由执行引擎统一完成针对通道的更新。反映在工具函数上就是:它不能在函数内部直接修改某个状态成员的值,只能将针对状态的更新封装返回的Command对象上。所以fake_tool会返回一个Command对象,它的update字段体现了针对状态的更新:在messages列表中添加一个ToolMessage,将状态成员foo和bar的值赋值给成员baz和qux。

在调用create_agent创建Agent的时候,我们直接将ExtendedAgentState类型设置为state_schema参数。model参数设置的是我们用来模拟模型的ModelSimulator对象,它采用与FakeModel类似的定义。在调用invoke方法时,我们在input参数的字典中添加了成员foo和bar的值。并从作为执行结果的字典中提取通过工具函数设置的baz和qux的值。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐