Agent工作流Harness:DAG编排与断点恢复


一、 引言

钩子:从一个常见的自动化困境开始

你是否曾经遇到过这样的场景:你精心设计了一个自动化流程,用于处理一批数据生成报告,流程包含了数据提取、清洗、转换、分析和可视化等多个步骤。一切看起来都很完美,直到……在流程执行到第8个步骤时,因为网络超时导致了一个API调用失败。整个流程戛然而止,之前7个步骤所做的工作全部白费,你不得不从头开始重新运行整个流程。更糟糕的是,你甚至不知道在失败之前,哪些数据已经处理完毕,哪些还没有。

这种场景在数据工程、机器学习流水线、DevOps自动化以及各种复杂的业务流程中太常见了。当我们处理的任务变得越来越复杂,涉及的步骤越来越多,依赖关系越来越错综,简单的线性脚本执行方式就显得力不从心了。我们需要一种更强大、更可靠的方式来组织和执行这些复杂的工作流。

定义问题/阐述背景

在当今的软件开发和数据处理领域,工作流编排已经成为一个核心挑战。特别是随着AI Agent概念的兴起,我们需要协调多个智能体协同工作,每个Agent可能有自己的输入输出、执行逻辑和依赖关系。这就对工作流系统提出了更高的要求:

  1. 复杂依赖管理:任务之间可能存在复杂的依赖关系,不是简单的线性执行,而是形成有向无环图(DAG)结构。
  2. 容错与恢复:当某个任务失败时,我们希望能够从失败点恢复,而不是重新执行整个流程。
  3. 可观测性:我们需要能够监控工作流的执行状态,了解每个任务的执行情况。
  4. 并行执行:对于没有依赖关系的任务,我们希望能够并行执行以提高效率。
  5. 动态调整:在某些情况下,我们可能需要根据前序任务的执行结果动态调整后续任务。

正是在这样的背景下,Agent工作流Harness应运而生。Harness在这里指的是一个框架或平台,它提供了一套工具和抽象,帮助我们定义、编排和执行复杂的Agent工作流,特别是支持DAG结构的编排和强大的断点恢复能力。

亮明观点/文章目标

在本文中,我们将深入探讨Agent工作流Harness的核心概念、设计原理和实现方法。具体来说,我们将:

  1. 介绍DAG编排的基本概念和重要性
  2. 详细分析断点恢复机制的实现原理
  3. 通过一个实战项目,从零开始构建一个简单但功能完整的工作流Harness
  4. 探讨如何在实际场景中应用这些概念
  5. 分享一些最佳实践和常见陷阱

通过阅读本文,你将不仅理解Agent工作流Harness的理论基础,还将掌握实际构建和使用这样一个系统的技能。无论你是在构建数据管道、机器学习流水线,还是复杂的业务流程自动化系统,本文的内容都将对你有所帮助。


二、 基础知识/背景铺垫

核心概念定义

在深入探讨Agent工作流Harness之前,我们需要先明确一些核心概念。

什么是Agent?

在本文的上下文中,Agent指的是一个封装了特定功能的执行单元。它可以是一个简单的函数,也可以是一个复杂的服务,甚至可以是一个AI模型。Agent具有以下特点:

  • 独立性:Agent应该是相对独立的,能够自主完成特定任务。
  • 输入输出明确:Agent应该有明确的输入和输出定义,便于与其他Agent协作。
  • 可组合性:多个Agent可以组合在一起,形成更复杂的系统。
什么是工作流(Workflow)?

工作流是一系列相互关联的任务的集合,这些任务按照一定的规则和顺序执行,以达到特定的目标。工作流定义了任务之间的依赖关系、执行顺序和数据流向。

什么是DAG?

DAG是有向无环图(Directed Acyclic Graph)的缩写。它是一种特殊的图结构,具有以下特点:

  • 有向:图中的边有方向,表示从一个节点指向另一个节点。
  • 无环:图中没有循环,即不能从一个节点出发,沿着边的方向又回到该节点。

在工作流编排中,DAG是一种非常自然的表示方式:

  • 节点表示任务(或Agent)
  • 边表示任务之间的依赖关系
  • 无环保证了工作流不会陷入无限循环
什么是编排(Orchestration)?

编排指的是对工作流中的任务进行协调、调度和管理的过程。编排系统负责:

  • 解析工作流定义
  • 管理任务之间的依赖关系
  • 调度任务的执行
  • 处理任务失败和重试
  • 收集和报告执行状态
什么是断点恢复(Checkpointing and Recovery)?

断点恢复是一种机制,允许工作流在失败后从失败点继续执行,而不是从头开始。这需要:

  • 在执行过程中定期保存状态(检查点)
  • 能够从保存的状态中恢复执行环境
  • 能够确定哪些任务已经完成,哪些需要重新执行

相关工具/技术概览

在Agent工作流编排领域,已经有一些成熟的工具和技术。让我们简要了解一下:

