根据“梳理Agent的执行流程”中针对Pregel执行流程的介绍,我们知道在每个Superstep的最后阶段,引擎会根据节点针对通道的订阅以及“__pregel_tasks”这个通道存储的Send对象,确定下一步应该执行的节点,并为它们创建代表执行任务的PregelExecutableTask。节点在执行过程中不能直接更新通道,它只能创建如下三种类型的对象来表示针对通道的写入意图。在默认情况下,这些对象承载的通道写入意图由ChannelWrite对象提交给引擎,后者在Superstep进入同步屏障阶段统一更新通道。

  • ChannelWriteEntry
  • ChannelWriteTupleEntry
  • Send

代表节点的PregelNode利用writers字段保存的一组Runnable对象来提交针对通道的写入意图。在默认情况下,这个用于提交通道写入意图的Runnable就是ChannelWrite对象。ChannelWrite继承自RunnableCallable,其writes字段是一个列表, 列表元素正是上述的三种类型的对象。

class PregelNode:
    writers : list[Runnable]

class ChannelWrite(RunnableCallable):
    writes: list[ChannelWriteEntry | ChannelWriteTupleEntry | Send]

1. ChannelWriteEntry

ChannelWriteEntry表示针对单通道的写入,只是一个命名元组。

PASSTHROUGH = object()
class ChannelWriteEntry(NamedTuple):
    channel: str
    value: Any = PASSTHROUGH
    skip_none: bool = False
    mapper: Callable | None = None

各字段说明:

  • channel:待写入通道的名称;
  • value:写入的值;默认值PASSTHROUGH表示输入值,如果ChannelWrite处于LCEL链中,写入的值就是前一个Runnable的输出;
  • skip_none:如果值为None,是否忽略;
  • mapper:针对value的映射函数,并将映射后的值写入通道;

2. ChannelWriteTupleEntry

ChannelWriteTupleEntry表示针多通道写入意图。

class ChannelWriteTupleEntry(NamedTuple):
    mapper: Callable[[Any], Sequence[tuple[str, Any]] | None]
    value: Any = PASSTHROUGH
    static: Sequence[tuple[str, Any, str | None]] | None = None
    """Optional, declared writes for static analysis."""

各字段说明:

  • mapper:映射函数会生成一个二元组序列,每个二元组表示写入的通道名称和值的组合;
  • value:传入mapper这个映射函数的输入;默认值PASSTHROUGH表示输入值,如果ChannelWrite处于LCEL链中,写入的值就是前一个Runnable的输出;
  • static:主要用于静态分析,以便在程序运行前就确定节点会对哪些通道进行写操作。返回一个三元组序列,三元组分别表示通道名称、写入值和附加元数据;

3. ChannelWrite

作为一个RunnableCallableChannelWrite的执行任务体现在它的funcafunc字段返回的可执行对象上。从构造函数可知,这两个字段分别指向_write_awrite方法。

class ChannelWrite(RunnableCallable):
    writes: list[ChannelWriteEntry | ChannelWriteTupleEntry | Send]

    def __init__(
        self,
        writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
        *,
        tags: Sequence[str] | None = None,
    ):
        super().__init__(
            func=self._write,
            afunc=self._awrite,
            name=None,
            tags=tags,
            trace=False,
        )
        self.writes = cast(
            list[ChannelWriteEntry | ChannelWriteTupleEntry | Send], writes
        )

    def _write(self, input: Any, config: RunnableConfig) -> None:
        writes = [
            ChannelWriteEntry(write.channel, input, write.skip_none, write.mapper)
            if isinstance(write, ChannelWriteEntry) and write.value is PASSTHROUGH
            else ChannelWriteTupleEntry(write.mapper, input)
            if isinstance(write, ChannelWriteTupleEntry) and write.value is PASSTHROUGH
            else write
            for write in self.writes
        ]
        self.do_write(
            config,
            writes,
        )
        return input

    async def _awrite(self, input: Any, config: RunnableConfig) -> None

