[LangChain智能体本质论]Agent的状态即通道(Channel)
create_agent函数返回的Agent的类型声明如下,其第一个泛型参数代表承载整个Agent的状态类型,具体的类型为AgentState[ResponseT],后者的泛型参数ResponseT代表格式化输出的Schema。
CompiledStateGraph[AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]]
1. AgentState
我们不断在强调Agent本质上是一个Pregel对象,其状态完全利用通道来维护,所以作为状态的AgentState对象的数据成员会转换成对应的通道。当我们看到AgentState的三个字段定义的时候,是不是感到很熟悉:前面演示实例输出的通道就有三个与它们同名。
CompiledStateGraph[
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
class AgentState(TypedDict, Generic[ResponseT]):
messages: Required[Annotated[list[AnyMessage], add_messages]]
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
JumpTo = Literal["tools", "model", "end"]
定义在AgentState中,绑定为通道的状态字段可以利用标注的Annotated来定义数据类型。如果定义了reducer函数,最终创建的将是一个BinaryOperatorAggregate类型的通道,reducer函数会作为它的操作符。通过Annotated标注可以看出,messages字段的类型是一个AnyMessage列表,AnyMessage是对众多预定义的针对BaseMessage/BaseMessageChunk具体实现类型的统称(我的文章消息——Agent与模型交互的媒介中具有针对它们的详细介绍)。
通过前面的演示程序可知:messags和jump_to这两个通道的类型分别为BinaryOperatorAggregate和EphemeralValue。这样设计很好理解,因为前者用于收集生成的消息,只有设计成BinaryOperatorAggregate类型并采用基于“追加”的reducer函数才能实现;后者用于指导下一步跳转到何处,具有“邻步有效”的特性,EphemeralValue类型的通道正式针对这种场景设计的。
除了类型,前面的演示程序还体现了三个通道针对输入/输出的差异:messages通道同时作为输入和输出,structured_response通道只作为输出,而jump_to通道既非输入也非输出。这是因为structured_response和jump_to字段分别被标注了OmitFromInput和PrivateStateAttr,前者将数据成员从输入中剔除,后者则同时从输入和输出中除名。如果只作为输入,则可以是标注OmitFromOutput。没有对此作显式标注的messages意味这输入和输出均可见。
@dataclass
class OmitFromSchema:
input: bool = True
output: bool = True
OmitFromInput = OmitFromSchema(input=True, output=False)
OmitFromOutput = OmitFromSchema(input=False, output=True)
PrivateStateAttr = OmitFromSchema(input=True, output=True)
2. 自定义状态成员
默认的状态类型AgentState是一个TypedDict,它只定义了三个数据成员,我们也可以自定义任意的TypedDict作为Agent的状态Schema。但是很多现有的组件都依赖于这三个成员,所以最稳妥的方式还是定义它的子类。接下来我们将通过一个实例演示:
- 添加自定义状态数据成员;
- 将状态注入工具函数;
- 调用Agent是初始化自定义状态成员;
- 从调用结果中提取自定义状态成员;
from typing import Annotated, Any,Callable, Sequence, override
import builtins
from langchain.agents import create_agent
from langchain.agents.middleware.types import OmitFromOutput,OmitFromInput
from langchain_core.language_models import BaseChatModel,LanguageModelInput
from langchain_core.messages import BaseMessage, AIMessage, ToolMessage, ToolCall
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.outputs.chat_result import ChatResult, ChatGeneration
from langchain_core.tools import BaseTool
from langchain_core.runnables import Runnable
from langgraph.prebuilt import InjectedState
from langchain.agents.middleware import AgentState
from langgraph.types import Command
class ModelSimulator(BaseChatModel):
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
if[message for message in messages if isinstance(message, ToolMessage)]:
generation = ChatGeneration(message=AIMessage(""))
return ChatResult(generations=[generation], llm_output={})
tool_call: ToolCall = {
"name": "fake_tool",
"args": {},
"id": "tool_call_001",
}
generation = ChatGeneration(message=AIMessage(content="", tool_calls=[tool_call]))
return ChatResult(generations=[generation], llm_output={})
@property
def _llm_type(self) -> str:
return "model-simulator"
@override
def bind_tools(
self,
tools: Sequence[builtins.dict[str, Any] | type | Callable | BaseTool],
*,
tool_choice: str | None = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self
class ExtendedAgentState(AgentState):
foo: Annotated[str,OmitFromOutput]
bar: Annotated[str,OmitFromOutput]
baz: Annotated[str,OmitFromInput]
qux: Annotated[str,OmitFromInput]
def fake_tool(state:Annotated[dict, InjectedState]) -> Command:
"""A fake tool"""
tool_call_id = state["messages"][-1].tool_calls[0]["id"]
return Command(
update={
"messages":[ToolMessage("", tool_call_id=tool_call_id)],
"baz": state.get("foo", "N/A"),
"qux": state.get("bar", "N/A"),
},
)
agent = create_agent(
model= ModelSimulator(),
tools=[fake_tool],
state_schema=ExtendedAgentState,
)
result = agent.invoke(input= {"foo": "Hello", "bar":"World"}) # type: ignore
assert result["baz"] == "Hello"
assert result["qux"]== "World"
我们通过继承AgentState创建了自定义状态类型ExtendedAgentState,并为它额外添加了四个字符串类型的成员(foo、bar、baz和qux),我们利用针对OmitFromOutput和OmitFromInput的注解,将foo和bar定义成输入,将baz和qux定义成输出。
模拟的工具函数fake_tool并没有外部输入,而是将当前状态视为输入。为此我们在参数state上应用了Annotated[dict, InjectedState]实现了针对状态的参数注入(以字典形式)。在默认情况下,Agent状态的所有成员都会注入到参数绑定的字典中,如果需要对注入的字典进行过滤,过滤的字段可以在InjectedState对象中指定。比如Annotated[str, InjectedState("foo")]。
class InjectedState(InjectedToolArg):
def __init__(self, field: str | None = None) -> None:
self.field = field
由于作为LangChain引擎的Pregel采用基于BSP的执行机制,承载状态的通道不允许被节点直接修改,节点只能提交通道更新请求,并最终由执行引擎统一完成针对通道的更新。反映在工具函数上就是:它不能在函数内部直接修改某个状态成员的值,只能将针对状态的更新封装返回的Command对象上。所以fake_tool会返回一个Command对象,它的update字段体现了针对状态的更新:在messages列表中添加一个ToolMessage,将状态成员foo和bar的值赋值给成员baz和qux。
在调用create_agent创建Agent的时候,我们直接将ExtendedAgentState类型设置为state_schema参数。model参数设置的是我们用来模拟模型的ModelSimulator对象,它采用与FakeModel类似的定义。在调用invoke方法时,我们在input参数的字典中添加了成员foo和bar的值。并从作为执行结果的字典中提取通过工具函数设置的baz和qux的值。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)