开源工作流编排工具
  1. Apache Airflow:这是最流行的开源工作流编排工具之一,使用Python编写。它使用DAG来定义工作流,提供了丰富的操作符和强大的调度能力。

  2. Prefect:一个相对较新的工作流编排工具,旨在解决Airflow的一些痛点。它提供了更现代的API和更好的开发体验。

  3. Argo Workflows:专为Kubernetes设计的工作流编排工具,使用CRD(自定义资源定义)来定义工作流。

  4. Temporal:一个专注于可靠性和可扩展性的工作流编排平台,提供了强大的容错和恢复能力。

AI Agent框架
  1. LangChain:一个用于构建由语言模型驱动的应用程序的框架,支持链式调用和Agent工作流。

  2. AutoGPT:一个自主AI Agent项目,可以自主设定目标并执行任务。

  3. CrewAI:一个用于协调多个AI Agent协作的框架,支持角色定义和任务分配。

虽然这些工具都很强大,但它们往往比较复杂,有一定的学习曲线。在本文中,我们将构建一个简化版的工作流Harness,帮助你理解这些工具背后的核心原理。

DAG编排的数学基础

在深入实现之前,让我们先了解一些DAG编排的数学基础。

拓扑排序

拓扑排序是对DAG的顶点进行排序的一种算法,使得对于每一条有向边(u, v),顶点u在排序中都出现在顶点v的前面。拓扑排序是DAG编排的基础,因为它决定了任务的执行顺序。

Kahn算法是一种常用的拓扑排序算法,其基本思想是:

  1. 计算每个节点的入度(指向该节点的边的数量)
  2. 将所有入度为0的节点加入队列
  3. 从队列中取出一个节点,将其加入拓扑排序结果
  4. 将该节点的所有邻居的入度减1
  5. 如果某个邻居的入度变为0,将其加入队列
  6. 重复步骤3-5,直到队列为空

如果最终拓扑排序结果的节点数量小于图中节点的总数,说明图中存在环。

依赖管理

在DAG中,每个节点的执行依赖于其所有前驱节点的完成。我们可以用以下数学表达式来表示这种依赖关系:

T={t1,t2,...,tn}T = \{t_1, t_2, ..., t_n\}T={t1,t2,...,tn}是工作流中的所有任务集合,D⊆T×TD \subseteq T \times TDT×T是任务之间的依赖关系集合,其中(ti,tj)∈D(t_i, t_j) \in D(ti,tj)D表示任务tjt_jtj依赖于任务tit_iti的完成。

对于任意任务tjt_jtj,其前驱任务集合为pred(tj)={ti∣(ti,tj)∈D}pred(t_j) = \{t_i | (t_i, t_j) \in D\}pred(tj)={ti(ti,tj)D}

任务tjt_jtj可以开始执行当且仅当其所有前驱任务都已成功完成:

executable(tj)  ⟺  ∀ti∈pred(tj),completed(ti)=trueexecutable(t_j) \iff \forall t_i \in pred(t_j), completed(t_i) = trueexecutable(tj)tipred(tj),completed(ti)=true


三、 核心内容/实战演练

在这一部分,我们将从零开始构建一个简单但功能完整的Agent工作流Harness。我们将使用Python作为编程语言,因为它在数据处理和自动化领域非常流行,且有丰富的生态系统。

项目概述

我们的工作流Harness将具备以下核心功能:

  1. 支持以DAG形式定义工作流
  2. 支持任务并行执行
  3. 实现检查点机制,支持断点恢复
  4. 提供基本的监控和日志功能

环境设置

首先,让我们设置项目环境。我们将使用Python 3.9+,并依赖以下库:

  • networkx:用于处理图结构
  • pydantic:用于数据验证
  • python-dotenv:用于环境变量管理
  • uuid:用于生成唯一标识符(Python标准库)
  • json:用于序列化(Python标准库)
  • datetime:用于处理时间(Python标准库)
  • concurrent.futures:用于并行执行(Python标准库)

让我们创建一个requirements.txt文件:

networkx>=3.1
pydantic>=2.0
python-dotenv>=1.0

然后安装这些依赖:

pip install -r requirements.txt

系统设计

在开始编码之前,让我们先设计系统的核心组件和架构。

核心概念模型

我们的系统将包含以下核心概念:

  1. Task:工作流中的基本执行单元,对应DAG中的节点。
  2. Dependency:表示任务之间的依赖关系,对应DAG中的边。
  3. Workflow:由任务和依赖关系组成的DAG。
  4. TaskInstance:任务的具体执行实例,包含执行状态和结果。
  5. Checkpoint:工作流执行过程中的状态快照,用于断点恢复。
  6. Executor:负责实际执行任务的组件。
  7. Scheduler:负责根据依赖关系调度任务执行的组件。
  8. StateStore:负责存储工作流执行状态和检查点的组件。

让我们用ER图来表示这些概念之间的关系:

