项目实训实验报告3
·
目录
一、上周回顾与本周目标
上周完成的工作
| 模块 | 功能 |
|---|---|
| 项目初始化 | 环境搭建、依赖安装 |
| 核心架构 | 分层模块化设计 |
| 工作流引擎 | LangGraph StateGraph |
| 智能路由 | 规则引擎 + LLM |
| 流式 API | SSE 实时推送 |
| 小红书子图 | 内容→图片→发布流程 |
| 基础测试 | 单元测试 |
本周开发目标
- 优化工作流引擎:提升图执行性能,增强状态管理
- 扩展智能路由系统:支持更多路由策略,增强意图分析能力
- 完善流式 API:优化 SSE 协议实现,添加心跳机制
- 扩展小红书工作流:支持多平台发布,优化内容生成
- 增强错误处理:实现统一的错误处理和重试机制
- 编写集成测试:确保各模块协同工作正常
二、工作流引擎优化
1. 图执行器重构
重构目标
提高图执行器的性能和可扩展性,支持更多执行模式。
核心实现
class GraphExecutor:
"""图执行器 - 支持多种执行模式
新增功能:
1. 同步/异步执行支持
2. 批量执行模式
3. 执行超时控制
4. 执行进度回调
"""
def __init__(self, compiled_graph):
self.graph = compiled_graph
self._executor = None
async def invoke(
self,
input_data: dict,
thread_id: str,
timeout: int = 300,
progress_callback: Optional[callable] = None,
) -> dict:
"""执行图并返回结果
Args:
input_data: 输入数据
thread_id: 会话 ID
timeout: 超时时间(秒)
progress_callback: 进度回调函数
Returns:
执行结果
"""
config = {
"configurable": {
"thread_id": thread_id,
}
}
# 执行超时控制
try:
async with asyncio.timeout(timeout):
# 执行图
result = await self.graph.ainvoke(input_data, config)
# 进度回调
if progress_callback:
await progress_callback(100, "执行完成")
return result
except asyncio.TimeoutError:
raise GraphExecutionError(f"图执行超时({timeout}秒)")
async def batch_invoke(
self,
inputs: list[dict],
thread_ids: list[str],
max_concurrent: int = 5,
) -> list[dict]:
"""批量执行图
Args:
inputs: 输入数据列表
thread_ids: 会话 ID 列表
max_concurrent: 最大并发数
Returns:
执行结果列表
"""
# 限制并发数
semaphore = asyncio.Semaphore(max_concurrent)
async def execute_with_limit(input_data, thread_id):
async with semaphore:
return await self.invoke(input_data, thread_id)
# 并行执行
tasks = [
execute_with_limit(inputs[i], thread_ids[i])
for i in range(len(inputs))
]
return await asyncio.gather(*tasks)
2. 状态管理增强
新增功能
支持多种存储后端和状态恢复机制。
class StateManager:
"""状态管理器 - 支持多种存储后端和高级特性"""
def __init__(
self,
storage_backend: str = "memory",
storage_config: Optional[dict] = None,
):
"""初始化状态管理器
Args:
storage_backend: 存储后端类型(memory/sqlite/redis)
storage_config: 存储配置
"""
self.backend = storage_backend
self.config = storage_config or {}
# 根据后端类型初始化存储
if storage_backend == "memory":
self._storage = MemoryStorage()
elif storage_backend == "sqlite":
db_path = self.config.get("db_path", "./state.db")
self._storage = SQLiteStorage(db_path)
elif storage_backend == "redis":
redis_config = self.config.get("redis", {})
self._storage = RedisStorage(**redis_config)
else:
raise ValueError(f"不支持的存储后端: {storage_backend}")
# 状态缓存
self._cache = {}
self._cache_ttl = self.config.get("cache_ttl", 3600)
async def save_state(self, thread_id: str, state: dict):
"""保存状态(带缓存)"""
# 保存到后端
await self._storage.save(thread_id, state)
# 更新缓存
self._cache[thread_id] = {
"state": state,
"timestamp": time.time(),
}
# 清理过期缓存
await self._cleanup_cache()
async def load_state(self, thread_id: str) -> Optional[dict]:
"""加载状态(优先缓存)"""
# 检查缓存
cached = self._cache.get(thread_id)
if cached and (time.time() - cached["timestamp"]) < self._cache_ttl:
return cached["state"]
# 从后端加载
state = await self._storage.load(thread_id)
# 更新缓存
if state:
self._cache[thread_id] = {
"state": state,
"timestamp": time.time(),
}
return state
async def _cleanup_cache(self):
"""清理过期缓存"""
now = time.time()
expired_keys = [
key for key, value in self._cache.items()
if (now - value["timestamp"]) > self._cache_ttl
]
for key in expired_keys:
del self._cache[key]
async def get_session_stats(self, thread_id: str) -> dict:
"""获取会话统计信息"""
state = await self.load_state(thread_id)
if not state:
return {}
return {
"thread_id": thread_id,
"message_count": len(state.get("messages", [])),
"task_count": len([t for t in state.get("tasks", []) if t]),
"created_at": state.get("created_at"),
"last_updated": state.get("updated_at"),
}
3. 错误处理机制
统一错误处理
class GraphExecutionError(Exception):
"""图执行错误"""
def __init__(self, message: str, error_code: Optional[int] = None, details: Optional[dict] = None):
super().__init__(message)
self.error_code = error_code
self.details = details or {}
class RetryPolicy:
"""重试策略"""
def __init__(
self,
max_retries: int = 3,
initial_delay: float = 1.0,
backoff_factor: float = 2.0,
retry_on_exceptions: Optional[list[type]] = None,
):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.backoff_factor = backoff_factor
self.retry_on_exceptions = retry_on_exceptions or []
async def execute(
self,
func: callable,
*args,
**kwargs,
):
"""执行函数并应用重试策略"""
last_exception = None
for attempt in range(self.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
# 检查是否应该重试
if not self._should_retry(e):
raise e
# 计算延迟
delay = self._calculate_delay(attempt)
logger.warning(f"操作失败,第 {attempt + 1}/{self.max_retries} 次重试,等待 {delay:.2f} 秒")
await asyncio.sleep(delay)
# 所有重试都失败
raise GraphExecutionError(
f"操作失败,已重试 {self.max_retries} 次",
details={"last_error": str(last_exception)}
)
def _should_retry(self, exception: Exception) -> bool:
"""判断是否应该重试"""
if not self.retry_on_exceptions:
return True
return isinstance(exception, tuple(self.retry_on_exceptions))
def _calculate_delay(self, attempt: int) -> float:
"""计算重试延迟(指数退避)"""
return self.initial_delay * (self.backoff_factor ** attempt)
# 使用示例
async def safe_execute_graph(executor, input_data, thread_id):
"""安全执行图(带重试)"""
policy = RetryPolicy(
max_retries=3,
initial_delay=1.0,
retry_on_exceptions=[GraphExecutionError, asyncio.TimeoutError],
)
return await policy.execute(
executor.invoke,
input_data,
thread_id,
)
三、智能路由系统扩展
1. 路由策略优化
新增路由策略
class RouterStrategy:
"""路由策略常量(扩展)"""
RULE_FIRST = "rule_first" # 规则优先
LLM_FIRST = "llm_first" # LLM 优先
HYBRID = "hybrid" # 混合策略
LLM_ONLY = "llm_only" # 仅 LLM
RULE_ONLY = "rule_only" # 仅规则
CONFIDENCE_THRESHOLD = "confidence_threshold" # 置信度阈值策略
ROUND_ROBIN = "round_robin" # 轮询策略(负载均衡)
class RouterSystem:
"""路由系统 - 扩展支持更多策略"""
async def _route_confidence_threshold(
self,
user_input: str,
context: Optional[dict],
messages: Optional[list],
) -> RouteDecision:
"""置信度阈值策略
规则:
1. 使用规则引擎和LLM同时分析
2. 如果规则置信度 >= 阈值,使用规则结果
3. 如果LLM置信度 >= 阈值,使用LLM结果
4. 否则返回降级决策
"""
threshold = 0.7
# 规则匹配
rule_decision = self.rule_engine.match(user_input, context)
rule_confidence = rule_decision.confidence if rule_decision else 0.0
# LLM分析
llm_decision = None
if self.intent_analyzer:
llm_decision = await self.intent_analyzer.analyze(user_input, context, messages)
llm_confidence = llm_decision.confidence
# 根据置信度选择
if rule_confidence >= threshold and llm_confidence >= threshold:
# 两者都满足,优先规则
return self._merge_decisions(rule_decision, llm_decision)
elif rule_confidence >= threshold:
return rule_decision
elif llm_confidence >= threshold:
return llm_decision
else:
return self._create_fallback_decision("Confidence below threshold")
def _merge_decisions(self, rule_decision, llm_decision) -> RouteDecision:
"""合并两个决策"""
return RouteDecision(
intent=llm_decision.intent,
confidence=min(rule_decision.confidence, llm_decision.confidence),
reasoning=f"规则: {rule_decision.reasoning} | LLM: {llm_decision.reasoning}",
response=llm_decision.response,
extracted_params={**rule_decision.extracted_params, **llm_decision.extracted_params},
target_nodes=rule_decision.target_nodes or llm_decision.target_nodes,
should_wait=llm_decision.should_wait,
metadata={"merged": True},
)
2. 意图分析器增强
多模型支持和上下文感知
class IntentAnalyzer:
"""意图分析器 - 增强版"""
def __init__(
self,
llm_providers: Optional[list[str]] = None,
fallback_provider: str = "qwen-plus",
temperature: float = 0.3,
):
"""初始化意图分析器
Args:
llm_providers: LLM 提供者列表(用于负载均衡)
fallback_provider: 降级提供者
temperature: 温度参数
"""
self.providers = llm_providers or ["qwen-plus"]
self.fallback_provider = fallback_provider
self.temperature = temperature
# 轮询索引
self._provider_index = 0
# 模型缓存
self._models = {}
async def analyze(
self,
user_input: str,
context: Optional[dict] = None,
messages: Optional[list] = None,
use_fallback: bool = False,
) -> RouteDecision:
"""分析用户意图(带降级)"""
provider = self._select_provider(use_fallback)
try:
return await self._analyze_with_provider(provider, user_input, context, messages)
except Exception as e:
logger.error(f"Provider {provider} failed: {e}")
# 如果已经在使用降级,直接抛出异常
if use_fallback:
raise
# 尝试降级
logger.info(f"Falling back to {self.fallback_provider}")
return await self.analyze(user_input, context, messages, use_fallback=True)
def _select_provider(self, use_fallback: bool = False) -> str:
"""选择LLM提供者(轮询策略)"""
if use_fallback:
return self.fallback_provider
# 轮询选择
provider = self.providers[self._provider_index]
self._provider_index = (self._provider_index + 1) % len(self.providers)
return provider
async def _analyze_with_provider(
self,
provider: str,
user_input: str,
context: Optional[dict],
messages: Optional[list],
) -> RouteDecision:
"""使用指定提供者分析"""
# 获取或创建模型
if provider not in self._models:
self._models[provider] = self._create_model(provider)
llm = self._models[provider]
structured_llm = llm.with_structured_output(IntentAnalysisOutput)
# 构建消息
msg_list = [SystemMessage(content=self._build_prompt())]
if messages:
msg_list.extend(messages[-5:])
msg_list.append(HumanMessage(content=user_input))
# 添加上下文信息
if context:
context_str = json.dumps(context, ensure_ascii=False)
msg_list.append(SystemMessage(content=f"上下文信息:{context_str}"))
output = await structured_llm.ainvoke(msg_list)
return self._convert_to_decision(output, provider)
3. 规则引擎扩展
支持复杂规则和动态规则
class RuleEngine:
"""规则引擎 - 支持复杂规则"""
def __init__(self):
self.rules = []
self.functions = {} # 自定义函数
def add_rule(self, rule: RuleConfig):
"""添加规则"""
self.rules.append(rule)
def add_dynamic_rule(self, condition: callable, action: callable):
"""添加动态规则"""
"""
动态规则示例:
def condition(user_input, context):
return "紧急" in user_input and context.get("priority") == "high"
def action(user_input, context):
return RouteDecision(
intent="urgent_request",
target_nodes=["priority_handler"],
confidence=1.0,
)
rule_engine.add_dynamic_rule(condition, action)
"""
dynamic_rule = {
"condition": condition,
"action": action,
}
self.rules.append(dynamic_rule)
def register_function(self, name: str, func: callable):
"""注册自定义函数"""
self.functions[name] = func
def match(self, user_input: str, context: Optional[dict]) -> Optional[RouteDecision]:
"""匹配规则(支持动态规则)"""
for rule in self.rules:
if isinstance(rule, dict):
# 动态规则
if rule["condition"](user_input, context):
return rule["action"](user_input, context)
else:
# 静态规则
if rule.match(user_input, context):
return self._create_decision(rule)
return None
# 自定义规则示例
class TimeBasedRule(RouteConfig):
"""基于时间的规则"""
def __init__(self, start_hour: int, end_hour: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start_hour = start_hour
self.end_hour = end_hour
def match(self, user_input: str, context: Optional[dict]) -> bool:
"""检查当前时间是否在指定范围内"""
current_hour = datetime.now().hour
in_time_range = self.start_hour <= current_hour < self.end_hour
# 必须满足时间条件和关键词条件
return in_time_range and super().match(user_input, context)
# 使用示例
rule_engine = RuleEngine()
# 添加时间规则:工作时间(9:00-18:00)优先处理工作相关请求
work_hours_rule = TimeBasedRule(
start_hour=9,
end_hour=18,
route_id="work_hours",
keywords=["工作", "任务", "报告"],
target_nodes=["work_handler"],
)
rule_engine.add_rule(work_hours_rule)
四、流式 API 完善
1. SSE 协议优化
增强的 SSE 实现
def create_streaming_router(executor: StreamingGraphExecutor) -> APIRouter:
"""创建流式 API 路由器(增强版)"""
router = APIRouter(prefix="/api/v1", tags=["streaming"])
@router.get("/chat/stream")
async def chat_stream(
message: str = Query(..., description="用户消息"),
thread_id: Optional[str] = Query(None, description="会话ID"),
user_id: Optional[str] = Query(None, description="用户ID"),
stream_mode: str = Query("full", description="流式模式:full/light"),
):
"""流式聊天接口(增强版)
流式模式说明:
- full: 完整模式,包含所有节点事件
- light: 轻量模式,仅包含关键事件
"""
if not thread_id:
thread_id = str(uuid.uuid4())
return StreamingResponse(
stream_graph_sse(
executor=executor,
user_input=message,
thread_id=thread_id,
user_id=user_id,
stream_mode=stream_mode,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*",
},
)
@router.get("/chat/stream/status")
async def stream_status(thread_id: str = Query(..., description="会话ID")):
"""查询流式会话状态"""
status = await executor.get_session_status(thread_id)
return {
"thread_id": thread_id,
"status": status,
"timestamp": datetime.now().isoformat(),
}
return router
async def stream_graph_sse(
executor: StreamingGraphExecutor,
user_input: str,
thread_id: str,
user_id: Optional[str] = None,
stream_mode: str = "full",
):
"""生成 SSE 事件流(增强版)"""
async def event_generator():
# 开始事件
yield f"event: started\ndata: {json.dumps({'thread_id': thread_id, 'mode': stream_mode})}\n\n"
# 启动心跳协程
heartbeat_task = asyncio.create_task(_heartbeat_generator())
try:
async for event in executor.stream_async(
user_input=user_input,
thread_id=thread_id,
user_id=user_id,
):
event_type = event.get("type")
event_data = event.get("data", {})
# 轻量模式过滤
if stream_mode == "light":
if event_type not in ["node_output", "completed", "error"]:
continue
# 发送事件
if event_type == "node_output":
# 压缩数据(移除不必要的字段)
event_data = _compress_event_data(event_data)
yield f"event: {event_type}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n"
except Exception as e:
yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
finally:
heartbeat_task.cancel()
yield f"event: completed\ndata: {json.dumps({'thread_id': thread_id})}\n\n"
return event_generator()
async def _heartbeat_generator():
"""心跳生成器"""
while True:
yield ": heartbeat\n\n"
await asyncio.sleep(15)
def _compress_event_data(data: dict) -> dict:
"""压缩事件数据"""
# 移除不必要的字段
fields_to_keep = ["node", "message", "progress", "timestamp"]
return {k: v for k, v in data.items() if k in fields_to_keep}
2. 流式执行器增强
支持多模式执行
class StreamingGraphExecutor:
"""流式图执行器(增强版)"""
def __init__(self, compiled_graph):
self.graph = compiled_graph
self._session_tasks = {} # 会话任务映射
async def stream_async(
self,
user_input: str,
thread_id: str,
user_id: Optional[str] = None,
mode: str = "streaming",
):
"""异步流式执行
执行模式:
- streaming: 全流式输出
- progressive: 渐进式输出(每个节点完成后输出)
- batch: 批量输出(全部完成后一次输出)
"""
config = {"configurable": {"thread_id": thread_id}}
# 获取当前状态
current_state = await self.graph.aget_state(config)
# 构建输入
input_data = self._build_input(user_input, current_state, user_id)
if mode == "streaming":
async for event in self._streaming_execute(input_data, config):
yield event
elif mode == "progressive":
async for event in self._progressive_execute(input_data, config):
yield event
elif mode == "batch":
result = await self.graph.ainvoke(input_data, config)
yield {"type": "completed", "data": result}
else:
raise ValueError(f"Unknown mode: {mode}")
async def _streaming_execute(self, input_data: dict, config: dict):
"""流式执行(逐步骤输出)"""
async for step in self.graph.astream(input_data, config):
for node, state in step.items():
yield {
"type": "node_output",
"data": {
"node": node,
"state": state,
"timestamp": datetime.now().isoformat(),
},
}
async def _progressive_execute(self, input_data: dict, config: dict):
"""渐进式执行(节点完成后输出)"""
async for step in self.graph.astream(input_data, config):
for node, state in step.items():
# 检查节点是否完成
if self._is_node_complete(state):
yield {
"type": "node_complete",
"data": {
"node": node,
"state": state,
"timestamp": datetime.now().isoformat(),
},
}
def _is_node_complete(self, state: dict) -> bool:
"""判断节点是否完成"""
# 检查是否有输出或错误标记
return "output" in state or "error" in state
3. 前端集成示例
完整的前端集成代码
// frontend/src/services/streamingService.js
class StreamingService {
constructor(baseUrl = '/api/v1') {
this.baseUrl = baseUrl;
this.eventSource = null;
this.listeners = {};
}
connect(message, threadId = null, options = {}) {
// 如果已有连接,先关闭
this.disconnect();
// 构建 URL
const params = new URLSearchParams({
message: message,
stream_mode: options.mode || 'full',
});
if (threadId) {
params.set('thread_id', threadId);
}
const url = `${this.baseUrl}/chat/stream?${params.toString()}`;
// 创建 EventSource
this.eventSource = new EventSource(url, {
withCredentials: true,
});
// 设置事件监听器
this.eventSource.addEventListener('started', (event) => {
this._emit('started', JSON.parse(event.data));
});
this.eventSource.addEventListener('node_start', (event) => {
this._emit('nodeStart', JSON.parse(event.data));
});
this.eventSource.addEventListener('node_output', (event) => {
this._emit('nodeOutput', JSON.parse(event.data));
});
this.eventSource.addEventListener('subgraph_start', (event) => {
this._emit('subgraphStart', JSON.parse(event.data));
});
this.eventSource.addEventListener('subgraph_node_output', (event) => {
this._emit('subgraphNodeOutput', JSON.parse(event.data));
});
this.eventSource.addEventListener('completed', (event) => {
this._emit('completed', JSON.parse(event.data));
this.disconnect();
});
this.eventSource.addEventListener('error', (event) => {
this._emit('error', event);
this.disconnect();
});
// 心跳处理
this.eventSource.addEventListener('heartbeat', () => {
// 保持连接活跃
});
}
disconnect() {
if (this.eventSource) {
this.eventSource.close();
this.eventSource = null;
}
}
on(eventName, callback) {
if (!this.listeners[eventName]) {
this.listeners[eventName] = [];
}
this.listeners[eventName].push(callback);
}
off(eventName, callback) {
if (this.listeners[eventName]) {
this.listeners[eventName] = this.listeners[eventName].filter(
(cb) => cb !== callback
);
}
}
_emit(eventName, data) {
if (this.listeners[eventName]) {
this.listeners[eventName].forEach((callback) => callback(data));
}
}
}
// 使用示例
const streamingService = new StreamingService();
streamingService.on('started', (data) => {
console.log('会话开始:', data.thread_id);
});
streamingService.on('nodeOutput', (data) => {
console.log('节点输出:', data.node, data.state);
// 更新UI显示进度
updateProgress(data.node, data.state);
});
streamingService.on('completed', (data) => {
console.log('会话完成:', data.thread_id);
showCompletionMessage();
});
streamingService.on('error', (error) => {
console.error('流式连接错误:', error);
showErrorMessage('连接中断,请重试');
});
// 发起请求
streamingService.connect('帮我写一篇小红书笔记', userThreadId);
五、小红书工作流扩展
1. 内容生成优化
智能内容生成策略
class XHSContentGenerator:
"""小红书内容生成器(优化版)"""
def __init__(self):
self._llm_client = None
async def generate(
self,
topic: str,
content_type: str = "normal",
style: str = "professional",
length: str = "medium",
) -> dict:
"""生成小红书内容
Args:
topic: 主题
content_type: 内容类型(normal/lifestyle/emotional)
style: 风格(professional/casual/humorous)
length: 长度(short/medium/long)
Returns:
包含标题、正文、标签的内容字典
"""
# 根据参数选择生成策略
if content_type == "lifestyle":
return await self._generate_lifestyle_content(topic, style)
else:
return await self._generate_normal_content(topic, style, length)
async def _generate_normal_content(
self,
topic: str,
style: str,
length: str,
) -> dict:
"""生成普通内容"""
length_map = {
"short": 300,
"medium": 600,
"long": 1000,
}
prompt = f"""请帮我写一篇关于「{topic}」的小红书笔记。
要求:
- 风格:{self._style_to_description(style)}
- 字数:约{length_map[length]}字
- 包含吸引人的标题
- 使用合适的标签(3-5个)
- 符合小红书平台风格
输出格式:
标题:[标题内容]
正文:[正文内容]
标签:#[标签1] #[标签2] #[标签3]"""
result = await self._call_llm(prompt)
return self._parse_content(result)
async def _generate_lifestyle_content(self, topic: str, style: str) -> dict:
"""生成生活化内容"""
prompt = f"""请以第一人称视角写一篇关于「{topic}」的生活化小红书笔记。
要求:
- 风格:{self._style_to_description(style)}
- 语气亲切自然,像和朋友聊天
- 可以加入个人感受和情绪
- 避免过于正式的语言
- 包含表情符号增加亲切感
- 使用合适的标签(3-5个)
输出格式:
标题:[标题内容]
正文:[正文内容]
标签:#[标签1] #[标签2] #[标签3]"""
result = await self._call_llm(prompt)
return self._parse_content(result)
def _style_to_description(self, style: str) -> str:
"""风格描述映射"""
styles = {
"professional": "专业、正式,适合知识分享",
"casual": "轻松、随意,适合日常分享",
"humorous": "幽默、有趣,带点调侃",
"emotional": "感性、温暖,触动人心",
}
return styles.get(style, "自然、适中")
def _parse_content(self, llm_result: str) -> dict:
"""解析LLM返回的内容"""
lines = llm_result.strip().split("\n")
result = {
"title": "",
"content": "",
"tags": [],
}
current_section = None
for line in lines:
if line.startswith("标题:"):
result["title"] = line.replace("标题:", "").strip()
current_section = None
elif line.startswith("正文:"):
result["content"] = line.replace("正文:", "").strip()
current_section = "content"
elif line.startswith("标签:"):
tags_str = line.replace("标签:", "").strip()
result["tags"] = [t.strip("#") for t in tags_str.split()]
current_section = None
elif current_section == "content":
result["content"] += "\n" + line.strip()
return result
2. 发布流程完善
增强的发布流程
class XHSPublishFlow:
"""小红书发布流程(完善版)"""
def __init__(self):
self._platform_manager = PlatformManager()
self._content_generator = XHSContentGenerator()
async def execute(
self,
topic: str,
platforms: list = None,
should_publish: bool = True,
content_options: Optional[dict] = None,
) -> dict:
"""执行完整的发布流程
Args:
topic: 内容主题
platforms: 目标平台列表
should_publish: 是否发布
content_options: 内容生成选项
Returns:
执行结果
"""
result = {
"success": False,
"steps": [],
"output": {},
}
try:
# Step 1: 生成内容
step_result = await self._step_generate_content(topic, content_options)
result["steps"].append({"step": "content_generation", "status": "completed"})
result["output"]["content"] = step_result
if not step_result.get("success"):
result["error"] = "内容生成失败"
return result
# Step 2: 生成图片(如果需要)
if content_options and content_options.get("need_images", False):
step_result = await self._step_generate_images(topic)
result["steps"].append({"step": "image_generation", "status": "completed"})
result["output"]["images"] = step_result
# Step 3: 发布(如果需要)
if should_publish:
step_result = await self._step_publish(
step_result["content"],
step_result.get("images", []),
platforms or ["xiaohongshu"],
)
result["steps"].append({"step": "publishing", "status": "completed"})
result["output"]["publish"] = step_result
result["success"] = True
except Exception as e:
result["success"] = False
result["error"] = str(e)
result["steps"].append({"step": "error", "status": "failed", "error": str(e)})
return result
async def _step_generate_content(self, topic: str, options: Optional[dict]) -> dict:
"""生成内容步骤"""
options = options or {}
result = await self._content_generator.generate(
topic=topic,
content_type=options.get("content_type", "normal"),
style=options.get("style", "professional"),
length=options.get("length", "medium"),
)
return {
"success": True,
"content": result,
}
async def _step_generate_images(self, topic: str) -> dict:
"""生成图片步骤"""
# 调用图片生成服务
# ...
return {
"success": True,
"images": [], # 图片URL列表
}
async def _step_publish(
self,
content: dict,
images: list,
platforms: list,
) -> dict:
"""发布步骤"""
results = await self._platform_manager.publish_to_platforms(
content=content,
platforms=platforms,
images=images,
)
# 检查是否全部成功
all_success = all(r["success"] for r in results.values())
return {
"success": all_success,
"results": results,
}
六、遇到的问题与解决方案
问题1:图执行性能优化
现象:图执行时间过长,特别是在复杂工作流中
原因分析:
- 每次执行都重新编译图
- 状态检查点操作频繁
- 缺乏缓存机制
解决方案:
class CachedGraphExecutor:
"""带缓存的图执行器"""
def __init__(self, graph_builder: GraphBuilder):
self.graph_builder = graph_builder
self._compiled_graph = None
self._compile_lock = asyncio.Lock()
async def get_graph(self):
"""获取编译后的图(带缓存)"""
async with self._compile_lock:
if self._compiled_graph is None:
self._compiled_graph = await self.graph_builder.compile()
return self._compiled_graph
async def invalidate_cache(self):
"""使缓存失效(配置变更时调用)"""
async with self._compile_lock:
self._compiled_graph = None
问题2:并发会话管理
现象:多用户并发时会话状态混淆
原因分析:
- 状态存储没有正确隔离
- 缺少会话级别的锁机制
解决方案:
class SessionManager:
"""会话管理器(带锁机制)"""
def __init__(self):
self._session_locks = {} # {thread_id: asyncio.Lock}
async def acquire_lock(self, thread_id: str):
"""获取会话锁"""
if thread_id not in self._session_locks:
self._session_locks[thread_id] = asyncio.Lock()
return self._session_locks[thread_id]
async def with_session_lock(self, thread_id: str, func: callable, *args, **kwargs):
"""在会话锁保护下执行函数"""
async with self.acquire_lock(thread_id):
return await func(*args, **kwargs)
问题3:流式输出数据量过大
现象:流式输出包含大量重复数据,网络传输压力大
原因分析:
- 每次状态更新都发送完整状态
- 缺少数据压缩机制
解决方案:
class DeltaEncoder:
"""增量编码器"""
def __init__(self):
self._previous_state = {}
def encode(self, current_state: dict) -> dict:
"""计算增量并更新状态"""
delta = {}
for key, value in current_state.items():
if key not in self._previous_state or self._previous_state[key] != value:
delta[key] = value
self._previous_state = current_state.copy()
return delta
def reset(self):
"""重置状态(新会话开始时调用)"""
self._previous_state = {}
# 使用示例
encoder = DeltaEncoder()
async def stream_with_delta(executor, input_data):
encoder.reset()
async for state in executor.stream(input_data):
delta = encoder.encode(state)
if delta: # 只有有变化时才发送
yield delta
问题4:MCP 服务动态注册
现象:MCP 服务需要在运行时动态注册
原因分析:
- 服务配置可能在运行时变更
- 需要支持热更新
解决方案:
class MCPRegistry:
"""MCP服务注册表(支持动态注册)"""
def __init__(self):
self._services = {}
self._service_factories = {}
self._registry_lock = asyncio.Lock()
async def register_service(
self,
name: str,
factory: callable,
config: Optional[dict] = None,
):
"""注册服务(支持动态注册)"""
async with self._registry_lock:
self._service_factories[name] = {
"factory": factory,
"config": config or {},
}
# 如果服务已存在,重新创建
if name in self._services:
await self._services[name].close()
del self._services[name]
async def get_client(self, name: str):
"""获取服务客户端(懒加载)"""
async with self._registry_lock:
if name not in self._services:
if name not in self._service_factories:
raise ValueError(f"Service {name} not registered")
factory_info = self._service_factories[name]
self._services[name] = await factory_info["factory"](**factory_info["config"])
return self._services[name]
async def reload_service(self, name: str):
"""重新加载服务"""
async with self._registry_lock:
if name in self._services:
await self._services[name].close()
del self._services[name]
七、测试与验证
集成测试
# tests/test_integration.py
@pytest.mark.asyncio
async def test_full_workflow():
"""测试完整工作流"""
from ai_social_scheduler.subgraphs import XHSWorkflowSubgraph
workflow = XHSWorkflowSubgraph()
await workflow.initialize()
result = await workflow.invoke({
"description": "咖啡制作教程",
"image_count": 3,
"should_publish": False,
})
assert result.get("success") == True
assert "content_result" in result
assert "image_result" in result
@pytest.mark.asyncio
async def test_multi_platform_publish():
"""测试多平台发布"""
from ai_social_scheduler.platform import PlatformManager
manager = PlatformManager()
content = {
"title": "测试标题",
"content": "测试内容",
"tags": ["测试", "自动化"],
}
result = await manager.publish_to_platforms(
content=content,
platforms=["xiaohongshu"],
)
assert "xiaohongshu" in result
性能测试
# tests/test_performance.py
@pytest.mark.asyncio
async def test_execution_time():
"""测试执行时间"""
import time
workflow = XHSWorkflowSubgraph()
await workflow.initialize()
start_time = time.time()
await workflow.invoke({
"description": "测试内容",
"image_count": 1,
"should_publish": False,
})
elapsed = time.time() - start_time
assert elapsed < 30 # 应在30秒内完成
@pytest.mark.asyncio
async def test_concurrent_execution():
"""测试并发执行"""
import asyncio
workflow = XHSWorkflowSubgraph()
await workflow.initialize()
tasks = [
workflow.invoke({
"description": f"测试任务 {i}",
"image_count": 1,
"should_publish": False,
})
for i in range(5)
]
start_time = time.time()
results = await asyncio.gather(*tasks)
elapsed = time.time() - start_time
# 检查所有任务都成功
for result in results:
assert result.get("success") == True
# 并发执行应比串行快
assert elapsed < 60 # 5个任务应在60秒内完成
压力测试
# tests/test_stress.py
@pytest.mark.asyncio
async def test_high_concurrency():
"""高并发测试"""
import asyncio
import time
workflow = XHSWorkflowSubgraph()
await workflow.initialize()
# 模拟20个并发请求
tasks = []
for i in range(20):
tasks.append(
asyncio.create_task(
workflow.invoke({
"description": f"压力测试任务 {i}",
"image_count": 1,
"should_publish": False,
})
)
)
start_time = time.time()
# 设置总体超时
try:
results = await asyncio.wait_for(
asyncio.gather(*tasks),
timeout=120, # 2分钟超时
)
except asyncio.TimeoutError:
pytest.fail("高并发测试超时")
elapsed = time.time() - start_time
# 检查成功比例
success_count = sum(1 for r in results if r.get("success"))
success_rate = success_count / len(results)
assert success_rate >= 0.9 # 至少90%成功率
print(f"高并发测试完成: {len(results)} 个任务, 耗时 {elapsed:.2f} 秒, 成功率 {success_rate:.2%}")
八、开发进度与下一步计划
本周完成的工作
| 模块 | 功能 |
|---|---|
| 工作流引擎 | 图执行器重构、状态管理增强 |
| 智能路由 | 新增路由策略、意图分析器增强 |
| 流式 API | SSE协议优化、心跳机制、多模式支持 |
| 小红书工作流 | 多平台支持、内容生成优化 |
| 错误处理 | 统一错误处理、重试机制 |
| 集成测试 | 完整工作流测试、性能测试 |
下一步计划
-
优化系统性能
- 实现图执行缓存
- 优化数据库查询
- 添加性能监控
-
开发管理界面
- Web 管理后台
- 任务监控面板
- 数据分析报表
-
完善部署方案
- Docker 容器化部署
- Kubernetes 支持
- 自动化部署脚本
总结
本周完成了 重点工作包括:
- 工作流引擎优化:重构了图执行器,增强了状态管理,实现了统一的错误处理和重试机制
- 智能路由系统扩展:新增了多种路由策略,增强了意图分析器的多模型支持能力
- 流式 API 完善:优化了 SSE 协议实现,添加了心跳机制和多模式支持
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)