Agent工作流Harness:DAG编排与断点恢复
Agent工作流Harness:DAG编排与断点恢复
一、 引言
钩子:从一个常见的自动化困境开始
你是否曾经遇到过这样的场景:你精心设计了一个自动化流程,用于处理一批数据生成报告,流程包含了数据提取、清洗、转换、分析和可视化等多个步骤。一切看起来都很完美,直到……在流程执行到第8个步骤时,因为网络超时导致了一个API调用失败。整个流程戛然而止,之前7个步骤所做的工作全部白费,你不得不从头开始重新运行整个流程。更糟糕的是,你甚至不知道在失败之前,哪些数据已经处理完毕,哪些还没有。
这种场景在数据工程、机器学习流水线、DevOps自动化以及各种复杂的业务流程中太常见了。当我们处理的任务变得越来越复杂,涉及的步骤越来越多,依赖关系越来越错综,简单的线性脚本执行方式就显得力不从心了。我们需要一种更强大、更可靠的方式来组织和执行这些复杂的工作流。
定义问题/阐述背景
在当今的软件开发和数据处理领域,工作流编排已经成为一个核心挑战。特别是随着AI Agent概念的兴起,我们需要协调多个智能体协同工作,每个Agent可能有自己的输入输出、执行逻辑和依赖关系。这就对工作流系统提出了更高的要求:
- 复杂依赖管理:任务之间可能存在复杂的依赖关系,不是简单的线性执行,而是形成有向无环图(DAG)结构。
- 容错与恢复:当某个任务失败时,我们希望能够从失败点恢复,而不是重新执行整个流程。
- 可观测性:我们需要能够监控工作流的执行状态,了解每个任务的执行情况。
- 并行执行:对于没有依赖关系的任务,我们希望能够并行执行以提高效率。
- 动态调整:在某些情况下,我们可能需要根据前序任务的执行结果动态调整后续任务。
正是在这样的背景下,Agent工作流Harness应运而生。Harness在这里指的是一个框架或平台,它提供了一套工具和抽象,帮助我们定义、编排和执行复杂的Agent工作流,特别是支持DAG结构的编排和强大的断点恢复能力。
亮明观点/文章目标
在本文中,我们将深入探讨Agent工作流Harness的核心概念、设计原理和实现方法。具体来说,我们将:
- 介绍DAG编排的基本概念和重要性
- 详细分析断点恢复机制的实现原理
- 通过一个实战项目,从零开始构建一个简单但功能完整的工作流Harness
- 探讨如何在实际场景中应用这些概念
- 分享一些最佳实践和常见陷阱
通过阅读本文,你将不仅理解Agent工作流Harness的理论基础,还将掌握实际构建和使用这样一个系统的技能。无论你是在构建数据管道、机器学习流水线,还是复杂的业务流程自动化系统,本文的内容都将对你有所帮助。
二、 基础知识/背景铺垫
核心概念定义
在深入探讨Agent工作流Harness之前,我们需要先明确一些核心概念。
什么是Agent?
在本文的上下文中,Agent指的是一个封装了特定功能的执行单元。它可以是一个简单的函数,也可以是一个复杂的服务,甚至可以是一个AI模型。Agent具有以下特点:
- 独立性:Agent应该是相对独立的,能够自主完成特定任务。
- 输入输出明确:Agent应该有明确的输入和输出定义,便于与其他Agent协作。
- 可组合性:多个Agent可以组合在一起,形成更复杂的系统。
什么是工作流(Workflow)?
工作流是一系列相互关联的任务的集合,这些任务按照一定的规则和顺序执行,以达到特定的目标。工作流定义了任务之间的依赖关系、执行顺序和数据流向。
什么是DAG?
DAG是有向无环图(Directed Acyclic Graph)的缩写。它是一种特殊的图结构,具有以下特点:
- 有向:图中的边有方向,表示从一个节点指向另一个节点。
- 无环:图中没有循环,即不能从一个节点出发,沿着边的方向又回到该节点。
在工作流编排中,DAG是一种非常自然的表示方式:
- 节点表示任务(或Agent)
- 边表示任务之间的依赖关系
- 无环保证了工作流不会陷入无限循环
什么是编排(Orchestration)?
编排指的是对工作流中的任务进行协调、调度和管理的过程。编排系统负责:
- 解析工作流定义
- 管理任务之间的依赖关系
- 调度任务的执行
- 处理任务失败和重试
- 收集和报告执行状态
什么是断点恢复(Checkpointing and Recovery)?
断点恢复是一种机制,允许工作流在失败后从失败点继续执行,而不是从头开始。这需要:
- 在执行过程中定期保存状态(检查点)
- 能够从保存的状态中恢复执行环境
- 能够确定哪些任务已经完成,哪些需要重新执行
相关工具/技术概览
在Agent工作流编排领域,已经有一些成熟的工具和技术。让我们简要了解一下:
开源工作流编排工具
-
Apache Airflow:这是最流行的开源工作流编排工具之一,使用Python编写。它使用DAG来定义工作流,提供了丰富的操作符和强大的调度能力。
-
Prefect:一个相对较新的工作流编排工具,旨在解决Airflow的一些痛点。它提供了更现代的API和更好的开发体验。
-
Argo Workflows:专为Kubernetes设计的工作流编排工具,使用CRD(自定义资源定义)来定义工作流。
-
Temporal:一个专注于可靠性和可扩展性的工作流编排平台,提供了强大的容错和恢复能力。
AI Agent框架
-
LangChain:一个用于构建由语言模型驱动的应用程序的框架,支持链式调用和Agent工作流。
-
AutoGPT:一个自主AI Agent项目,可以自主设定目标并执行任务。
-
CrewAI:一个用于协调多个AI Agent协作的框架,支持角色定义和任务分配。
虽然这些工具都很强大,但它们往往比较复杂,有一定的学习曲线。在本文中,我们将构建一个简化版的工作流Harness,帮助你理解这些工具背后的核心原理。
DAG编排的数学基础
在深入实现之前,让我们先了解一些DAG编排的数学基础。
拓扑排序
拓扑排序是对DAG的顶点进行排序的一种算法,使得对于每一条有向边(u, v),顶点u在排序中都出现在顶点v的前面。拓扑排序是DAG编排的基础,因为它决定了任务的执行顺序。
Kahn算法是一种常用的拓扑排序算法,其基本思想是:
- 计算每个节点的入度(指向该节点的边的数量)
- 将所有入度为0的节点加入队列
- 从队列中取出一个节点,将其加入拓扑排序结果
- 将该节点的所有邻居的入度减1
- 如果某个邻居的入度变为0,将其加入队列
- 重复步骤3-5,直到队列为空
如果最终拓扑排序结果的节点数量小于图中节点的总数,说明图中存在环。
依赖管理
在DAG中,每个节点的执行依赖于其所有前驱节点的完成。我们可以用以下数学表达式来表示这种依赖关系:
设T={t1,t2,...,tn}T = \{t_1, t_2, ..., t_n\}T={t1,t2,...,tn}是工作流中的所有任务集合,D⊆T×TD \subseteq T \times TD⊆T×T是任务之间的依赖关系集合,其中(ti,tj)∈D(t_i, t_j) \in D(ti,tj)∈D表示任务tjt_jtj依赖于任务tit_iti的完成。
对于任意任务tjt_jtj,其前驱任务集合为pred(tj)={ti∣(ti,tj)∈D}pred(t_j) = \{t_i | (t_i, t_j) \in D\}pred(tj)={ti∣(ti,tj)∈D}。
任务tjt_jtj可以开始执行当且仅当其所有前驱任务都已成功完成:
executable(tj) ⟺ ∀ti∈pred(tj),completed(ti)=trueexecutable(t_j) \iff \forall t_i \in pred(t_j), completed(t_i) = trueexecutable(tj)⟺∀ti∈pred(tj),completed(ti)=true
三、 核心内容/实战演练
在这一部分,我们将从零开始构建一个简单但功能完整的Agent工作流Harness。我们将使用Python作为编程语言,因为它在数据处理和自动化领域非常流行,且有丰富的生态系统。
项目概述
我们的工作流Harness将具备以下核心功能:
- 支持以DAG形式定义工作流
- 支持任务并行执行
- 实现检查点机制,支持断点恢复
- 提供基本的监控和日志功能
环境设置
首先,让我们设置项目环境。我们将使用Python 3.9+,并依赖以下库:
networkx:用于处理图结构pydantic:用于数据验证python-dotenv:用于环境变量管理uuid:用于生成唯一标识符(Python标准库)json:用于序列化(Python标准库)datetime:用于处理时间(Python标准库)concurrent.futures:用于并行执行(Python标准库)
让我们创建一个requirements.txt文件:
networkx>=3.1
pydantic>=2.0
python-dotenv>=1.0
然后安装这些依赖:
pip install -r requirements.txt
系统设计
在开始编码之前,让我们先设计系统的核心组件和架构。
核心概念模型
我们的系统将包含以下核心概念:
- Task:工作流中的基本执行单元,对应DAG中的节点。
- Dependency:表示任务之间的依赖关系,对应DAG中的边。
- Workflow:由任务和依赖关系组成的DAG。
- TaskInstance:任务的具体执行实例,包含执行状态和结果。
- Checkpoint:工作流执行过程中的状态快照,用于断点恢复。
- Executor:负责实际执行任务的组件。
- Scheduler:负责根据依赖关系调度任务执行的组件。
- StateStore:负责存储工作流执行状态和检查点的组件。
让我们用ER图来表示这些概念之间的关系:
系统架构
我们的系统将采用分层架构,从上到下依次为:
- API层:提供定义和管理工作流的接口
- 调度层:负责解析工作流定义,调度任务执行
- 执行层:负责实际执行任务
- 存储层:负责持久化工作流状态和检查点
让我们用架构图来表示:
状态模型
任务和工作流都有其生命周期状态。让我们定义这些状态:
任务状态:
PENDING:任务已定义,但尚未准备好执行READY:任务的所有依赖已完成,可以开始执行RUNNING:任务正在执行COMPLETED:任务执行成功FAILED:任务执行失败SKIPPED:任务被跳过(通常是因为依赖任务失败)
工作流状态:
CREATED:工作流已创建,但尚未开始执行RUNNING:工作流正在执行COMPLETED:工作流中所有任务都已成功完成FAILED:工作流中有一个或多个任务失败PAUSED:工作流被暂停CANCELLED:工作流被取消
让我们用状态图来表示任务的状态转换:
核心实现
现在让我们开始实现这些核心组件。我们将按照从底层到上层的顺序进行。
第一步:定义数据模型
首先,让我们使用Pydantic定义核心数据模型。创建一个models.py文件:
import uuid
from enum import Enum
from datetime import datetime
from typing import Any, Dict, List, Optional, Set
from pydantic import BaseModel, Field, field_validator
class TaskStatus(str, Enum):
"""任务状态枚举"""
PENDING = "pending"
READY = "ready"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
class WorkflowStatus(str, Enum):
"""工作流状态枚举"""
CREATED = "created"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
CANCELLED = "cancelled"
class Task(BaseModel):
"""任务模型"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
description: Optional[str] = None
function: str # 函数路径,例如 "mymodule.myfunction"
parameters: Dict[str, Any] = Field(default_factory=dict)
retry_count: int = 0
max_retries: int = 3
timeout: Optional[int] = None # 超时时间(秒)
def __hash__(self):
return hash(self.id)
def __eq__(self, other):
if isinstance(other, Task):
return self.id == other.id
return False
class Dependency(BaseModel):
"""依赖关系模型"""
source_task_id: str # 前置任务ID
target_task_id: str # 后续任务ID
condition: Optional[str] = None # 可选的依赖条件
def __hash__(self):
return hash((self.source_task_id, self.target_task_id))
def __eq__(self, other):
if isinstance(other, Dependency):
return self.source_task_id == other.source_task_id and self.target_task_id == other.target_task_id
return False
class TaskInstance(BaseModel):
"""任务实例模型"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
task_id: str
workflow_id: str
status: TaskStatus = TaskStatus.PENDING
result: Optional[Any] = None
error: Optional[str] = None
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
retry_count: int = 0
input_data: Dict[str, Any] = Field(default_factory=dict)
def duration(self) -> Optional[float]:
"""计算任务执行时长(秒)"""
if self.start_time and self.end_time:
return (self.end_time - self.start_time).total_seconds()
return None
class Checkpoint(BaseModel):
"""检查点模型"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
workflow_id: str
created_at: datetime = Field(default_factory=datetime.now)
task_instances: Dict[str, TaskInstance] = Field(default_factory=dict)
workflow_status: WorkflowStatus
metadata: Dict[str, Any] = Field(default_factory=dict)
class Workflow(BaseModel):
"""工作流模型"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
description: Optional[str] = None
tasks: Dict[str, Task] = Field(default_factory=dict)
dependencies: Set[Dependency] = Field(default_factory=set)
status: WorkflowStatus = WorkflowStatus.CREATED
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
current_checkpoint_id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@field_validator('dependencies')
@classmethod
def validate_dependencies(cls, dependencies, info):
"""验证依赖关系是否引用了存在的任务"""
task_ids = info.data.get('tasks', {}).keys()
for dep in dependencies:
if dep.source_task_id not in task_ids:
raise ValueError(f"Dependency source task {dep.source_task_id} not found in tasks")
if dep.target_task_id not in task_ids:
raise ValueError(f"Dependency target task {dep.target_task_id} not found in tasks")
return dependencies
def add_task(self, task: Task) -> None:
"""添加任务到工作流"""
self.tasks[task.id] = task
self.updated_at = datetime.now()
def add_dependency(self, source_task_id: str, target_task_id: str, condition: Optional[str] = None) -> None:
"""添加任务依赖关系"""
if source_task_id not in self.tasks:
raise ValueError(f"Source task {source_task_id} not found")
if target_task_id not in self.tasks:
raise ValueError(f"Target task {target_task_id} not found")
dependency = Dependency(
source_task_id=source_task_id,
target_task_id=target_task_id,
condition=condition
)
self.dependencies.add(dependency)
self.updated_at = datetime.now()
def get_task_dependencies(self, task_id: str) -> Set[str]:
"""获取指定任务的所有依赖任务ID"""
return {
dep.source_task_id
for dep in self.dependencies
if dep.target_task_id == task_id
}
def get_task_dependents(self, task_id: str) -> Set[str]:
"""获取依赖于指定任务的所有任务ID"""
return {
dep.target_task_id
for dep in self.dependencies
if dep.source_task_id == task_id
}
这些模型定义了我们系统的核心数据结构。注意我们使用了Pydantic的验证功能,确保依赖关系引用的任务确实存在。
第二步:实现状态存储
接下来,让我们实现状态存储组件。这个组件负责持久化工作流、任务实例和检查点。为了简单起见,我们先实现一个基于内存的版本,然后再讨论如何扩展到持久化存储。
创建一个state_store.py文件:
import json
import os
from datetime import datetime
from typing import Dict, List, Optional
from models import Workflow, TaskInstance, Checkpoint, WorkflowStatus, TaskStatus
class StateStore:
"""状态存储抽象基类"""
def save_workflow(self, workflow: Workflow) -> None:
"""保存工作流"""
raise NotImplementedError
def get_workflow(self, workflow_id: str) -> Optional[Workflow]:
"""获取工作流"""
raise NotImplementedError
def list_workflows(self) -> List[Workflow]:
"""列出所有工作流"""
raise NotImplementedError
def save_task_instance(self, task_instance: TaskInstance) -> None:
"""保存任务实例"""
raise NotImplementedError
def get_task_instance(self, task_instance_id: str) -> Optional[TaskInstance]:
"""获取任务实例"""
raise NotImplementedError
def get_task_instances_by_workflow(self, workflow_id: str) -> List[TaskInstance]:
"""获取工作流的所有任务实例"""
raise NotImplementedError
def save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""保存检查点"""
raise NotImplementedError
def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
"""获取检查点"""
raise NotImplementedError
def get_latest_checkpoint(self, workflow_id: str) -> Optional[Checkpoint]:
"""获取工作流的最新检查点"""
raise NotImplementedError
def list_checkpoints(self, workflow_id: str) -> List[Checkpoint]:
"""列出工作流的所有检查点"""
raise NotImplementedError
class InMemoryStateStore(StateStore):
"""基于内存的状态存储实现"""
def __init__(self):
self._workflows: Dict[str, Workflow] = {}
self._task_instances: Dict[str, TaskInstance] = {}
self._checkpoints: Dict[str, Checkpoint] = {}
self._workflow_checkpoints: Dict[str, List[str]] = {} # workflow_id -> list of checkpoint_ids
def save_workflow(self, workflow: Workflow) -> None:
workflow.updated_at = datetime.now()
self._workflows[workflow.id] = workflow.model_copy(deep=True)
def get_workflow(self, workflow_id: str) -> Optional[Workflow]:
if workflow_id in self._workflows:
return self._workflows[workflow_id].model_copy(deep=True)
return None
def list_workflows(self) -> List[Workflow]:
return [workflow.model_copy(deep=True) for workflow in self._workflows.values()]
def save_task_instance(self, task_instance: TaskInstance) -> None:
self._task_instances[task_instance.id] = task_instance.model_copy(deep=True)
def get_task_instance(self, task_instance_id: str) -> Optional[TaskInstance]:
if task_instance_id in self._task_instances:
return self._task_instances[task_instance_id].model_copy(deep=True)
return None
def get_task_instances_by_workflow(self, workflow_id: str) -> List[TaskInstance]:
return [
instance.model_copy(deep=True)
for instance in self._task_instances.values()
if instance.workflow_id == workflow_id
]
def save_checkpoint(self, checkpoint: Checkpoint) -> None:
self._checkpoints[checkpoint.id] = checkpoint.model_copy(deep=True)
if checkpoint.workflow_id not in self._workflow_checkpoints:
self._workflow_checkpoints[checkpoint.workflow_id] = []
self._workflow_checkpoints[checkpoint.workflow_id].append(checkpoint.id)
def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
if checkpoint_id in self._checkpoints:
return self._checkpoints[checkpoint_id].model_copy(deep=True)
return None
def get_latest_checkpoint(self, workflow_id: str) -> Optional[Checkpoint]:
if workflow_id not in self._workflow_checkpoints:
return None
checkpoint_ids = self._workflow_checkpoints[workflow_id]
if not checkpoint_ids:
return None
# 假设checkpoint_ids是按时间顺序添加的,取最后一个
latest_checkpoint_id = checkpoint_ids[-1]
return self.get_checkpoint(latest_checkpoint_id)
def list_checkpoints(self, workflow_id: str) -> List[Checkpoint]:
if workflow_id not in self._workflow_checkpoints:
return []
return [
self.get_checkpoint(checkpoint_id)
for checkpoint_id in self._workflow_checkpoints[workflow_id]
if self.get_checkpoint(checkpoint_id) is not None
]
class FileStateStore(InMemoryStateStore):
"""基于文件的状态存储实现"""
def __init__(self, storage_dir: str = "./workflow_data"):
super().__init__()
self._storage_dir = storage_dir
self._workflows_dir = os.path.join(storage_dir, "workflows")
self._task_instances_dir = os.path.join(storage_dir, "task_instances")
self._checkpoints_dir = os.path.join(storage_dir, "checkpoints")
# 创建存储目录
for dir_path in [self._workflows_dir, self._task_instances_dir, self._checkpoints_dir]:
os.makedirs(dir_path, exist_ok=True)
# 从文件加载数据
self._load_from_files()
def _load_from_files(self) -> None:
"""从文件加载数据到内存"""
# 加载工作流
for filename in os.listdir(self._workflows_dir):
if filename.endswith(".json"):
file_path = os.path.join(self._workflows_dir, filename)
with open(file_path, "r") as f:
workflow_data = json.load(f)
workflow = Workflow.model_validate(workflow_data)
self._workflows[workflow.id] = workflow
# 加载任务实例
for filename in os.listdir(self._task_instances_dir):
if filename.endswith(".json"):
file_path = os.path.join(self._task_instances_dir, filename)
with open(file_path, "r") as f:
instance_data = json.load(f)
instance = TaskInstance.model_validate(instance_data)
self._task_instances[instance.id] = instance
# 加载检查点
for filename in os.listdir(self._checkpoints_dir):
if filename.endswith(".json"):
file_path = os.path.join(self._checkpoints_dir, filename)
with open(file_path, "r") as f:
checkpoint_data = json.load(f)
checkpoint = Checkpoint.model_validate(checkpoint_data)
self._checkpoints[checkpoint.id] = checkpoint
if checkpoint.workflow_id not in self._workflow_checkpoints:
self._workflow_checkpoints[checkpoint.workflow_id] = []
self._workflow_checkpoints[checkpoint.workflow_id].append(checkpoint.id)
def _save_to_file(self, obj, directory: str, obj_id: str) -> None:
"""保存对象到文件"""
file_path = os.path.join(directory, f"{obj_id}.json")
with open(file_path, "w") as f:
# 使用model_dump()而不是dict()
json.dump(obj.model_dump(), f, default=str)
def save_workflow(self, workflow: Workflow) -> None:
super().save_workflow(workflow)
self._save_to_file(workflow, self._workflows_dir, workflow.id)
def save_task_instance(self, task_instance: TaskInstance) -> None:
super().save_task_instance(task_instance)
self._save_to_file(task_instance, self._task_instances_dir, task_instance.id)
def save_checkpoint(self, checkpoint: Checkpoint) -> None:
super().save_checkpoint(checkpoint)
self._save_to_file(checkpoint, self._checkpoints_dir, checkpoint.id)
这里我们实现了两个状态存储:
InMemoryStateStore:完全基于内存,简单但不持久FileStateStore:继承自InMemoryStateStore,增加了文件持久化功能
在实际生产环境中,你可能会使用数据库(如PostgreSQL、MongoDB)或分布式存储系统(如Redis、Etcd)来实现状态存储。
第三步:实现执行器
执行器负责实际执行任务。它需要能够动态加载函数并执行它们。
创建一个executor.py文件:
import importlib
import time
import traceback
from datetime import datetime
from typing import Any, Callable, Dict, Optional
from models import Task, TaskInstance, TaskStatus
class TaskExecutor:
"""任务执行器"""
def __init__(self):
self._function_cache: Dict[str, Callable] = {}
def _load_function(self, function_path: str) -> Callable:
"""动态加载函数"""
if function_path in self._function_cache:
return self._function_cache[function_path]
try:
module_name, function_name = function_path.rsplit('.', 1)
module = importlib.import_module(module_name)
function = getattr(module, function_name)
if not callable(function):
raise ValueError(f"{function_path} is not callable")
self._function_cache[function_path] = function
return function
except (ImportError, AttributeError, ValueError) as e:
raise RuntimeError(f"Failed to load function {function_path}: {str(e)}")
def execute_task(
self,
task: Task,
task_instance: TaskInstance,
input_data: Dict[str, Any],
timeout: Optional[int] = None
) -> TaskInstance:
"""执行任务"""
# 更新任务实例状态为运行中
task_instance.status = TaskStatus.RUNNING
task_instance.start_time = datetime.now()
task_instance.input_data = input_data
try:
# 加载并执行函数
function = self._load_function(task.function)
# 合并任务参数和输入数据
kwargs = {**task.parameters, **input_data}
# 执行函数(这里简化处理,没有实现真正的超时机制)
# 实际生产中可能需要使用多进程或线程来实现超时
start_time = time.time()
result = function(**kwargs)
execution_time = time.time() - start_time
# 更新任务实例为成功状态
task_instance.status = TaskStatus.COMPLETED
task_instance.result = result
task_instance.end_time = datetime.now()
except Exception as e:
# 更新任务实例为失败状态
task_instance.status = TaskStatus.FAILED
task_instance.error = f"{str(e)}\n{traceback.format_exc()}"
task_instance.end_time = datetime.now()
# 如果还有重试次数,可以考虑标记为重试
task_instance.retry_count += 1
return task_instance
这个执行器实现了基本的任务执行功能,包括动态加载函数、执行函数、处理成功和失败情况。注意我们这里简化了超时机制的实现,实际生产中可能需要更健壮的实现。
第四步:实现调度器
调度器是整个系统的核心,它负责解析工作流定义、管理任务依赖关系、调度任务执行,以及创建检查点。
创建一个scheduler.py文件:
import networkx as nx
from datetime import datetime
from typing import Dict, List, Optional, Set
from concurrent.futures import ThreadPoolExecutor, as_completed
from models import (
Workflow, Task, TaskInstance, Checkpoint,
WorkflowStatus, TaskStatus, Dependency
)
from state_store import StateStore
from executor import TaskExecutor
class WorkflowScheduler:
"""工作流调度器"""
def __init__(self, state_store: StateStore, max_workers: int = 4):
self._state_store = state_store
self._executor = TaskExecutor()
self._thread_pool = ThreadPoolExecutor(max_workers=max_workers)
def _build_dag(self, workflow: Workflow) -> nx.DiGraph:
"""从工作流构建DAG"""
graph = nx.DiGraph()
# 添加节点
for task_id in workflow.tasks:
graph.add_node(task_id)
# 添加边
for dep in workflow.dependencies:
graph.add_edge(dep.source_task_id, dep.target_task_id)
# 验证是否为DAG
if not nx.is_directed_acyclic_graph(graph):
raise ValueError("Workflow is not a DAG")
return graph
def _create_task_instances(self, workflow: Workflow) -> Dict[str, TaskInstance]:
"""为工作流中的所有任务创建实例"""
task_instances = {}
for task_id, task in workflow.tasks.items():
instance = TaskInstance(
task_id=task_id,
workflow_id=workflow.id,
status=TaskStatus.PENDING
)
task_instances[task_id] = instance
self._state_store.save_task_instance(instance)
return task_instances
def _get_ready_tasks(
self,
graph: nx.DiGraph,
task_instances: Dict[str, TaskInstance]
) -> List[str]:
"""获取所有准备好执行的任务ID"""
ready_tasks = []
for task_id in graph.nodes:
instance = task_instances.get(task_id)
if not instance:
continue
# 只处理待处理或准备好的任务
if instance.status not in [TaskStatus.PENDING, TaskStatus.READY]:
continue
# 检查所有依赖是否已完成
dependencies = list(graph.predecessors(task_id))
all_dependencies_completed = True
for dep_id in dependencies:
dep_instance = task_instances.get(dep_id)
if not dep_instance or dep_instance.status != TaskStatus.COMPLETED:
all_dependencies_completed = False
break
if all_dependencies_completed:
# 检查是否有依赖失败导致需要跳过
should_skip = False
for dep_id in dependencies:
dep_instance = task_instances.get(dep_id)
if dep_instance and dep_instance.status == TaskStatus.FAILED:
should_skip = True
break
if should_skip:
instance.status = TaskStatus.SKIPPED
self._state_store.save_task_instance(instance)
else:
instance.status = TaskStatus.READY
ready_tasks.append(task_id)
self._state_store.save_task_instance(instance)
return ready_tasks
def _create_checkpoint(
self,
workflow: Workflow,
task_instances: Dict[str, TaskInstance],
metadata: Optional[Dict] = None
) -> Checkpoint:
"""创建检查点"""
checkpoint = Checkpoint(
workflow_id=workflow.id,
task_instances={
task_id: instance.model_copy(deep=True)
for task_id, instance in task_instances.items()
},
workflow_status=workflow.status,
metadata=metadata or {}
)
self._state_store.save_checkpoint(checkpoint)
workflow.current_checkpoint_id = checkpoint.id
self._state_store.save_workflow(workflow)
return checkpoint
def _collect_input_data(
self,
task_id: str,
graph: nx.DiGraph,
task_instances: Dict[str, TaskInstance]
) -> Dict:
"""收集任务的输入数据(来自依赖任务的输出)"""
input_data = {}
for dep_id in graph.predecessors(task_id):
dep_instance = task_instances.get(dep_id)
if dep_instance and dep_instance.status == TaskStatus.COMPLETED:
# 这里简单地将所有依赖的结果合并到输入数据中
# 实际应用中可能需要更复杂的数据传递机制
if isinstance(dep_instance.result, dict):
input_data.update(dep_instance.result)
else:
input_data[f"{dep_id}_result"] = dep_instance.result
return input_data
def execute_workflow(
self,
workflow_id: str,
checkpoint_id: Optional[str] = None
) -> Workflow:
"""执行工作流"""
# 获取工作流
workflow = self._state_store.get_workflow(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
# 构建DAG
graph = self._build_dag(workflow)
# 加载或创建任务实例
if checkpoint_id:
# 从检查点恢复
checkpoint = self._state_store.get_checkpoint(checkpoint_id)
if not checkpoint or checkpoint.workflow_id != workflow_id:
raise ValueError(f"Checkpoint {checkpoint_id} not found or not for this workflow")
task_instances = checkpoint.task_instances
workflow.status = checkpoint.workflow_status
# 保存恢复的任务实例
for instance in task_instances.values():
self._state_store.save_task_instance(instance)
else:
# 创建新的任务实例
task_instances = self._create_task_instances(workflow)
# 更新工作流状态为运行中
workflow.status = WorkflowStatus.RUNNING
self._state_store.save_workflow(workflow)
# 创建初始检查点
self._create_checkpoint(workflow, task_instances, {"type": "start"})
try:
# 执行循环
while True:
# 检查工作流是否已完成或失败
completed_count = sum(
1 for instance in task_instances.values()
if instance.status == TaskStatus.COMPLETED
)
failed_count = sum(
1 for instance in task_instances.values()
if instance.status == TaskStatus.FAILED
)
skipped_count = sum(
1 for instance in task_instances.values()
if instance.status == TaskStatus.SKIPPED
)
total_tasks = len(task_instances)
if completed_count + failed_count + skipped_count == total_tasks:
# 所有任务都已处理完毕
if failed_count > 0:
workflow.status = WorkflowStatus.FAILED
else:
workflow.status = WorkflowStatus.COMPLETED
break
# 获取准备好执行的任务
ready_task_ids = self._get_ready_tasks(graph, task_instances)
if not ready_task_ids:
# 没有准备好的任务,但也没有完成所有任务,可能是死锁或依赖问题
# 这里简化处理,标记为失败
workflow.status = WorkflowStatus.FAILED
break
# 并行执行准备好的任务
futures = {}
for task_id in ready_task_ids:
task = workflow.tasks[task_id]
instance = task_instances[task_id]
input_data = self._collect_input_data(task_id, graph, task_instances)
# 提交任务到线程池
future = self._thread_pool.submit(
self._executor.execute_task,
task,
instance,
input_data,
task.timeout
)
futures[future] = task_id
# 处理执行结果
for future in as_completed(futures):
task_id = futures[future]
try:
updated_instance = future.result()
task_instances[task_id] = updated_instance
self._state_store.save_task_instance(updated_instance)
except Exception as e:
# 处理执行过程中的异常
instance = task_instances[task_id]
instance.status = TaskStatus.FAILED
instance.error = str(e)
instance.end_time = datetime.now()
self._state_store.save_task_instance(instance)
# 创建检查点
self._create_checkpoint(workflow, task_instances, {"type": "progress"})
except Exception as e:
# 处理工作流执行过程中的异常
workflow.status = WorkflowStatus.FAILED
# 这里可以记录错误信息到workflow.metadata
workflow.metadata["error"] = str(e)
finally:
# 保存最终状态
self._state_store.save_workflow(workflow)
# 创建最终检查点
self._create_checkpoint(workflow, task_instances, {"type": "end"})
return workflow
def resume_workflow(self, workflow_id: str) -> Workflow:
"""从最新检查点恢复工作流"""
workflow = self._state_store.get_workflow(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
latest_checkpoint = self._state_store.get_latest_checkpoint(workflow_id)
if not latest_checkpoint:
raise ValueError(f"No checkpoints found for workflow {workflow_id}")
return self.execute_workflow(workflow_id, latest_checkpoint.id)
这个调度器实现了工作流执行的核心逻辑:
- 构建和验证DAG
- 管理任务实例的生命周期
- 确定哪些任务准备好执行
- 并行执行任务
- 创建检查点用于恢复
- 从检查点恢复工作流执行
第五步:构建工作流API
现在让我们创建一个简单的API层,使用户能够方便地定义和执行工作流。
创建一个workflow_api.py文件:
from typing import Any, Callable, Dict, List, Optional
from models import Workflow, Task, WorkflowStatus
from state_store import StateStore, FileStateStore
from scheduler import WorkflowScheduler
class WorkflowAPI:
"""工作流API"""
def __init__(self, state_store: Optional[StateStore] = None, max_workers: int = 4):
self._state_store = state_store or FileStateStore()
self._scheduler = WorkflowScheduler(self._state_store, max_workers)
def create_workflow(self, name: str, description: Optional[str] = None) -> Workflow:
"""创建一个新的工作流"""
workflow = Workflow(name=name, description=description)
self._state_store.save_workflow(workflow)
return workflow
def add_task(
self,
workflow_id: str,
name: str,
function: str,
parameters: Optional[Dict[str, Any]] = None,
description: Optional[str] = None,
max_retries: int = 3,
timeout: Optional[int] = None
) -> Task:
"""向工作流添加任务"""
workflow = self._state_store.get_workflow(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
task = Task(
name=name,
description=description,
function=function,
parameters=parameters or {},
max_retries=max_retries,
timeout=timeout
)
workflow.add_task(task)
self._state_store.save_workflow(workflow)
return task
def add_dependency(
self,
workflow_id: str,
source_task_id: str,
target_task_id: str,
condition: Optional[str] = None
) -> None:
"""添加任务依赖关系"""
workflow = self._state_store.get_workflow(workflow_id)
if not workflow:
raise ValueError(f"Workflow {workflow_id} not found")
workflow.add_dependency(source_task_id, target_task_id, condition)
self._state_store.save_workflow(workflow)
def get_workflow(self, workflow_id: str) -> Optional[Workflow]:
"""获取工作流"""
return self._state_store.get_workflow(workflow_id)
def list_workflows(self) -> List[Workflow]:
"""列出所有工作流"""
return self._state_store.list_workflows()
def execute_workflow(self, workflow_id: str, checkpoint_id: Optional[str] = None) -> Workflow:
"""执行工作流"""
return self._scheduler.execute_workflow(workflow_id, checkpoint_id)
def resume_workflow(self, workflow_id: str) -> Workflow:
"""从最新检查点恢复工作流"""
return self._scheduler.resume_workflow(workflow_id)
def get_workflow_status(self, workflow_id: str) -> Optional[WorkflowStatus]:
"""获取工作流状态"""
workflow = self._state_store.get_workflow(workflow_id)
if workflow:
return workflow.status
return None
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)