_write的定义可以看出,如果提供的ChannelWriteEntry或者ChannelWriteTupleEntryvaluePASSTHROUGH,表示当前节点仅作数据透传,此时它会放弃原有的对象,根据原始输入(input参数)创建一个新的对象,最后调用如下这个静态方法do_write提交通道写入意图。

class ChannelWrite(RunnableCallable):
    @staticmethod
    def do_write(
        config: RunnableConfig,
        writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
        allow_passthrough: bool = True,
    ) -> None:
        # validate
        for w in writes:
            if isinstance(w, ChannelWriteEntry):
                if w.channel == TASKS:
                    raise InvalidUpdateError(
                        "Cannot write to the reserved channel TASKS"
                    )
                if w.value is PASSTHROUGH and not allow_passthrough:
                    raise InvalidUpdateError("PASSTHROUGH value must be replaced")
            if isinstance(w, ChannelWriteTupleEntry):
                if w.value is PASSTHROUGH and not allow_passthrough:
                    raise InvalidUpdateError("PASSTHROUGH value must be replaced")
        write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND]
        write(_assemble_writes(writes))

由于“__pregel_tasks”这个通道(常量TASKS表示它的通道名)是引擎用于存储节点任务的,是引擎专有通道,所以当ChannelWriteEntry试图写入此通道时,会被该方法阻难。如果我们针对想要写入此通道(几乎不太有这种需求),只能使用ChannelWriteTupleEntry,我们在“__pregel_tasks通道——成就“PUSH任务”的功臣”这篇文章中相应讨论过这个问题。三种承载通道写入意图的对象最终都需要统一转换成一个二元组序列,每个二元组为“通道名”和“写入值”的组合,这状态是通过如下这个_assemble_writes函数完成的。

def _assemble_writes(
    writes: Sequence[ChannelWriteEntry | ChannelWriteTupleEntry | Send],
) -> list[tuple[str, Any]]:
    """Assembles the writes into a list of tuples."""
    tuples: list[tuple[str, Any]] = []
    for w in writes:
        if isinstance(w, Send):
            tuples.append((TASKS, w))
        elif isinstance(w, ChannelWriteTupleEntry):
            if ww := w.mapper(w.value):
                tuples.extend(ww)
        elif isinstance(w, ChannelWriteEntry):
            value = w.mapper(w.value) if w.mapper is not None else w.value
            if value is SKIP_WRITE:
                continue
            if w.skip_none and value is None:
                continue
            tuples.append((w.channel, value))
        else:
            raise ValueError(f"Invalid write entry: {w}")
    return tuples

_assemble_writes针对三种类型的转换规则如下:

  • ChannelWriteEntry:channel字段作为通道名;如果mapper不存在,value字段作为写入值;否则mapper针对value的映射值作为写入值。value为None,且skip_none为True,写入忽略;
  • ChannelWriteTupleEntry:mapper针对value的映射结果为一组二元组,分别代表通道名和写入值;
  • Send:Send对象以PUSH模式创建节点任务,它总是写入“__pregel_tasks”通道;

在引擎在为某个节点创建作为执行任务的PregelExecutableTask对象时,会创建一个代表双端队列的deque[tuple[str, Any]]对象,该队列存储的二元组正是当前节点最终需要写入的“通道名”+“写入值”。它会将用于向此对象写入二元组的extend方法写入RunnableConfig配置,对应的Key为“__pregel_send”,常量CONFIG_KEY_SEND返回的正是这个Key。

再回到ChannelWritedo_write方法,它调用_assemble_writes方法生成“通道名/值”二元组列表之前,会利用上述的这个Key从RunnableConfig中将这个函数提取出来,“通道名/值”二元组作为该函数的输入被添加到deque[tuple[str, Any]]队列中。最终只需要提取这些队列存储的二元组,根据名称定为对应的通道实施写入即可。

from langgraph.pregel._write import PASSTHROUGH, ChannelWrite, ChannelWriteEntry, ChannelWriteTupleEntry, Send
from langchain_core.runnables import RunnableConfig
from collections import deque
from typing import Any
from langgraph._internal._constants import CONFIG_KEY_SEND