contains

defines

is predecessor in

is successor in

has

is executed as

creates

is included in

manages

executes

stores

stores

stores

Workflow

Task

Dependency

TaskInstance

Checkpoint

Scheduler

Executor

StateStore

系统架构

我们的系统将采用分层架构,从上到下依次为:

  1. API层:提供定义和管理工作流的接口
  2. 调度层:负责解析工作流定义,调度任务执行
  3. 执行层:负责实际执行任务
  4. 存储层:负责持久化工作流状态和检查点

让我们用架构图来表示:

用户

API层

调度层

执行层

存储层

状态模型

任务和工作流都有其生命周期状态。让我们定义这些状态:

任务状态:

  • PENDING:任务已定义,但尚未准备好执行
  • READY:任务的所有依赖已完成,可以开始执行
  • RUNNING:任务正在执行
  • COMPLETED:任务执行成功
  • FAILED:任务执行失败
  • SKIPPED:任务被跳过(通常是因为依赖任务失败)

工作流状态:

  • CREATED:工作流已创建,但尚未开始执行
  • RUNNING:工作流正在执行
  • COMPLETED:工作流中所有任务都已成功完成
  • FAILED:工作流中有一个或多个任务失败
  • PAUSED:工作流被暂停
  • CANCELLED:工作流被取消

让我们用状态图来表示任务的状态转换:

dependencies completed

start execution

success

failure

retry

dependency failed

dependency failed

PENDING

READY

RUNNING

COMPLETED

FAILED

SKIPPED

核心实现

现在让我们开始实现这些核心组件。我们将按照从底层到上层的顺序进行。

第一步:定义数据模型

首先,让我们使用Pydantic定义核心数据模型。创建一个models.py文件:

import uuid
from enum import Enum
from datetime import datetime
from typing import Any, Dict, List, Optional, Set
from pydantic import BaseModel, Field, field_validator


class TaskStatus(str, Enum):
    """任务状态枚举"""
    PENDING = "pending"
    READY = "ready"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    SKIPPED = "skipped"


class WorkflowStatus(str, Enum):
    """工作流状态枚举"""
    CREATED = "created"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    PAUSED = "paused"
    CANCELLED = "cancelled"


class Task(BaseModel):
    """任务模型"""
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    name: str
    description: Optional[str] = None
    function: str  # 函数路径,例如 "mymodule.myfunction"
    parameters: Dict[str, Any] = Field(default_factory=dict)
    retry_count: int = 0
    max_retries: int = 3
    timeout: Optional[int] = None  # 超时时间(秒)
    
    def __hash__(self):
        return hash(self.id)
    
    def __eq__(self, other):
        if isinstance(other, Task):
            return self.id == other.id
        return False


class Dependency(BaseModel):
    """依赖关系模型"""
    source_task_id: str  # 前置任务ID
    target_task_id: str  # 后续任务ID
    condition: Optional[str] = None  # 可选的依赖条件
    
    def __hash__(self):
        return hash((self.source_task_id, self.target_task_id))
    
    def __eq__(self, other):
        if isinstance(other, Dependency):
            return self.source_task_id == other.source_task_id and self.target_task_id == other.target_task_id
        return False


class TaskInstance(BaseModel):
    """任务实例模型"""
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    task_id: str
    workflow_id: str
    status: TaskStatus = TaskStatus.PENDING
    result: Optional[Any] = None
    error: Optional[str] = None
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    retry_count: int = 0
    input_data: Dict[str, Any] = Field(default_factory=dict)
    
    def duration(self) -> Optional[float]:
        """计算任务执行时长(秒)"""
        if self.start_time and self.end_time:
            return (self.end_time - self.start_time).total_seconds()
        return None


class Checkpoint(BaseModel):
    """检查点模型"""
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    workflow_id: str
    created_at: datetime = Field(default_factory=datetime.now)
    task_instances: Dict[str, TaskInstance] = Field(default_factory=dict)
    workflow_status: WorkflowStatus
    metadata: Dict[str, Any] = Field(default_factory=dict)