write = ChannelWrite(
    [
        ChannelWriteEntry(channel="foo", value= PASSTHROUGH, skip_none=True),
        ChannelWriteTupleEntry(mapper=lambda x: [("bar", x*2),("baz", x*3)]),
        Send(node="qux",arg="hello,world!"),
    ])

wrties:deque[tuple[str,Any]] = deque()
config:RunnableConfig = {
    "configurable": {CONFIG_KEY_SEND: wrties.extend}
}

write.invoke(input=5, config=config)
for update in wrties:
    print(update)

我们通过上面这个简单的例子演示了ChannelWrite针对通道写入意图的提交。我们针对创建的ChannelWriteEntryChannelWriteTupleEntrySend创建了ChannelWrite对象。ChannelWriteEntry仅作透传,写入通道为foo;ChannelWriteTupleEntry会将输入值分别乘以2和3后写入通道bar和baz,Send创建了针对节点qux的任务,参数被设置为“hello,world”。

我们创建一个deque队列来收集“通道名/值”二元组,并将它的extend方法置于创建的RunnableConfig对象中,对应的Key为CONFIG_KEY_SEND(“__pregel_send”)。我们传入一个整数5和这个RunnableConfig调用这个ChannelWrite对象。然后输出利用队列收集的二元组,具体输出如下:

('foo', 5)
('bar', 10)
('baz', 15)
('__pregel_tasks', Send(node='qux', arg='hello,world!'))

4. 利用其他Runnable

我们在开篇已经说过,PregelNode利用writes字段返回的一组Runnable对象来提供针对通道的写入意图,ChannelWrite只是默认的实现方式而已。ChannelWrite提供了静态方法register_writer将指定的Runnable对象注册为"通道写入器",具体的实现仅仅时添加了一个名为_is_channel_writer的属性(attribute),值为static参数提供的二元组序列。二元组的前部分为ChannelWriteEntry或者Send对象,后者表示补充的元数据。静态方法is_writer根据这个添加的属性判断指定的Runnable是否为“通道写入器”。

class ChannelWrite(RunnableCallable):   
      @staticmethod
      def register_writer(
        runnable: R,
        static: Sequence[tuple[ChannelWriteEntry | Send, str | None]] | None = None,
    ) -> R:
        object.__setattr__(runnable, "_is_channel_writer", static)
        return runnable

    @staticmethod
    def is_writer(runnable: Runnable) -> bool:
        return (
            isinstance(runnable, ChannelWrite)
            or getattr(runnable, "_is_channel_writer", MISSING) is not MISSING
        )

    @staticmethod
    def get_static_writes(
        runnable: Runnable,
    ) -> Sequence[tuple[str, Any, str | None]] | None:
        """Used to get conditional writes a writer declares for static analysis."""
        if isinstance(runnable, ChannelWrite):
            return [
                w
                for entry in runnable.writes
                if isinstance(entry, ChannelWriteTupleEntry) and entry.static
                for w in entry.static
            ] or None
        elif writes := getattr(runnable, "_is_channel_writer", MISSING):
            if writes is not MISSING:
                writes = cast(
                    Sequence[tuple[ChannelWriteEntry | Send, str | None]],
                    writes,
                )
                entries = [e for e, _ in writes]
                labels = [la for _, la in writes]
                return [(*t, la) for t, la in zip(_assemble_writes(entries), labels)]    

静态方法get_static_writes的核心作用是 “静态探测”:它让LangGraph能够在不实际执行节点代码的情况下,提前通过元数据分析出该节点预期的写入行为。LangGraph在编译阶段需要验证图的拓扑结构。通过这个方法,编译器可以知道:

  • 哪些通道会被修改:确保写入的通道在State Schema中确实存在。
  • 数据流向:识别是否存在逻辑死循环或孤立节点。
  • 并发控制:根据静态声明预先规划节点的执行顺序和状态同步。
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