class Workflow(BaseModel):
    """工作流模型"""
    id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    name: str
    description: Optional[str] = None
    tasks: Dict[str, Task] = Field(default_factory=dict)
    dependencies: Set[Dependency] = Field(default_factory=set)
    status: WorkflowStatus = WorkflowStatus.CREATED
    created_at: datetime = Field(default_factory=datetime.now)
    updated_at: datetime = Field(default_factory=datetime.now)
    current_checkpoint_id: Optional[str] = None
    metadata: Dict[str, Any] = Field(default_factory=dict)
    
    @field_validator('dependencies')
    @classmethod
    def validate_dependencies(cls, dependencies, info):
        """验证依赖关系是否引用了存在的任务"""
        task_ids = info.data.get('tasks', {}).keys()
        for dep in dependencies:
            if dep.source_task_id not in task_ids:
                raise ValueError(f"Dependency source task {dep.source_task_id} not found in tasks")
            if dep.target_task_id not in task_ids:
                raise ValueError(f"Dependency target task {dep.target_task_id} not found in tasks")
        return dependencies
    
    def add_task(self, task: Task) -> None:
        """添加任务到工作流"""
        self.tasks[task.id] = task
        self.updated_at = datetime.now()
    
    def add_dependency(self, source_task_id: str, target_task_id: str, condition: Optional[str] = None) -> None:
        """添加任务依赖关系"""
        if source_task_id not in self.tasks:
            raise ValueError(f"Source task {source_task_id} not found")
        if target_task_id not in self.tasks:
            raise ValueError(f"Target task {target_task_id} not found")
        
        dependency = Dependency(
            source_task_id=source_task_id,
            target_task_id=target_task_id,
            condition=condition
        )
        self.dependencies.add(dependency)
        self.updated_at = datetime.now()
    
    def get_task_dependencies(self, task_id: str) -> Set[str]:
        """获取指定任务的所有依赖任务ID"""
        return {
            dep.source_task_id 
            for dep in self.dependencies 
            if dep.target_task_id == task_id
        }
    
    def get_task_dependents(self, task_id: str) -> Set[str]:
        """获取依赖于指定任务的所有任务ID"""
        return {
            dep.target_task_id 
            for dep in self.dependencies 
            if dep.source_task_id == task_id
        }

这些模型定义了我们系统的核心数据结构。注意我们使用了Pydantic的验证功能,确保依赖关系引用的任务确实存在。

第二步:实现状态存储

接下来,让我们实现状态存储组件。这个组件负责持久化工作流、任务实例和检查点。为了简单起见,我们先实现一个基于内存的版本,然后再讨论如何扩展到持久化存储。

创建一个state_store.py文件:

import json
import os
from datetime import datetime
from typing import Dict, List, Optional
from models import Workflow, TaskInstance, Checkpoint, WorkflowStatus, TaskStatus


class StateStore:
    """状态存储抽象基类"""
    
    def save_workflow(self, workflow: Workflow) -> None:
        """保存工作流"""
        raise NotImplementedError
    
    def get_workflow(self, workflow_id: str) -> Optional[Workflow]:
        """获取工作流"""
        raise NotImplementedError
    
    def list_workflows(self) -> List[Workflow]:
        """列出所有工作流"""
        raise NotImplementedError
    
    def save_task_instance(self, task_instance: TaskInstance) -> None:
        """保存任务实例"""
        raise NotImplementedError
    
    def get_task_instance(self, task_instance_id: str) -> Optional[TaskInstance]:
        """获取任务实例"""
        raise NotImplementedError
    
    def get_task_instances_by_workflow(self, workflow_id: str) -> List[TaskInstance]:
        """获取工作流的所有任务实例"""
        raise NotImplementedError
    
    def save_checkpoint(self, checkpoint: Checkpoint) -> None:
        """保存检查点"""
        raise NotImplementedError
    
    def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
        """获取检查点"""
        raise NotImplementedError
    
    def get_latest_checkpoint(self, workflow_id: str) -> Optional[Checkpoint]:
        """获取工作流的最新检查点"""
        raise NotImplementedError
    
    def list_checkpoints(self, workflow_id: str) -> List[Checkpoint]:
        """列出工作流的所有检查点"""
        raise NotImplementedError


class InMemoryStateStore(StateStore):
    """基于内存的状态存储实现"""
    
    def __init__(self):
        self._workflows: Dict[str, Workflow] = {}
        self._task_instances: Dict[str, TaskInstance] = {}
        self._checkpoints: Dict[str, Checkpoint] = {}
        self._workflow_checkpoints: Dict[str, List[str]] = {}  # workflow_id -> list of checkpoint_ids
    
    def save_workflow(self, workflow: Workflow) -> None:
        workflow.updated_at = datetime.now()
        self._workflows[workflow.id] = workflow.model_copy(deep=True)
    
    def get_workflow(self, workflow_id: str) -> Optional[Workflow]:
        if workflow_id in self._workflows:
            return self._workflows[workflow_id].model_copy(deep=True)
        return None
    
    def list_workflows(self) -> List[Workflow]:
        return [workflow.model_copy(deep=True) for workflow in self._workflows.values()]
    
    def save_task_instance(self, task_instance: TaskInstance) -> None:
        self._task_instances[task_instance.id] = task_instance.model_copy(deep=True)
    
    def get_task_instance(self, task_instance_id: str) -> Optional[TaskInstance]:
        if task_instance_id in self._task_instances:
            return self._task_instances[task_instance_id].model_copy(deep=True)
        return None
    
    def get_task_instances_by_workflow(self, workflow_id: str) -> List[TaskInstance]:
        return [
            instance.model_copy(deep=True)
            for instance in self._task_instances.values()
            if instance.workflow_id == workflow_id
        ]
    
    def save_checkpoint(self, checkpoint: Checkpoint) -> None:
        self._checkpoints[checkpoint.id] = checkpoint.model_copy(deep=True)
        
        if checkpoint.workflow_id not in self._workflow_checkpoints:
            self._workflow_checkpoints[checkpoint.workflow_id] = []
        self._workflow_checkpoints[checkpoint.workflow_id].append(checkpoint.id)
    
    def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
        if checkpoint_id in self._checkpoints:
            return self._checkpoints[checkpoint_id].model_copy(deep=True)
        return None
    
    def get_latest_checkpoint(self, workflow_id: str) -> Optional[Checkpoint]:
        if workflow_id not in self._workflow_checkpoints:
            return None
        
        checkpoint_ids = self._workflow_checkpoints[workflow_id]
        if not checkpoint_ids:
            return None
        
        # 假设checkpoint_ids是按时间顺序添加的,取最后一个
        latest_checkpoint_id = checkpoint_ids[-1]
        return self.get_checkpoint(latest_checkpoint_id)
    
    def list_checkpoints(self, workflow_id: str) -> List[Checkpoint]:
        if workflow_id not in self._workflow_checkpoints:
            return []
        
        return [
            self.get_checkpoint(checkpoint_id)
            for checkpoint_id in self._workflow_checkpoints[workflow_id]
            if self.get_checkpoint(checkpoint_id) is not None
        ]


class FileStateStore(InMemoryStateStore):
    """基于文件的状态存储实现"""
    
    def __init__(self, storage_dir: str = "./workflow_data"):
        super().__init__()
        self._storage_dir = storage_dir
        self._workflows_dir = os.path.join(storage_dir, "workflows")
        self._task_instances_dir = os.path.join(storage_dir, "task_instances")
        self._checkpoints_dir = os.path.join(storage_dir, "checkpoints")
        
        # 创建存储目录
        for dir_path in [self._workflows_dir, self._task_instances_dir, self._checkpoints_dir]:
            os.makedirs(dir_path, exist_ok=True)
        
        # 从文件加载数据
        self._load_from_files()
    
    def _load_from_files(self) -> None:
        """从文件加载数据到内存"""
        # 加载工作流
        for filename in os.listdir(self._workflows_dir):
            if filename.endswith(".json"):
                file_path = os.path.join(self._workflows_dir, filename)
                with open(file_path, "r") as f:
                    workflow_data = json.load(f)
                    workflow = Workflow.model_validate(workflow_data)
                    self._workflows[workflow.id] = workflow
        
        # 加载任务实例
        for filename in os.listdir(self._task_instances_dir):
            if filename.endswith(".json"):
                file_path = os.path.join(self._task_instances_dir, filename)
                with open(file_path, "r") as f:
                    instance_data = json.load(f)
                    instance = TaskInstance.model_validate(instance_data)
                    self._task_instances[instance.id] = instance
        
        # 加载检查点
        for filename in os.listdir(self._checkpoints_dir):
            if filename.endswith(".json"):
                file_path = os.path.join(self._checkpoints_dir, filename)
                with open(file_path, "r") as f:
                    checkpoint_data = json.load(f)
                    checkpoint = Checkpoint.model_validate(checkpoint_data)
                    self._checkpoints[checkpoint.id] = checkpoint
                    
                    if checkpoint.workflow_id not in self._workflow_checkpoints:
                        self._workflow_checkpoints[checkpoint.workflow_id] = []
                    self._workflow_checkpoints[checkpoint.workflow_id].append(checkpoint.id)
    
    def _save_to_file(self, obj, directory: str, obj_id: str) -> None:
        """保存对象到文件"""
        file_path = os.path.join(directory, f"{obj_id}.json")
        with open(file_path, "w") as f:
            # 使用model_dump()而不是dict()
            json.dump(obj.model_dump(), f, default=str)
    
    def save_workflow(self, workflow: Workflow) -> None:
        super().save_workflow(workflow)
        self._save_to_file(workflow, self._workflows_dir, workflow.id)
    
    def save_task_instance(self, task_instance: TaskInstance) -> None:
        super().save_task_instance(task_instance)
        self._save_to_file(task_instance, self._task_instances_dir, task_instance.id)
    
    def save_checkpoint(self, checkpoint: Checkpoint) -> None:
        super().save_checkpoint(checkpoint)
        self._save_to_file(checkpoint, self._checkpoints_dir, checkpoint.id)

这里我们实现了两个状态存储:

  1. InMemoryStateStore:完全基于内存,简单但不持久
  2. FileStateStore:继承自InMemoryStateStore,增加了文件持久化功能

在实际生产环境中,你可能会使用数据库(如PostgreSQL、MongoDB)或分布式存储系统(如Redis、Etcd)来实现状态存储。

第三步:实现执行器

执行器负责实际执行任务。它需要能够动态加载函数并执行它们。

创建一个executor.py文件:

import importlib
import time
import traceback
from datetime import datetime
from typing import Any, Callable, Dict, Optional
from models import Task, TaskInstance, TaskStatus


class TaskExecutor:
    """任务执行器"""
    
    def __init__(self):
        self._function_cache: Dict[str, Callable] = {}
    
    def _load_function(self, function_path: str) -> Callable:
        """动态加载函数"""
        if function_path in self._function_cache:
            return self._function_cache[function_path]
        
        try:
            module_name, function_name = function_path.rsplit('.', 1)
            module = importlib.import_module(module_name)
            function = getattr(module, function_name)
            
            if not callable(function):
                raise ValueError(f"{function_path} is not callable")
            
            self._function_cache[function_path] = function
            return function
        except (ImportError, AttributeError, ValueError) as e:
            raise RuntimeError(f"Failed to load function {function_path}: {str(e)}")
    
    def execute_task(
        self,
        task: Task,
        task_instance: TaskInstance,
        input_data: Dict[str, Any],
        timeout: Optional[int] = None
    ) -> TaskInstance:
        """执行任务"""
        # 更新任务实例状态为运行中
        task_instance.status = TaskStatus.RUNNING
        task_instance.start_time = datetime.now()
        task_instance.input_data = input_data
        
        try:
            # 加载并执行函数
            function = self._load_function(task.function)
            
            # 合并任务参数和输入数据
            kwargs = {**task.parameters, **input_data}
            
            # 执行函数(这里简化处理,没有实现真正的超时机制)
            # 实际生产中可能需要使用多进程或线程来实现超时
            start_time = time.time()
            result = function(**kwargs)
            execution_time = time.time() - start_time
            
            # 更新任务实例为成功状态
            task_instance.status = TaskStatus.COMPLETED
            task_instance.result = result
            task_instance.end_time = datetime.now()
            
        except Exception as e:
            # 更新任务实例为失败状态
            task_instance.status = TaskStatus.FAILED
            task_instance.error = f"{str(e)}\n{traceback.format_exc()}"
            task_instance.end_time = datetime.now()
            
            # 如果还有重试次数,可以考虑标记为重试
            task_instance.retry_count += 1
        
        return task_instance

这个执行器实现了基本的任务执行功能,包括动态加载函数、执行函数、处理成功和失败情况。注意我们这里简化了超时机制的实现,实际生产中可能需要更健壮的实现。

第四步:实现调度器

调度器是整个系统的核心,它负责解析工作流定义、管理任务依赖关系、调度任务执行,以及创建检查点。

创建一个scheduler.py文件:

import networkx as nx
from datetime import datetime
from typing import Dict, List, Optional, Set
from concurrent.futures import ThreadPoolExecutor, as_completed

from models import (
    Workflow, Task, TaskInstance, Checkpoint,
    WorkflowStatus, TaskStatus, Dependency
)
from state_store import StateStore
from executor import TaskExecutor


class WorkflowScheduler:
    """工作流调度器"""
    
    def __init__(self, state_store: StateStore, max_workers: int = 4):
        self._state_store = state_store
        self._executor = TaskExecutor()
        self._thread_pool = ThreadPoolExecutor(max_workers=max_workers)
    
    def _build_dag(self, workflow: Workflow) -> nx.DiGraph:
        """从工作流构建DAG"""
        graph = nx.DiGraph()
        
        # 添加节点
        for task_id in workflow.tasks:
            graph.add_node(task_id)
        
        # 添加边
        for dep in workflow.dependencies:
            graph.add_edge(dep.source_task_id, dep.target_task_id)
        
        # 验证是否为DAG
        if not nx.is_directed_acyclic_graph(graph):
            raise ValueError("Workflow is not a DAG")
        
        return graph
    
    def _create_task_instances(self, workflow: Workflow) -> Dict[str, TaskInstance]:
        """为工作流中的所有任务创建实例"""
        task_instances = {}
        
        for task_id, task in workflow.tasks.items():
            instance = TaskInstance(
                task_id=task_id,
                workflow_id=workflow.id,
                status=TaskStatus.PENDING
            )
            task_instances[task_id] = instance
            self._state_store.save_task_instance(instance)
        
        return task_instances
    
    def _get_ready_tasks(
        self,
        graph: nx.DiGraph,
        task_instances: Dict[str, TaskInstance]
    ) -> List[str]:
        """获取所有准备好执行的任务ID"""
        ready_tasks = []
        
        for task_id in graph.nodes:
            instance = task_instances.get(task_id)
            if not instance:
                continue
            
            # 只处理待处理或准备好的任务
            if instance.status not in [TaskStatus.PENDING, TaskStatus.READY]:
                continue
            
            # 检查所有依赖是否已完成
            dependencies = list(graph.predecessors(task_id))
            all_dependencies_completed = True
            
            for dep_id in dependencies:
                dep_instance = task_instances.get(dep_id)
                if not dep_instance or dep_instance.status != TaskStatus.COMPLETED:
                    all_dependencies_completed = False
                    break
            
            if all_dependencies_completed:
                # 检查是否有依赖失败导致需要跳过
                should_skip = False
                for dep_id in dependencies:
                    dep_instance = task_instances.get(dep_id)
                    if dep_instance and dep_instance.status == TaskStatus.FAILED:
                        should_skip = True
                        break
                
                if should_skip:
                    instance.status = TaskStatus.SKIPPED
                    self._state_store.save_task_instance(instance)
                else:
                    instance.status = TaskStatus.READY
                    ready_tasks.append(task_id)
                    self._state_store.save_task_instance(instance)
        
        return ready_tasks
    
    def _create_checkpoint(
        self,
        workflow: Workflow,
        task_instances: Dict[str, TaskInstance],
        metadata: Optional[Dict] = None
    ) -> Checkpoint:
        """创建检查点"""
        checkpoint = Checkpoint(
            workflow_id=workflow.id,
            task_instances={
                task_id: instance.model_copy(deep=True)
                for task_id, instance in task_instances.items()
            },
            workflow_status=workflow.status,
            metadata=metadata or {}
        )
        
        self._state_store.save_checkpoint(checkpoint)
        workflow.current_checkpoint_id = checkpoint.id
        self._state_store.save_workflow(workflow)
        
        return checkpoint
    
    def _collect_input_data(
        self,
        task_id: str,
        graph: nx.DiGraph,
        task_instances: Dict[str, TaskInstance]
    ) -> Dict:
        """收集任务的输入数据(来自依赖任务的输出)"""
        input_data = {}
        
        for dep_id in graph.predecessors(task_id):
            dep_instance = task_instances.get(dep_id)
            if dep_instance and dep_instance.status == TaskStatus.COMPLETED:
                # 这里简单地将所有依赖的结果合并到输入数据中
                # 实际应用中可能需要更复杂的数据传递机制
                if isinstance(dep_instance.result, dict):
                    input_data.update(dep_instance.result)
                else:
                    input_data[f"{dep_id}_result"] = dep_instance.result
        
        return input_data
    
    def execute_workflow(
        self,
        workflow_id: str,
        checkpoint_id: Optional[str] = None
    ) -> Workflow:
        """执行工作流"""
        # 获取工作流
        workflow = self._state_store.get_workflow(workflow_id)
        if not workflow:
            raise ValueError(f"Workflow {workflow_id} not found")
        
        # 构建DAG
        graph = self._build_dag(workflow)
        
        # 加载或创建任务实例
        if checkpoint_id:
            # 从检查点恢复
            checkpoint = self._state_store.get_checkpoint(checkpoint_id)
            if not checkpoint or checkpoint.workflow_id != workflow_id:
                raise ValueError(f"Checkpoint {checkpoint_id} not found or not for this workflow")
            
            task_instances = checkpoint.task_instances
            workflow.status = checkpoint.workflow_status
            
            # 保存恢复的任务实例
            for instance in task_instances.values():
                self._state_store.save_task_instance(instance)
        else:
            # 创建新的任务实例
            task_instances = self._create_task_instances(workflow)
        
        # 更新工作流状态为运行中
        workflow.status = WorkflowStatus.RUNNING
        self._state_store.save_workflow(workflow)
        
        # 创建初始检查点
        self._create_checkpoint(workflow, task_instances, {"type": "start"})
        
        try:
            # 执行循环
            while True:
                # 检查工作流是否已完成或失败
                completed_count = sum(
                    1 for instance in task_instances.values()
                    if instance.status == TaskStatus.COMPLETED
                )
                failed_count = sum(
                    1 for instance in task_instances.values()
                    if instance.status == TaskStatus.FAILED
                )
                skipped_count = sum(
                    1 for instance in task_instances.values()
                    if instance.status == TaskStatus.SKIPPED
                )
                
                total_tasks = len(task_instances)
                
                if completed_count + failed_count + skipped_count == total_tasks:
                    # 所有任务都已处理完毕
                    if failed_count > 0:
                        workflow.status = WorkflowStatus.FAILED
                    else:
                        workflow.status = WorkflowStatus.COMPLETED
                    break
                
                # 获取准备好执行的任务
                ready_task_ids = self._get_ready_tasks(graph, task_instances)
                
                if not ready_task_ids:
                    # 没有准备好的任务,但也没有完成所有任务,可能是死锁或依赖问题
                    # 这里简化处理,标记为失败
                    workflow.status = WorkflowStatus.FAILED
                    break
                
                # 并行执行准备好的任务
                futures = {}
                
                for task_id in ready_task_ids:
                    task = workflow.tasks[task_id]
                    instance = task_instances[task_id]
                    input_data = self._collect_input_data(task_id, graph, task_instances)
                    
                    # 提交任务到线程池
                    future = self._thread_pool.submit(
                        self._executor.execute_task,
                        task,
                        instance,
                        input_data,
                        task.timeout
                    )
                    futures[future] = task_id
                
                # 处理执行结果
                for future in as_completed(futures):
                    task_id = futures[future]
                    try:
                        updated_instance = future.result()
                        task_instances[task_id] = updated_instance
                        self._state_store.save_task_instance(updated_instance)
                    except Exception as e:
                        # 处理执行过程中的异常
                        instance = task_instances[task_id]
                        instance.status = TaskStatus.FAILED
                        instance.error = str(e)
                        instance.end_time = datetime.now()
                        self._state_store.save_task_instance(instance)
                
                # 创建检查点
                self._create_checkpoint(workflow, task_instances, {"type": "progress"})
        
        except Exception as e:
            # 处理工作流执行过程中的异常
            workflow.status = WorkflowStatus.FAILED
            # 这里可以记录错误信息到workflow.metadata
            workflow.metadata["error"] = str(e)
        
        finally:
            # 保存最终状态
            self._state_store.save_workflow(workflow)
            # 创建最终检查点
            self._create_checkpoint(workflow, task_instances, {"type": "end"})
        
        return workflow
    
    def resume_workflow(self, workflow_id: str) -> Workflow:
        """从最新检查点恢复工作流"""
        workflow = self._state_store.get_workflow(workflow_id)
        if not workflow:
            raise ValueError(f"Workflow {workflow_id} not found")
        
        latest_checkpoint = self._state_store.get_latest_checkpoint(workflow_id)
        if not latest_checkpoint:
            raise ValueError(f"No checkpoints found for workflow {workflow_id}")
        
        return self.execute_workflow(workflow_id, latest_checkpoint.id)

这个调度器实现了工作流执行的核心逻辑:

  1. 构建和验证DAG
  2. 管理任务实例的生命周期
  3. 确定哪些任务准备好执行
  4. 并行执行任务
  5. 创建检查点用于恢复
  6. 从检查点恢复工作流执行
第五步:构建工作流API

现在让我们创建一个简单的API层,使用户能够方便地定义和执行工作流。

创建一个workflow_api.py文件:

from typing import Any, Callable, Dict, List, Optional
from models import Workflow, Task, WorkflowStatus
from state_store import StateStore, FileStateStore
from scheduler import WorkflowScheduler


class WorkflowAPI:
    """工作流API"""
    
    def __init__(self, state_store: Optional[StateStore] = None, max_workers: int = 4):
        self._state_store = state_store or FileStateStore()
        self._scheduler = WorkflowScheduler(self._state_store, max_workers)
    
    def create_workflow(self, name: str, description: Optional[str] = None) -> Workflow:
        """创建一个新的工作流"""
        workflow = Workflow(name=name, description=description)
        self._state_store.save_workflow(workflow)
        return workflow
    
    def add_task(
        self,
        workflow_id: str,
        name: str,
        function: str,
        parameters: Optional[Dict[str, Any]] = None,
        description: Optional[str] = None,
        max_retries: int = 3,
        timeout: Optional[int] = None
    ) -> Task:
        """向工作流添加任务"""
        workflow = self._state_store.get_workflow(workflow_id)
        if not workflow:
            raise ValueError(f"Workflow {workflow_id} not found")
        
        task = Task(
            name=name,
            description=description,
            function=function,
            parameters=parameters or {},
            max_retries=max_retries,
            timeout=timeout
        )
        
        workflow.add_task(task)
        self._state_store.save_workflow(workflow)
        return task
    
    def add_dependency(
        self,
        workflow_id: str,
        source_task_id: str,
        target_task_id: str,
        condition: Optional[str] = None
    ) -> None:
        """添加任务依赖关系"""
        workflow = self._state_store.get_workflow(workflow_id)
        if not workflow:
            raise ValueError(f"Workflow {workflow_id} not found")
        
        workflow.add_dependency(source_task_id, target_task_id, condition)
        self._state_store.save_workflow(workflow)
    
    def get_workflow(self, workflow_id: str) -> Optional[Workflow]:
        """获取工作流"""
        return self._state_store.get_workflow(workflow_id)
    
    def list_workflows(self) -> List[Workflow]:
        """列出所有工作流"""
        return self._state_store.list_workflows()
    
    def execute_workflow(self, workflow_id: str, checkpoint_id: Optional[str] = None) -> Workflow:
        """执行工作流"""
        return self._scheduler.execute_workflow(workflow_id, checkpoint_id)
    
    def resume_workflow(self, workflow_id: str) -> Workflow:
        """从最新检查点恢复工作流"""
        return self._scheduler.resume_workflow(workflow_id)
    
    def get_workflow_status(self, workflow_id: str) -> Optional[WorkflowStatus]:
        """获取工作流状态"""
        workflow = self._state_store.get_workflow(workflow_id)
        if workflow:
            return workflow.status
        return None
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